torch_em.data.datasets.medical.segthy
The SegThy dataset contains annotations for thyroid segmentation in MRI and US scans, and additional annotations for vein and artery segmentation in MRI.
NOTE: The label legends are described as following: 1: For thyroid-only labels: (at 'MRI_thyroid' or 'US_thyroid')
- background: 0 and thyroid: 1 2: For thyroid, jugular veins and carotid arteries (at 'MRI_thyroid+jugular+carotid_label')
- background: 0, thyroid: 1, jugular vein: 3 and 5, carotid artery: 2 and 4.
The dataset is located at https://www.cs.cit.tum.de/camp/publications/segthy-dataset/.
This dataset is from the publication https://doi.org/10.1371/journal.pone.0268550. Please cite it if you use this dataset in your research.
1"""The SegThy dataset contains annotations for thyroid segmentation in MRI and US scans, 2and additional annotations for vein and artery segmentation in MRI. 3 4NOTE: The label legends are described as following: 51: For thyroid-only labels: (at 'MRI_thyroid' or 'US_thyroid') 6- background: 0 and thyroid: 1 72: For thyroid, jugular veins and carotid arteries (at 'MRI_thyroid+jugular+carotid_label') 8- background: 0, thyroid: 1, jugular vein: 3 and 5, carotid artery: 2 and 4. 9 10The dataset is located at https://www.cs.cit.tum.de/camp/publications/segthy-dataset/. 11 12This dataset is from the publication https://doi.org/10.1371/journal.pone.0268550. 13Please cite it if you use this dataset in your research. 14""" 15 16import os 17from glob import glob 18from natsort import natsorted 19from typing import Union, Tuple, Literal, List 20 21import numpy as np 22 23from torch.utils.data import Dataset, DataLoader 24 25import torch_em 26 27from .. import util 28 29 30URLS = { 31 "MRI": "https://www.campar.in.tum.de/public_datasets/2022_plosone_eilers/MRI_data.zip", 32 "US": "https://www.campar.in.tum.de/public_datasets/2022_plosone_eilers/US_data.zip", 33} 34 35CHECKSUMS = { 36 "MRI": "e9d0599b305dfe36795c45282a8495d3bfb4a872851c221b321d59ed0b11e7eb", 37 "US": "52c59ef4db08adfa0e6ea562c7fe747c612f2064e01f907a78b170b02fb459bb", 38} 39 40 41def get_segthy_data(path: Union[os.PathLike, str], source: Literal['MRI', 'US'], download: bool = False): 42 """Download the SegThy dataset. 43 44 Args: 45 path: Filepath to a folder where the data is downloaded for further processing. 46 download: Whether to download the data if it is not present. 47 """ 48 data_dir = os.path.join(path, f"{source}_volunteer_dataset") 49 if os.path.exists(data_dir): 50 return 51 52 os.makedirs(path, exist_ok=True) 53 54 zip_path = os.path.join(path, f"{source}_data.zip") 55 util.download_source(path=zip_path, url=URLS[source], download=download, checksum=CHECKSUMS[source]) 56 util.unzip(zip_path=zip_path, dst=path) 57 58 # NOTE: There is one label with an empty channel. 59 if source == "MRI": 60 lpath = os.path.join(data_dir, "MRI_thyroid_label", "005_MRI_thyroid_label.nii.gz") 61 62 import nibabel as nib 63 # Load the label volume and remove the empty channel. 64 label = nib.load(lpath).get_fdata() 65 label = label[..., 0] 66 67 # Store the updated label. 68 label_nifti = nib.Nifti2Image(label, np.eye(4)) 69 nib.save(label_nifti, lpath) 70 71 72def get_segthy_paths( 73 path: Union[os.PathLike, str], 74 split: Literal['train', 'val', 'test'], 75 source: Literal['MRI', 'US'], 76 region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid", 77 download: bool = False 78) -> Tuple[List[str], List[str]]: 79 """Get paths to the SegThy data. 80 81 Args: 82 path: Filepath to a folder where the data is downloaded for further processing. 83 split: The choice of data split. 84 source: The source of dataset. Either 'MRI' or 'US. 85 region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'. 86 download: Whether to download the data if it is not present. 87 88 Returns: 89 List of filepaths for the image data. 90 List of filepaths for the label data. 91 """ 92 get_segthy_data(path, source, download) 93 94 if source == "MRI": 95 ldir = "MRI_thyroid_label" if region == "thyroid" else "MRI_thyroid+jugular+carotid_label" 96 label_paths = natsorted(glob(os.path.join(path, f"{source}_volunteer_dataset", ldir, "*.nii.gz"))) 97 raw_paths = [p.replace(ldir, "MRI") for p in label_paths] 98 99 if split == "train": 100 raw_paths = raw_paths[:15] if region == "thyroid" else raw_paths[:8] 101 label_paths = label_paths[:15] if region == "thyroid" else label_paths[:8] 102 elif split == "val": 103 raw_paths = raw_paths[15:20] if region == "thyroid" else raw_paths[8:10] 104 label_paths = label_paths[15:20] if region == "thyroid" else label_paths[8:10] 105 elif split == "test": 106 raw_paths = raw_paths[20:] if region == "thyroid" else raw_paths[10:] 107 label_paths = label_paths[20:] if region == "thyroid" else label_paths[10:] 108 else: 109 raise ValueError(f"'{split}' is not a valid split.") 110 111 else: # US data 112 assert region != "thyroid_and_vessels", "US source does not have labels for both thyroid and vessels." 113 ldir = "ground_truth_data/US_thyroid_label" 114 label_paths = natsorted(glob(os.path.join(path, f"{source}_volunteer_dataset", ldir, "*.nii"))) 115 116 raw_paths = [p.replace(ldir, "ground_truth_data/US") for p in label_paths] 117 raw_paths = [p.replace(".nii", "_US.nii") for p in raw_paths] 118 119 if split == "train": 120 raw_paths, label_paths = raw_paths[:20], label_paths[:20] 121 elif split == "val": 122 raw_paths, label_paths = raw_paths[20:25], label_paths[20:25] 123 elif split == "test": 124 raw_paths, label_paths = raw_paths[25:], label_paths[25:] 125 else: 126 raise ValueError(f"'{split}' is not a valid split.") 127 128 return raw_paths, label_paths 129 130 131def get_segthy_dataset( 132 path: Union[os.PathLike, str], 133 patch_shape: Tuple[int, ...], 134 split: Literal['train', 'val', 'test'], 135 source: Literal['MRI', 'US'], 136 region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid", 137 resize_inputs: bool = False, 138 download: bool = False, 139 **kwargs 140) -> Dataset: 141 """Get the SegThy dataset for thyroid (and vessel) segmentation. 142 143 Args: 144 path: Filepath to a folder where the data is downloaded for further processing. 145 patch_shape: The patch shape to use for training. 146 split: The choice of data split. 147 source: The source of dataset. Either 'MRI' or 'US. 148 region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'. 149 resize_inputs: Whether to resize inputs to the desired patch shape. 150 download: Whether to download the data if it is not present. 151 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 152 153 Returns: 154 The segmentation dataset. 155 """ 156 raw_paths, label_paths = get_segthy_paths(path, split, source, region, download) 157 158 if resize_inputs: 159 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 160 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 161 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 162 ) 163 164 return torch_em.default_segmentation_dataset( 165 raw_paths=raw_paths, 166 raw_key="data", 167 label_paths=label_paths, 168 label_key="data", 169 patch_shape=patch_shape, 170 is_seg_dataset=True, 171 **kwargs 172 ) 173 174 175def get_segthy_loader( 176 path: Union[os.PathLike, str], 177 batch_size: int, 178 patch_shape: Tuple[int, ...], 179 split: Literal['train', 'val', 'test'], 180 source: Literal['MRI', 'US'], 181 region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid", 182 resize_inputs: bool = False, 183 download: bool = False, 184 **kwargs 185) -> DataLoader: 186 """Get the SegThy dataloader for thyroid (and vessel) segmentation. 187 188 Args: 189 path: Filepath to a folder where the data is downloaded for further processing. 190 batch_size: The batch size for training. 191 patch_shape: The patch shape to use for training. 192 split: The choice of data split. 193 source: The source of dataset. Either 'MRI' or 'US. 194 region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'. 195 resize_inputs: Whether to resize inputs to the desired patch shape. 196 download: Whether to download the data if it is not present. 197 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 198 199 Args: 200 The DataLoader. 201 """ 202 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 203 dataset = get_segthy_dataset(path, patch_shape, split, source, region, resize_inputs, download, **ds_kwargs) 204 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
42def get_segthy_data(path: Union[os.PathLike, str], source: Literal['MRI', 'US'], download: bool = False): 43 """Download the SegThy dataset. 44 45 Args: 46 path: Filepath to a folder where the data is downloaded for further processing. 47 download: Whether to download the data if it is not present. 48 """ 49 data_dir = os.path.join(path, f"{source}_volunteer_dataset") 50 if os.path.exists(data_dir): 51 return 52 53 os.makedirs(path, exist_ok=True) 54 55 zip_path = os.path.join(path, f"{source}_data.zip") 56 util.download_source(path=zip_path, url=URLS[source], download=download, checksum=CHECKSUMS[source]) 57 util.unzip(zip_path=zip_path, dst=path) 58 59 # NOTE: There is one label with an empty channel. 60 if source == "MRI": 61 lpath = os.path.join(data_dir, "MRI_thyroid_label", "005_MRI_thyroid_label.nii.gz") 62 63 import nibabel as nib 64 # Load the label volume and remove the empty channel. 65 label = nib.load(lpath).get_fdata() 66 label = label[..., 0] 67 68 # Store the updated label. 69 label_nifti = nib.Nifti2Image(label, np.eye(4)) 70 nib.save(label_nifti, lpath)
Download the SegThy dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
73def get_segthy_paths( 74 path: Union[os.PathLike, str], 75 split: Literal['train', 'val', 'test'], 76 source: Literal['MRI', 'US'], 77 region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid", 78 download: bool = False 79) -> Tuple[List[str], List[str]]: 80 """Get paths to the SegThy data. 81 82 Args: 83 path: Filepath to a folder where the data is downloaded for further processing. 84 split: The choice of data split. 85 source: The source of dataset. Either 'MRI' or 'US. 86 region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'. 87 download: Whether to download the data if it is not present. 88 89 Returns: 90 List of filepaths for the image data. 91 List of filepaths for the label data. 92 """ 93 get_segthy_data(path, source, download) 94 95 if source == "MRI": 96 ldir = "MRI_thyroid_label" if region == "thyroid" else "MRI_thyroid+jugular+carotid_label" 97 label_paths = natsorted(glob(os.path.join(path, f"{source}_volunteer_dataset", ldir, "*.nii.gz"))) 98 raw_paths = [p.replace(ldir, "MRI") for p in label_paths] 99 100 if split == "train": 101 raw_paths = raw_paths[:15] if region == "thyroid" else raw_paths[:8] 102 label_paths = label_paths[:15] if region == "thyroid" else label_paths[:8] 103 elif split == "val": 104 raw_paths = raw_paths[15:20] if region == "thyroid" else raw_paths[8:10] 105 label_paths = label_paths[15:20] if region == "thyroid" else label_paths[8:10] 106 elif split == "test": 107 raw_paths = raw_paths[20:] if region == "thyroid" else raw_paths[10:] 108 label_paths = label_paths[20:] if region == "thyroid" else label_paths[10:] 109 else: 110 raise ValueError(f"'{split}' is not a valid split.") 111 112 else: # US data 113 assert region != "thyroid_and_vessels", "US source does not have labels for both thyroid and vessels." 114 ldir = "ground_truth_data/US_thyroid_label" 115 label_paths = natsorted(glob(os.path.join(path, f"{source}_volunteer_dataset", ldir, "*.nii"))) 116 117 raw_paths = [p.replace(ldir, "ground_truth_data/US") for p in label_paths] 118 raw_paths = [p.replace(".nii", "_US.nii") for p in raw_paths] 119 120 if split == "train": 121 raw_paths, label_paths = raw_paths[:20], label_paths[:20] 122 elif split == "val": 123 raw_paths, label_paths = raw_paths[20:25], label_paths[20:25] 124 elif split == "test": 125 raw_paths, label_paths = raw_paths[25:], label_paths[25:] 126 else: 127 raise ValueError(f"'{split}' is not a valid split.") 128 129 return raw_paths, label_paths
Get paths to the SegThy data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The choice of data split.
- source: The source of dataset. Either 'MRI' or 'US.
- region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
132def get_segthy_dataset( 133 path: Union[os.PathLike, str], 134 patch_shape: Tuple[int, ...], 135 split: Literal['train', 'val', 'test'], 136 source: Literal['MRI', 'US'], 137 region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid", 138 resize_inputs: bool = False, 139 download: bool = False, 140 **kwargs 141) -> Dataset: 142 """Get the SegThy dataset for thyroid (and vessel) segmentation. 143 144 Args: 145 path: Filepath to a folder where the data is downloaded for further processing. 146 patch_shape: The patch shape to use for training. 147 split: The choice of data split. 148 source: The source of dataset. Either 'MRI' or 'US. 149 region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'. 150 resize_inputs: Whether to resize inputs to the desired patch shape. 151 download: Whether to download the data if it is not present. 152 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 153 154 Returns: 155 The segmentation dataset. 156 """ 157 raw_paths, label_paths = get_segthy_paths(path, split, source, region, download) 158 159 if resize_inputs: 160 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 161 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 162 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 163 ) 164 165 return torch_em.default_segmentation_dataset( 166 raw_paths=raw_paths, 167 raw_key="data", 168 label_paths=label_paths, 169 label_key="data", 170 patch_shape=patch_shape, 171 is_seg_dataset=True, 172 **kwargs 173 )
Get the SegThy dataset for thyroid (and vessel) segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- source: The source of dataset. Either 'MRI' or 'US.
- region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
176def get_segthy_loader( 177 path: Union[os.PathLike, str], 178 batch_size: int, 179 patch_shape: Tuple[int, ...], 180 split: Literal['train', 'val', 'test'], 181 source: Literal['MRI', 'US'], 182 region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid", 183 resize_inputs: bool = False, 184 download: bool = False, 185 **kwargs 186) -> DataLoader: 187 """Get the SegThy dataloader for thyroid (and vessel) segmentation. 188 189 Args: 190 path: Filepath to a folder where the data is downloaded for further processing. 191 batch_size: The batch size for training. 192 patch_shape: The patch shape to use for training. 193 split: The choice of data split. 194 source: The source of dataset. Either 'MRI' or 'US. 195 region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'. 196 resize_inputs: Whether to resize inputs to the desired patch shape. 197 download: Whether to download the data if it is not present. 198 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 199 200 Args: 201 The DataLoader. 202 """ 203 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 204 dataset = get_segthy_dataset(path, patch_shape, split, source, region, resize_inputs, download, **ds_kwargs) 205 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the SegThy dataloader for thyroid (and vessel) segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- source: The source of dataset. Either 'MRI' or 'US.
- region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Arguments:
- The DataLoader.