torch_em.data.datasets.medical.amos
The AMOS dataset contains annotations for abdominal multi-organ segmentation in CT and MRI scans.
This dataset is located at https://doi.org/10.5281/zenodo.7155725. The dataset is from AMOS 2022 Challenge https://doi.org/10.48550/arXiv.2206.08023. Please cite them if you use this dataset for your research.
1"""The AMOS dataset contains annotations for abdominal multi-organ segmentation in CT and MRI scans. 2 3This dataset is located at https://doi.org/10.5281/zenodo.7155725. 4The dataset is from AMOS 2022 Challenge https://doi.org/10.48550/arXiv.2206.08023. 5Please cite them if you use this dataset for your research. 6""" 7 8import os 9import shutil 10from glob import glob 11from pathlib import Path 12from typing import Union, Tuple, Optional, Literal, List 13 14from torch.utils.data import Dataset, DataLoader 15 16import torch_em 17 18from .. import util 19 20 21URL = "https://zenodo.org/records/7155725/files/amos22.zip" 22CHECKSUM = "d2fbf2c31abba9824d183f05741ce187b17905b8cca64d1078eabf1ba96775c2" 23 24 25def get_amos_data(path: Union[os.PathLike, str], download: bool = False) -> str: 26 """Download the AMOS dataset. 27 28 Args: 29 path: Filepath to a folder where the data is downloaded for further processing. 30 download: Whether to download the data if it is not present. 31 32 Returns: 33 Filepath where the data is downloaded. 34 """ 35 data_dir = os.path.join(path, "amos22") 36 if os.path.exists(data_dir): 37 return data_dir 38 39 os.makedirs(path, exist_ok=True) 40 41 zip_path = os.path.join(path, "amos22.zip") 42 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 43 util.unzip(zip_path=zip_path, dst=path) 44 45 shutil.rmtree(os.path.join(path, "__MACOSX")) 46 47 return data_dir 48 49 50def get_amos_paths( 51 path: Union[os.PathLike, str], 52 split: Literal['train', 'val', 'test'], 53 modality: Optional[Literal['CT', 'MRI']] = None, 54 download: bool = False 55) -> Tuple[List[str], List[str]]: 56 """Get paths to the AMOS data. 57 58 Args: 59 path: Filepath to a folder where the data is downloaded for further processing. 60 split: The choice of data split. 61 modality: The choice of imaging modality. 62 download: Whether to download the data if it is not present. 63 64 Returns: 65 List of filepaths for the image data. 66 List of filepaths for the label data. 67 """ 68 data_dir = get_amos_data(path=path, download=download) 69 70 if split == "train": 71 im_dir, gt_dir = "imagesTr", "labelsTr" 72 elif split == "val": 73 im_dir, gt_dir = "imagesVa", "labelsVa" 74 elif split == "test": 75 im_dir, gt_dir = "imagesTs", "labelsTs" 76 else: 77 raise ValueError(f"'{split}' is not a valid split.") 78 79 image_paths = sorted(glob(os.path.join(data_dir, im_dir, "*.nii.gz"))) 80 gt_paths = sorted(glob(os.path.join(data_dir, gt_dir, "*.nii.gz"))) 81 82 if modality is None: 83 chosen_image_paths, chosen_gt_paths = image_paths, gt_paths 84 else: 85 ct_image_paths, ct_gt_paths = [], [] 86 mri_image_paths, mri_gt_paths = [], [] 87 for image_path, gt_path in zip(image_paths, gt_paths): 88 patient_id = Path(image_path.split(".")[0]).stem 89 id_value = int(patient_id.split("_")[-1]) 90 91 is_ct = id_value < 500 92 93 if is_ct: 94 ct_image_paths.append(image_path) 95 ct_gt_paths.append(gt_path) 96 else: 97 mri_image_paths.append(image_path) 98 mri_gt_paths.append(gt_path) 99 100 if modality.upper() == "CT": 101 chosen_image_paths, chosen_gt_paths = ct_image_paths, ct_gt_paths 102 elif modality.upper() == "MRI": 103 chosen_image_paths, chosen_gt_paths = mri_image_paths, mri_gt_paths 104 else: 105 raise ValueError(f"'{modality}' is not a valid modality.") 106 107 return chosen_image_paths, chosen_gt_paths 108 109 110def get_amos_dataset( 111 path: Union[os.PathLike, str], 112 patch_shape: Tuple[int, ...], 113 split: Literal['train', 'val', 'test'], 114 modality: Optional[Literal['CT', 'MRI']] = None, 115 resize_inputs: bool = False, 116 download: bool = False, 117 **kwargs 118) -> Dataset: 119 """Get the AMOS dataset for abdominal multi-organ segmentation in CT and MRI scans. 120 121 Args: 122 path: Filepath to a folder where the data is downloaded for further processing. 123 patch_shape: The patch shape to use for traiing. 124 split: The choice of data split. 125 modality: The choice of imaging modality. 126 resize_inputs: Whether to resize the inputs. 127 download: Whether to download the data if it is not present. 128 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 129 130 Returns: 131 The segmentation dataset. 132 """ 133 image_paths, gt_paths = get_amos_paths(path, split, modality, download) 134 135 if resize_inputs: 136 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 137 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 138 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 139 ) 140 141 return torch_em.default_segmentation_dataset( 142 raw_paths=image_paths, 143 raw_key="data", 144 label_paths=gt_paths, 145 label_key="data", 146 patch_shape=patch_shape, 147 is_seg_dataset=True, 148 **kwargs 149 ) 150 151 152def get_amos_loader( 153 path: Union[os.PathLike, str], 154 batch_size: int, 155 patch_shape: Tuple[int, ...], 156 split: Literal['train', 'val', 'test'], 157 modality: Optional[Literal['CT', 'MRI']] = None, 158 resize_inputs: bool = False, 159 download: bool = False, 160 **kwargs 161) -> DataLoader: 162 """Get the AMOS dataloader for abdominal multi-organ segmentation in CT and MRI scans. 163 164 Args: 165 path: Filepath to a folder where the data is downloaded for further processing. 166 batch_size: The batch size for training. 167 patch_shape: The patch shape to use for training. 168 split: The choice of data split. 169 modality: The choice of imaging modality. 170 resize_inputs: Whether to resize the inputs. 171 download: Whether to download the data if it is not present. 172 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 173 174 Returns: 175 The DataLoader. 176 """ 177 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 178 dataset = get_amos_dataset(path, patch_shape, split, modality, resize_inputs, download, **ds_kwargs) 179 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://zenodo.org/records/7155725/files/amos22.zip'
CHECKSUM =
'd2fbf2c31abba9824d183f05741ce187b17905b8cca64d1078eabf1ba96775c2'
def
get_amos_data(path: Union[os.PathLike, str], download: bool = False) -> str:
26def get_amos_data(path: Union[os.PathLike, str], download: bool = False) -> str: 27 """Download the AMOS dataset. 28 29 Args: 30 path: Filepath to a folder where the data is downloaded for further processing. 31 download: Whether to download the data if it is not present. 32 33 Returns: 34 Filepath where the data is downloaded. 35 """ 36 data_dir = os.path.join(path, "amos22") 37 if os.path.exists(data_dir): 38 return data_dir 39 40 os.makedirs(path, exist_ok=True) 41 42 zip_path = os.path.join(path, "amos22.zip") 43 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 44 util.unzip(zip_path=zip_path, dst=path) 45 46 shutil.rmtree(os.path.join(path, "__MACOSX")) 47 48 return data_dir
Download the AMOS dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_amos_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], modality: Optional[Literal['CT', 'MRI']] = None, download: bool = False) -> Tuple[List[str], List[str]]:
51def get_amos_paths( 52 path: Union[os.PathLike, str], 53 split: Literal['train', 'val', 'test'], 54 modality: Optional[Literal['CT', 'MRI']] = None, 55 download: bool = False 56) -> Tuple[List[str], List[str]]: 57 """Get paths to the AMOS data. 58 59 Args: 60 path: Filepath to a folder where the data is downloaded for further processing. 61 split: The choice of data split. 62 modality: The choice of imaging modality. 63 download: Whether to download the data if it is not present. 64 65 Returns: 66 List of filepaths for the image data. 67 List of filepaths for the label data. 68 """ 69 data_dir = get_amos_data(path=path, download=download) 70 71 if split == "train": 72 im_dir, gt_dir = "imagesTr", "labelsTr" 73 elif split == "val": 74 im_dir, gt_dir = "imagesVa", "labelsVa" 75 elif split == "test": 76 im_dir, gt_dir = "imagesTs", "labelsTs" 77 else: 78 raise ValueError(f"'{split}' is not a valid split.") 79 80 image_paths = sorted(glob(os.path.join(data_dir, im_dir, "*.nii.gz"))) 81 gt_paths = sorted(glob(os.path.join(data_dir, gt_dir, "*.nii.gz"))) 82 83 if modality is None: 84 chosen_image_paths, chosen_gt_paths = image_paths, gt_paths 85 else: 86 ct_image_paths, ct_gt_paths = [], [] 87 mri_image_paths, mri_gt_paths = [], [] 88 for image_path, gt_path in zip(image_paths, gt_paths): 89 patient_id = Path(image_path.split(".")[0]).stem 90 id_value = int(patient_id.split("_")[-1]) 91 92 is_ct = id_value < 500 93 94 if is_ct: 95 ct_image_paths.append(image_path) 96 ct_gt_paths.append(gt_path) 97 else: 98 mri_image_paths.append(image_path) 99 mri_gt_paths.append(gt_path) 100 101 if modality.upper() == "CT": 102 chosen_image_paths, chosen_gt_paths = ct_image_paths, ct_gt_paths 103 elif modality.upper() == "MRI": 104 chosen_image_paths, chosen_gt_paths = mri_image_paths, mri_gt_paths 105 else: 106 raise ValueError(f"'{modality}' is not a valid modality.") 107 108 return chosen_image_paths, chosen_gt_paths
Get paths to the AMOS data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The choice of data split.
- modality: The choice of imaging modality.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_amos_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], modality: Optional[Literal['CT', 'MRI']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
111def get_amos_dataset( 112 path: Union[os.PathLike, str], 113 patch_shape: Tuple[int, ...], 114 split: Literal['train', 'val', 'test'], 115 modality: Optional[Literal['CT', 'MRI']] = None, 116 resize_inputs: bool = False, 117 download: bool = False, 118 **kwargs 119) -> Dataset: 120 """Get the AMOS dataset for abdominal multi-organ segmentation in CT and MRI scans. 121 122 Args: 123 path: Filepath to a folder where the data is downloaded for further processing. 124 patch_shape: The patch shape to use for traiing. 125 split: The choice of data split. 126 modality: The choice of imaging modality. 127 resize_inputs: Whether to resize the inputs. 128 download: Whether to download the data if it is not present. 129 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 130 131 Returns: 132 The segmentation dataset. 133 """ 134 image_paths, gt_paths = get_amos_paths(path, split, modality, download) 135 136 if resize_inputs: 137 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 138 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 139 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 140 ) 141 142 return torch_em.default_segmentation_dataset( 143 raw_paths=image_paths, 144 raw_key="data", 145 label_paths=gt_paths, 146 label_key="data", 147 patch_shape=patch_shape, 148 is_seg_dataset=True, 149 **kwargs 150 )
Get the AMOS dataset for abdominal multi-organ segmentation in CT and MRI scans.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for traiing.
- split: The choice of data split.
- modality: The choice of imaging modality.
- resize_inputs: Whether to resize the inputs.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_amos_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], modality: Optional[Literal['CT', 'MRI']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
153def get_amos_loader( 154 path: Union[os.PathLike, str], 155 batch_size: int, 156 patch_shape: Tuple[int, ...], 157 split: Literal['train', 'val', 'test'], 158 modality: Optional[Literal['CT', 'MRI']] = None, 159 resize_inputs: bool = False, 160 download: bool = False, 161 **kwargs 162) -> DataLoader: 163 """Get the AMOS dataloader for abdominal multi-organ segmentation in CT and MRI scans. 164 165 Args: 166 path: Filepath to a folder where the data is downloaded for further processing. 167 batch_size: The batch size for training. 168 patch_shape: The patch shape to use for training. 169 split: The choice of data split. 170 modality: The choice of imaging modality. 171 resize_inputs: Whether to resize the inputs. 172 download: Whether to download the data if it is not present. 173 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 174 175 Returns: 176 The DataLoader. 177 """ 178 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 179 dataset = get_amos_dataset(path, patch_shape, split, modality, resize_inputs, download, **ds_kwargs) 180 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the AMOS dataloader for abdominal multi-organ segmentation in CT and MRI scans.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- modality: The choice of imaging modality.
- resize_inputs: Whether to resize the inputs.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.