torch_em.data.datasets.medical.sega
The SegA dataset contains annotations for aorta segmentation in CT scans.
The dataset is from the publication https://doi.org/10.1007/978-3-031-53241-2. Please cite it if you use this dataset for your research.
1"""The SegA dataset contains annotations for aorta segmentation in CT scans. 2 3The dataset is from the publication https://doi.org/10.1007/978-3-031-53241-2. 4Please cite it if you use this dataset for your research. 5""" 6 7import os 8from glob import glob 9from pathlib import Path 10from natsort import natsorted 11from typing import Union, Tuple, Optional, Literal, List 12 13from torch.utils.data import Dataset, DataLoader 14 15import torch_em 16 17from .. import util 18 19 20URL = { 21 "kits": "https://figshare.com/ndownloader/files/30950821", 22 "rider": "https://figshare.com/ndownloader/files/30950914", 23 "dongyang": "https://figshare.com/ndownloader/files/30950971" 24} 25 26CHECKSUMS = { 27 "kits": "6c9c2ea31e5998348acf1c4f6683ae07041bd6c8caf309dd049adc7f222de26e", 28 "rider": "7244038a6a4f70ae70b9288a2ce874d32128181de2177c63a7612d9ab3c4f5fa", 29 "dongyang": "0187e90038cba0564e6304ef0182969ff57a31b42c5969d2b9188a27219da541" 30} 31 32ZIPFILES = { 33 "kits": "KiTS.zip", 34 "rider": "Rider.zip", 35 "dongyang": "Dongyang.zip" 36} 37 38 39def get_sega_data( 40 path: Union[os.PathLike, str], 41 data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None, 42 download: bool = False 43) -> str: 44 """Dwonload the SegA dataset. 45 46 Args: 47 path: Filepath to a folder where the data is downloaded for further processing. 48 data_choice: The choice of dataset. 49 download: Whether to download the data if it is not present. 50 51 Returns: 52 Filepath where the data is downloaded. 53 """ 54 data_choice = data_choice.lower() 55 zip_fid = ZIPFILES[data_choice] 56 data_dir = os.path.join(path, Path(zip_fid).stem) 57 if os.path.exists(data_dir): 58 return data_dir 59 60 os.makedirs(path, exist_ok=True) 61 62 zip_path = os.path.join(path, zip_fid) 63 util.download_source(path=zip_path, url=URL[data_choice], download=download, checksum=CHECKSUMS[data_choice]) 64 util.unzip(zip_path=zip_path, dst=path) 65 66 return data_dir 67 68 69def get_sega_paths( 70 path: Union[os.PathLike, str], 71 data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None, 72 download: bool = False 73) -> Tuple[List[str], List[str]]: 74 """Get paths to the SegA data. 75 76 Args: 77 path: Filepath to a folder where the data is downloaded for further processing. 78 data_choice: The choice of dataset. 79 download: Whether to download the data if it is not present. 80 81 Returns: 82 List of filepaths for the image data. 83 List of filepaths for the label data. 84 """ 85 if data_choice is None: 86 data_choices = URL.keys() 87 else: 88 if isinstance(data_choice, str): 89 data_choices = [data_choice] 90 91 data_dirs = [get_sega_data(path=path, data_choice=data_choice, download=download) for data_choice in data_choices] 92 93 image_paths, gt_paths = [], [] 94 for data_dir in data_dirs: 95 all_volumes_paths = glob(os.path.join(data_dir, "*", "*.nrrd")) 96 for volume_path in all_volumes_paths: 97 if volume_path.endswith(".seg.nrrd"): 98 gt_paths.append(volume_path) 99 else: 100 image_paths.append(volume_path) 101 102 # now let's wrap the volumes to nifti format 103 fimage_dir = os.path.join(path, "data", "images") 104 fgt_dir = os.path.join(path, "data", "labels") 105 106 os.makedirs(fimage_dir, exist_ok=True) 107 os.makedirs(fgt_dir, exist_ok=True) 108 109 fimage_paths, fgt_paths = [], [] 110 for image_path, gt_path in zip(natsorted(image_paths), natsorted(gt_paths)): 111 fimage_path = os.path.join(fimage_dir, f"{Path(image_path).stem}.nii.gz") 112 fgt_path = os.path.join(fgt_dir, f"{Path(image_path).stem}.nii.gz") 113 114 fimage_paths.append(fimage_path) 115 fgt_paths.append(fgt_path) 116 117 if os.path.exists(fimage_path) and os.path.exists(fgt_path): 118 continue 119 120 import nrrd 121 import numpy as np 122 import nibabel as nib 123 124 image = nrrd.read(image_path)[0] 125 gt = nrrd.read(gt_path)[0] 126 127 image_nifti = nib.Nifti2Image(image, np.eye(4)) 128 gt_nifti = nib.Nifti2Image(gt, np.eye(4)) 129 130 nib.save(image_nifti, fimage_path) 131 nib.save(gt_nifti, fgt_path) 132 133 return natsorted(fimage_paths), natsorted(fgt_paths) 134 135 136def get_sega_dataset( 137 path: Union[os.PathLike, str], 138 patch_shape: Tuple[int, ...], 139 data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None, 140 resize_inputs: bool = False, 141 download: bool = False, 142 **kwargs 143) -> Dataset: 144 """Get the SegA dataset for segmentation of aorta in computed tomography angiography (CTA) scans. 145 146 Args: 147 path: Filepath to a folder where the data is downloaded for further processing. 148 patch_shape: The patch shape to use for training. 149 data_choice: The choice of dataset. 150 resize_inputs: Whether to resize the inputs to the patch shape. 151 download: Whether to download the data if it is not present. 152 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 153 154 Returns: 155 The segmentation dataset. 156 """ 157 image_paths, gt_paths = get_sega_paths(path, data_choice, download) 158 159 if resize_inputs: 160 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 161 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 162 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs, 163 ) 164 165 return torch_em.default_segmentation_dataset( 166 raw_paths=image_paths, 167 raw_key="data", 168 label_paths=gt_paths, 169 label_key="data", 170 patch_shape=patch_shape, 171 is_seg_dataset=True, 172 **kwargs 173 ) 174 175 176def get_sega_loader( 177 path: Union[os.PathLike, str], 178 batch_size: int, 179 patch_shape: Tuple[int, ...], 180 data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None, 181 resize_inputs: bool = False, 182 download: bool = False, 183 **kwargs 184) -> DataLoader: 185 """Get the SegA dataloader for segmentation of aorta in computed tomography angiography (CTA) scans. 186 187 Args: 188 path: Filepath to a folder where the data is downloaded for further processing. 189 batch_size: The batch size for training. 190 patch_shape: The patch shape to use for training. 191 data_choice: The choice of dataset. 192 resize_inputs: Whether to resize the inputs to the patch shape. 193 download: Whether to download the data if it is not present. 194 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 195 196 Returns: 197 The DataLoader. 198 """ 199 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 200 dataset = get_sega_dataset(path, patch_shape, data_choice, resize_inputs, download, **ds_kwargs) 201 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
{'kits': 'https://figshare.com/ndownloader/files/30950821', 'rider': 'https://figshare.com/ndownloader/files/30950914', 'dongyang': 'https://figshare.com/ndownloader/files/30950971'}
CHECKSUMS =
{'kits': '6c9c2ea31e5998348acf1c4f6683ae07041bd6c8caf309dd049adc7f222de26e', 'rider': '7244038a6a4f70ae70b9288a2ce874d32128181de2177c63a7612d9ab3c4f5fa', 'dongyang': '0187e90038cba0564e6304ef0182969ff57a31b42c5969d2b9188a27219da541'}
ZIPFILES =
{'kits': 'KiTS.zip', 'rider': 'Rider.zip', 'dongyang': 'Dongyang.zip'}
def
get_sega_data( path: Union[os.PathLike, str], data_choice: Optional[Literal['KiTS', 'Rider', 'Dongyang']] = None, download: bool = False) -> str:
40def get_sega_data( 41 path: Union[os.PathLike, str], 42 data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None, 43 download: bool = False 44) -> str: 45 """Dwonload the SegA dataset. 46 47 Args: 48 path: Filepath to a folder where the data is downloaded for further processing. 49 data_choice: The choice of dataset. 50 download: Whether to download the data if it is not present. 51 52 Returns: 53 Filepath where the data is downloaded. 54 """ 55 data_choice = data_choice.lower() 56 zip_fid = ZIPFILES[data_choice] 57 data_dir = os.path.join(path, Path(zip_fid).stem) 58 if os.path.exists(data_dir): 59 return data_dir 60 61 os.makedirs(path, exist_ok=True) 62 63 zip_path = os.path.join(path, zip_fid) 64 util.download_source(path=zip_path, url=URL[data_choice], download=download, checksum=CHECKSUMS[data_choice]) 65 util.unzip(zip_path=zip_path, dst=path) 66 67 return data_dir
Dwonload the SegA dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- data_choice: The choice of dataset.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_sega_paths( path: Union[os.PathLike, str], data_choice: Optional[Literal['KiTS', 'Rider', 'Dongyang']] = None, download: bool = False) -> Tuple[List[str], List[str]]:
70def get_sega_paths( 71 path: Union[os.PathLike, str], 72 data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None, 73 download: bool = False 74) -> Tuple[List[str], List[str]]: 75 """Get paths to the SegA data. 76 77 Args: 78 path: Filepath to a folder where the data is downloaded for further processing. 79 data_choice: The choice of dataset. 80 download: Whether to download the data if it is not present. 81 82 Returns: 83 List of filepaths for the image data. 84 List of filepaths for the label data. 85 """ 86 if data_choice is None: 87 data_choices = URL.keys() 88 else: 89 if isinstance(data_choice, str): 90 data_choices = [data_choice] 91 92 data_dirs = [get_sega_data(path=path, data_choice=data_choice, download=download) for data_choice in data_choices] 93 94 image_paths, gt_paths = [], [] 95 for data_dir in data_dirs: 96 all_volumes_paths = glob(os.path.join(data_dir, "*", "*.nrrd")) 97 for volume_path in all_volumes_paths: 98 if volume_path.endswith(".seg.nrrd"): 99 gt_paths.append(volume_path) 100 else: 101 image_paths.append(volume_path) 102 103 # now let's wrap the volumes to nifti format 104 fimage_dir = os.path.join(path, "data", "images") 105 fgt_dir = os.path.join(path, "data", "labels") 106 107 os.makedirs(fimage_dir, exist_ok=True) 108 os.makedirs(fgt_dir, exist_ok=True) 109 110 fimage_paths, fgt_paths = [], [] 111 for image_path, gt_path in zip(natsorted(image_paths), natsorted(gt_paths)): 112 fimage_path = os.path.join(fimage_dir, f"{Path(image_path).stem}.nii.gz") 113 fgt_path = os.path.join(fgt_dir, f"{Path(image_path).stem}.nii.gz") 114 115 fimage_paths.append(fimage_path) 116 fgt_paths.append(fgt_path) 117 118 if os.path.exists(fimage_path) and os.path.exists(fgt_path): 119 continue 120 121 import nrrd 122 import numpy as np 123 import nibabel as nib 124 125 image = nrrd.read(image_path)[0] 126 gt = nrrd.read(gt_path)[0] 127 128 image_nifti = nib.Nifti2Image(image, np.eye(4)) 129 gt_nifti = nib.Nifti2Image(gt, np.eye(4)) 130 131 nib.save(image_nifti, fimage_path) 132 nib.save(gt_nifti, fgt_path) 133 134 return natsorted(fimage_paths), natsorted(fgt_paths)
Get paths to the SegA data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- data_choice: The choice of dataset.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_sega_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], data_choice: Optional[Literal['KiTS', 'Rider', 'Dongyang']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
137def get_sega_dataset( 138 path: Union[os.PathLike, str], 139 patch_shape: Tuple[int, ...], 140 data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None, 141 resize_inputs: bool = False, 142 download: bool = False, 143 **kwargs 144) -> Dataset: 145 """Get the SegA dataset for segmentation of aorta in computed tomography angiography (CTA) scans. 146 147 Args: 148 path: Filepath to a folder where the data is downloaded for further processing. 149 patch_shape: The patch shape to use for training. 150 data_choice: The choice of dataset. 151 resize_inputs: Whether to resize the inputs to the patch shape. 152 download: Whether to download the data if it is not present. 153 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 154 155 Returns: 156 The segmentation dataset. 157 """ 158 image_paths, gt_paths = get_sega_paths(path, data_choice, download) 159 160 if resize_inputs: 161 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 162 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 163 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs, 164 ) 165 166 return torch_em.default_segmentation_dataset( 167 raw_paths=image_paths, 168 raw_key="data", 169 label_paths=gt_paths, 170 label_key="data", 171 patch_shape=patch_shape, 172 is_seg_dataset=True, 173 **kwargs 174 )
Get the SegA dataset for segmentation of aorta in computed tomography angiography (CTA) scans.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- data_choice: The choice of dataset.
- resize_inputs: Whether to resize the inputs to the patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_sega_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], data_choice: Optional[Literal['KiTS', 'Rider', 'Dongyang']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
177def get_sega_loader( 178 path: Union[os.PathLike, str], 179 batch_size: int, 180 patch_shape: Tuple[int, ...], 181 data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None, 182 resize_inputs: bool = False, 183 download: bool = False, 184 **kwargs 185) -> DataLoader: 186 """Get the SegA dataloader for segmentation of aorta in computed tomography angiography (CTA) scans. 187 188 Args: 189 path: Filepath to a folder where the data is downloaded for further processing. 190 batch_size: The batch size for training. 191 patch_shape: The patch shape to use for training. 192 data_choice: The choice of dataset. 193 resize_inputs: Whether to resize the inputs to the patch shape. 194 download: Whether to download the data if it is not present. 195 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 196 197 Returns: 198 The DataLoader. 199 """ 200 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 201 dataset = get_sega_dataset(path, patch_shape, data_choice, resize_inputs, download, **ds_kwargs) 202 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the SegA dataloader for segmentation of aorta in computed tomography angiography (CTA) scans.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- data_choice: The choice of dataset.
- resize_inputs: Whether to resize the inputs to the patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.