torch_em.data.datasets.medical.segthy

The SegThy dataset contains annotations for thyroid segmentation in MRI and US scans, and additional annotations for vein and artery segmentation in MRI.

NOTE: The label legends are described as following: 1: For thyroid-only labels: (at 'MRI_thyroid' or 'US_thyroid')

  • background: 0 and thyroid: 1 2: For thyroid, jugular veins and carotid arteries (at 'MRI_thyroid+jugular+carotid_label')
  • background: 0, thyroid: 1, jugular vein: 3 and 5, carotid artery: 2 and 4.

The dataset is located at https://www.cs.cit.tum.de/camp/publications/segthy-dataset/.

This dataset is from the publication https://doi.org/10.1371/journal.pone.0268550. Please cite it if you use this dataset in your research.

  1"""The SegThy dataset contains annotations for thyroid segmentation in MRI and US scans,
  2and additional annotations for vein and artery segmentation in MRI.
  3
  4NOTE: The label legends are described as following:
  51: For thyroid-only labels: (at 'MRI_thyroid' or 'US_thyroid')
  6- background: 0 and thyroid: 1
  72: For thyroid, jugular veins and carotid arteries (at 'MRI_thyroid+jugular+carotid_label')
  8- background: 0, thyroid: 1, jugular vein: 3 and 5, carotid artery: 2 and 4.
  9
 10The dataset is located at https://www.cs.cit.tum.de/camp/publications/segthy-dataset/.
 11
 12This dataset is from the publication https://doi.org/10.1371/journal.pone.0268550.
 13Please cite it if you use this dataset in your research.
 14"""
 15
 16import os
 17from glob import glob
 18from natsort import natsorted
 19from typing import Union, Tuple, Literal, List
 20
 21import numpy as np
 22
 23from torch.utils.data import Dataset, DataLoader
 24
 25import torch_em
 26
 27from .. import util
 28
 29
 30URLS = {
 31    "MRI": "https://www.campar.in.tum.de/public_datasets/2022_plosone_eilers/MRI_data.zip",
 32    "US": "https://www.campar.in.tum.de/public_datasets/2022_plosone_eilers/US_data.zip",
 33}
 34
 35CHECKSUMS = {
 36    "MRI": "e9d0599b305dfe36795c45282a8495d3bfb4a872851c221b321d59ed0b11e7eb",
 37    "US": "52c59ef4db08adfa0e6ea562c7fe747c612f2064e01f907a78b170b02fb459bb",
 38}
 39
 40
 41def get_segthy_data(path: Union[os.PathLike, str], source: Literal['MRI', 'US'], download: bool = False):
 42    """Download the SegThy dataset.
 43
 44    Args:
 45        path: Filepath to a folder where the data is downloaded for further processing.
 46        download: Whether to download the data if it is not present.
 47    """
 48    data_dir = os.path.join(path, f"{source}_volunteer_dataset")
 49    if os.path.exists(data_dir):
 50        return
 51
 52    os.makedirs(path, exist_ok=True)
 53
 54    zip_path = os.path.join(path, f"{source}_data.zip")
 55    util.download_source(path=zip_path, url=URLS[source], download=download, checksum=CHECKSUMS[source])
 56    util.unzip(zip_path=zip_path, dst=path)
 57
 58    # NOTE: There is one label with an empty channel.
 59    if source == "MRI":
 60        lpath = os.path.join(data_dir, "MRI_thyroid_label", "005_MRI_thyroid_label.nii.gz")
 61
 62        import nibabel as nib
 63        # Load the label volume and remove the empty channel.
 64        label = nib.load(lpath).get_fdata()
 65        label = label[..., 0]
 66
 67        # Store the updated label.
 68        label_nifti = nib.Nifti2Image(label, np.eye(4))
 69        nib.save(label_nifti, lpath)
 70
 71
 72def get_segthy_paths(
 73    path: Union[os.PathLike, str],
 74    split: Literal['train', 'val', 'test'],
 75    source: Literal['MRI', 'US'],
 76    region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid",
 77    download: bool = False
 78) -> Tuple[List[str], List[str]]:
 79    """Get paths to the SegThy data.
 80
 81    Args:
 82        path: Filepath to a folder where the data is downloaded for further processing.
 83        split: The choice of data split.
 84        source: The source of dataset. Either 'MRI' or 'US.
 85        region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
 86        download: Whether to download the data if it is not present.
 87
 88    Returns:
 89        List of filepaths for the image data.
 90        List of filepaths for the label data.
 91    """
 92    get_segthy_data(path, source, download)
 93
 94    if source == "MRI":
 95        ldir = "MRI_thyroid_label" if region == "thyroid" else "MRI_thyroid+jugular+carotid_label"
 96        label_paths = natsorted(glob(os.path.join(path, f"{source}_volunteer_dataset", ldir, "*.nii.gz")))
 97        raw_paths = [p.replace(ldir, "MRI") for p in label_paths]
 98
 99        if split == "train":
100            raw_paths = raw_paths[:15] if region == "thyroid" else raw_paths[:8]
101            label_paths = label_paths[:15] if region == "thyroid" else label_paths[:8]
102        elif split == "val":
103            raw_paths = raw_paths[15:20] if region == "thyroid" else raw_paths[8:10]
104            label_paths = label_paths[15:20] if region == "thyroid" else label_paths[8:10]
105        elif split == "test":
106            raw_paths = raw_paths[20:] if region == "thyroid" else raw_paths[10:]
107            label_paths = label_paths[20:] if region == "thyroid" else label_paths[10:]
108        else:
109            raise ValueError(f"'{split}' is not a valid split.")
110
111    else:  # US data
112        assert region != "thyroid_and_vessels", "US source does not have labels for both thyroid and vessels."
113        ldir = "ground_truth_data/US_thyroid_label"
114        label_paths = natsorted(glob(os.path.join(path, f"{source}_volunteer_dataset", ldir, "*.nii")))
115
116        raw_paths = [p.replace(ldir, "ground_truth_data/US") for p in label_paths]
117        raw_paths = [p.replace(".nii", "_US.nii") for p in raw_paths]
118
119        if split == "train":
120            raw_paths, label_paths = raw_paths[:20], label_paths[:20]
121        elif split == "val":
122            raw_paths, label_paths = raw_paths[20:25], label_paths[20:25]
123        elif split == "test":
124            raw_paths, label_paths = raw_paths[25:], label_paths[25:]
125        else:
126            raise ValueError(f"'{split}' is not a valid split.")
127
128    return raw_paths, label_paths
129
130
131def get_segthy_dataset(
132    path: Union[os.PathLike, str],
133    patch_shape: Tuple[int, ...],
134    split: Literal['train', 'val', 'test'],
135    source: Literal['MRI', 'US'],
136    region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid",
137    resize_inputs: bool = False,
138    download: bool = False,
139    **kwargs
140) -> Dataset:
141    """Get the SegThy dataset for thyroid (and vessel) segmentation.
142
143    Args:
144        path: Filepath to a folder where the data is downloaded for further processing.
145        patch_shape: The patch shape to use for training.
146        split: The choice of data split.
147        source: The source of dataset. Either 'MRI' or 'US.
148        region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
149        resize_inputs: Whether to resize inputs to the desired patch shape.
150        download: Whether to download the data if it is not present.
151        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
152
153    Returns:
154        The segmentation dataset.
155    """
156    raw_paths, label_paths = get_segthy_paths(path, split, source, region, download)
157
158    if resize_inputs:
159        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
160        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
161            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
162        )
163
164    return torch_em.default_segmentation_dataset(
165        raw_paths=raw_paths,
166        raw_key="data",
167        label_paths=label_paths,
168        label_key="data",
169        patch_shape=patch_shape,
170        is_seg_dataset=True,
171        **kwargs
172    )
173
174
175def get_segthy_loader(
176    path: Union[os.PathLike, str],
177    batch_size: int,
178    patch_shape: Tuple[int, ...],
179    split: Literal['train', 'val', 'test'],
180    source: Literal['MRI', 'US'],
181    region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid",
182    resize_inputs: bool = False,
183    download: bool = False,
184    **kwargs
185) -> DataLoader:
186    """Get the SegThy dataloader for thyroid (and vessel) segmentation.
187
188    Args:
189        path: Filepath to a folder where the data is downloaded for further processing.
190        batch_size: The batch size for training.
191        patch_shape: The patch shape to use for training.
192        split: The choice of data split.
193        source: The source of dataset. Either 'MRI' or 'US.
194        region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
195        resize_inputs: Whether to resize inputs to the desired patch shape.
196        download: Whether to download the data if it is not present.
197        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
198
199    Args:
200        The DataLoader.
201    """
202    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
203    dataset = get_segthy_dataset(path, patch_shape, split, source, region, resize_inputs, download, **ds_kwargs)
204    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'MRI': 'https://www.campar.in.tum.de/public_datasets/2022_plosone_eilers/MRI_data.zip', 'US': 'https://www.campar.in.tum.de/public_datasets/2022_plosone_eilers/US_data.zip'}
CHECKSUMS = {'MRI': 'e9d0599b305dfe36795c45282a8495d3bfb4a872851c221b321d59ed0b11e7eb', 'US': '52c59ef4db08adfa0e6ea562c7fe747c612f2064e01f907a78b170b02fb459bb'}
def get_segthy_data( path: Union[os.PathLike, str], source: Literal['MRI', 'US'], download: bool = False):
42def get_segthy_data(path: Union[os.PathLike, str], source: Literal['MRI', 'US'], download: bool = False):
43    """Download the SegThy dataset.
44
45    Args:
46        path: Filepath to a folder where the data is downloaded for further processing.
47        download: Whether to download the data if it is not present.
48    """
49    data_dir = os.path.join(path, f"{source}_volunteer_dataset")
50    if os.path.exists(data_dir):
51        return
52
53    os.makedirs(path, exist_ok=True)
54
55    zip_path = os.path.join(path, f"{source}_data.zip")
56    util.download_source(path=zip_path, url=URLS[source], download=download, checksum=CHECKSUMS[source])
57    util.unzip(zip_path=zip_path, dst=path)
58
59    # NOTE: There is one label with an empty channel.
60    if source == "MRI":
61        lpath = os.path.join(data_dir, "MRI_thyroid_label", "005_MRI_thyroid_label.nii.gz")
62
63        import nibabel as nib
64        # Load the label volume and remove the empty channel.
65        label = nib.load(lpath).get_fdata()
66        label = label[..., 0]
67
68        # Store the updated label.
69        label_nifti = nib.Nifti2Image(label, np.eye(4))
70        nib.save(label_nifti, lpath)

Download the SegThy dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
def get_segthy_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], source: Literal['MRI', 'US'], region: Literal['thyroid', 'thyroid_and_vessels'] = 'thyroid', download: bool = False) -> Tuple[List[str], List[str]]:
 73def get_segthy_paths(
 74    path: Union[os.PathLike, str],
 75    split: Literal['train', 'val', 'test'],
 76    source: Literal['MRI', 'US'],
 77    region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid",
 78    download: bool = False
 79) -> Tuple[List[str], List[str]]:
 80    """Get paths to the SegThy data.
 81
 82    Args:
 83        path: Filepath to a folder where the data is downloaded for further processing.
 84        split: The choice of data split.
 85        source: The source of dataset. Either 'MRI' or 'US.
 86        region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
 87        download: Whether to download the data if it is not present.
 88
 89    Returns:
 90        List of filepaths for the image data.
 91        List of filepaths for the label data.
 92    """
 93    get_segthy_data(path, source, download)
 94
 95    if source == "MRI":
 96        ldir = "MRI_thyroid_label" if region == "thyroid" else "MRI_thyroid+jugular+carotid_label"
 97        label_paths = natsorted(glob(os.path.join(path, f"{source}_volunteer_dataset", ldir, "*.nii.gz")))
 98        raw_paths = [p.replace(ldir, "MRI") for p in label_paths]
 99
100        if split == "train":
101            raw_paths = raw_paths[:15] if region == "thyroid" else raw_paths[:8]
102            label_paths = label_paths[:15] if region == "thyroid" else label_paths[:8]
103        elif split == "val":
104            raw_paths = raw_paths[15:20] if region == "thyroid" else raw_paths[8:10]
105            label_paths = label_paths[15:20] if region == "thyroid" else label_paths[8:10]
106        elif split == "test":
107            raw_paths = raw_paths[20:] if region == "thyroid" else raw_paths[10:]
108            label_paths = label_paths[20:] if region == "thyroid" else label_paths[10:]
109        else:
110            raise ValueError(f"'{split}' is not a valid split.")
111
112    else:  # US data
113        assert region != "thyroid_and_vessels", "US source does not have labels for both thyroid and vessels."
114        ldir = "ground_truth_data/US_thyroid_label"
115        label_paths = natsorted(glob(os.path.join(path, f"{source}_volunteer_dataset", ldir, "*.nii")))
116
117        raw_paths = [p.replace(ldir, "ground_truth_data/US") for p in label_paths]
118        raw_paths = [p.replace(".nii", "_US.nii") for p in raw_paths]
119
120        if split == "train":
121            raw_paths, label_paths = raw_paths[:20], label_paths[:20]
122        elif split == "val":
123            raw_paths, label_paths = raw_paths[20:25], label_paths[20:25]
124        elif split == "test":
125            raw_paths, label_paths = raw_paths[25:], label_paths[25:]
126        else:
127            raise ValueError(f"'{split}' is not a valid split.")
128
129    return raw_paths, label_paths

Get paths to the SegThy data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • source: The source of dataset. Either 'MRI' or 'US.
  • region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_segthy_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], source: Literal['MRI', 'US'], region: Literal['thyroid', 'thyroid_and_vessels'] = 'thyroid', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
132def get_segthy_dataset(
133    path: Union[os.PathLike, str],
134    patch_shape: Tuple[int, ...],
135    split: Literal['train', 'val', 'test'],
136    source: Literal['MRI', 'US'],
137    region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid",
138    resize_inputs: bool = False,
139    download: bool = False,
140    **kwargs
141) -> Dataset:
142    """Get the SegThy dataset for thyroid (and vessel) segmentation.
143
144    Args:
145        path: Filepath to a folder where the data is downloaded for further processing.
146        patch_shape: The patch shape to use for training.
147        split: The choice of data split.
148        source: The source of dataset. Either 'MRI' or 'US.
149        region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
150        resize_inputs: Whether to resize inputs to the desired patch shape.
151        download: Whether to download the data if it is not present.
152        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
153
154    Returns:
155        The segmentation dataset.
156    """
157    raw_paths, label_paths = get_segthy_paths(path, split, source, region, download)
158
159    if resize_inputs:
160        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
161        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
162            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
163        )
164
165    return torch_em.default_segmentation_dataset(
166        raw_paths=raw_paths,
167        raw_key="data",
168        label_paths=label_paths,
169        label_key="data",
170        patch_shape=patch_shape,
171        is_seg_dataset=True,
172        **kwargs
173    )

Get the SegThy dataset for thyroid (and vessel) segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • source: The source of dataset. Either 'MRI' or 'US.
  • region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_segthy_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], source: Literal['MRI', 'US'], region: Literal['thyroid', 'thyroid_and_vessels'] = 'thyroid', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
176def get_segthy_loader(
177    path: Union[os.PathLike, str],
178    batch_size: int,
179    patch_shape: Tuple[int, ...],
180    split: Literal['train', 'val', 'test'],
181    source: Literal['MRI', 'US'],
182    region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid",
183    resize_inputs: bool = False,
184    download: bool = False,
185    **kwargs
186) -> DataLoader:
187    """Get the SegThy dataloader for thyroid (and vessel) segmentation.
188
189    Args:
190        path: Filepath to a folder where the data is downloaded for further processing.
191        batch_size: The batch size for training.
192        patch_shape: The patch shape to use for training.
193        split: The choice of data split.
194        source: The source of dataset. Either 'MRI' or 'US.
195        region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
196        resize_inputs: Whether to resize inputs to the desired patch shape.
197        download: Whether to download the data if it is not present.
198        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
199
200    Args:
201        The DataLoader.
202    """
203    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
204    dataset = get_segthy_dataset(path, patch_shape, split, source, region, resize_inputs, download, **ds_kwargs)
205    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the SegThy dataloader for thyroid (and vessel) segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • source: The source of dataset. Either 'MRI' or 'US.
  • region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Arguments:
  • The DataLoader.