torch_em.data.datasets.medical.amos

The AMOS dataset contains annotations for abdominal multi-organ segmentation in CT and MRI scans.

This dataset is located at https://doi.org/10.5281/zenodo.7155725. The dataset is from AMOS 2022 Challenge https://doi.org/10.48550/arXiv.2206.08023. Please cite them if you use this dataset for your research.

  1"""The AMOS dataset contains annotations for abdominal multi-organ segmentation in CT and MRI scans.
  2
  3This dataset is located at https://doi.org/10.5281/zenodo.7155725.
  4The dataset is from AMOS 2022 Challenge https://doi.org/10.48550/arXiv.2206.08023.
  5Please cite them if you use this dataset for your research.
  6"""
  7
  8import os
  9import shutil
 10from glob import glob
 11from pathlib import Path
 12from typing import Union, Tuple, Optional, Literal, List
 13
 14from torch.utils.data import Dataset, DataLoader
 15
 16import torch_em
 17
 18from .. import util
 19
 20
 21URL = "https://zenodo.org/records/7155725/files/amos22.zip"
 22CHECKSUM = "d2fbf2c31abba9824d183f05741ce187b17905b8cca64d1078eabf1ba96775c2"
 23
 24
 25def get_amos_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 26    """Download the AMOS dataset.
 27
 28    Args:
 29        path: Filepath to a folder where the data is downloaded for further processing.
 30        download: Whether to download the data if it is not present.
 31
 32    Returns:
 33        Filepath where the data is downloaded.
 34    """
 35    data_dir = os.path.join(path, "amos22")
 36    if os.path.exists(data_dir):
 37        return data_dir
 38
 39    os.makedirs(path, exist_ok=True)
 40
 41    zip_path = os.path.join(path, "amos22.zip")
 42    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 43    util.unzip(zip_path=zip_path, dst=path)
 44
 45    shutil.rmtree(os.path.join(path, "__MACOSX"))
 46
 47    return data_dir
 48
 49
 50def get_amos_paths(
 51    path: Union[os.PathLike, str],
 52    split: Literal['train', 'val', 'test'],
 53    modality: Optional[Literal['CT', 'MRI']] = None,
 54    download: bool = False
 55) -> Tuple[List[str], List[str]]:
 56    """Get paths to the AMOS data.
 57
 58    Args:
 59        path: Filepath to a folder where the data is downloaded for further processing.
 60        split: The choice of data split.
 61        modality: The choice of imaging modality.
 62        download: Whether to download the data if it is not present.
 63
 64    Returns:
 65        List of filepaths for the image data.
 66        List of filepaths for the label data.
 67    """
 68    data_dir = get_amos_data(path=path, download=download)
 69
 70    if split == "train":
 71        im_dir, gt_dir = "imagesTr", "labelsTr"
 72    elif split == "val":
 73        im_dir, gt_dir = "imagesVa", "labelsVa"
 74    elif split == "test":
 75        im_dir, gt_dir = "imagesTs", "labelsTs"
 76    else:
 77        raise ValueError(f"'{split}' is not a valid split.")
 78
 79    image_paths = sorted(glob(os.path.join(data_dir, im_dir, "*.nii.gz")))
 80    gt_paths = sorted(glob(os.path.join(data_dir, gt_dir, "*.nii.gz")))
 81
 82    if modality is None:
 83        chosen_image_paths, chosen_gt_paths = image_paths, gt_paths
 84    else:
 85        ct_image_paths, ct_gt_paths = [], []
 86        mri_image_paths, mri_gt_paths = [], []
 87        for image_path, gt_path in zip(image_paths, gt_paths):
 88            patient_id = Path(image_path.split(".")[0]).stem
 89            id_value = int(patient_id.split("_")[-1])
 90
 91            is_ct = id_value < 500
 92
 93            if is_ct:
 94                ct_image_paths.append(image_path)
 95                ct_gt_paths.append(gt_path)
 96            else:
 97                mri_image_paths.append(image_path)
 98                mri_gt_paths.append(gt_path)
 99
100        if modality.upper() == "CT":
101            chosen_image_paths, chosen_gt_paths = ct_image_paths, ct_gt_paths
102        elif modality.upper() == "MRI":
103            chosen_image_paths, chosen_gt_paths = mri_image_paths, mri_gt_paths
104        else:
105            raise ValueError(f"'{modality}' is not a valid modality.")
106
107    return chosen_image_paths, chosen_gt_paths
108
109
110def get_amos_dataset(
111    path: Union[os.PathLike, str],
112    patch_shape: Tuple[int, ...],
113    split: Literal['train', 'val', 'test'],
114    modality: Optional[Literal['CT', 'MRI']] = None,
115    resize_inputs: bool = False,
116    download: bool = False,
117    **kwargs
118) -> Dataset:
119    """Get the AMOS dataset for abdominal multi-organ segmentation in CT and MRI scans.
120
121    Args:
122        path: Filepath to a folder where the data is downloaded for further processing.
123        patch_shape: The patch shape to use for traiing.
124        split: The choice of data split.
125        modality: The choice of imaging modality.
126        resize_inputs: Whether to resize the inputs.
127        download: Whether to download the data if it is not present.
128        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
129
130    Returns:
131        The segmentation dataset.
132    """
133    image_paths, gt_paths = get_amos_paths(path, split, modality, download)
134
135    if resize_inputs:
136        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
137        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
138            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
139        )
140
141    return torch_em.default_segmentation_dataset(
142        raw_paths=image_paths,
143        raw_key="data",
144        label_paths=gt_paths,
145        label_key="data",
146        patch_shape=patch_shape,
147        is_seg_dataset=True,
148        **kwargs
149    )
150
151
152def get_amos_loader(
153    path: Union[os.PathLike, str],
154    batch_size: int,
155    patch_shape: Tuple[int, ...],
156    split: Literal['train', 'val', 'test'],
157    modality: Optional[Literal['CT', 'MRI']] = None,
158    resize_inputs: bool = False,
159    download: bool = False,
160    **kwargs
161) -> DataLoader:
162    """Get the AMOS dataloader for abdominal multi-organ segmentation in CT and MRI scans.
163
164    Args:
165        path: Filepath to a folder where the data is downloaded for further processing.
166        batch_size: The batch size for training.
167        patch_shape: The patch shape to use for training.
168        split: The choice of data split.
169        modality: The choice of imaging modality.
170        resize_inputs: Whether to resize the inputs.
171        download: Whether to download the data if it is not present.
172        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
173
174    Returns:
175        The DataLoader.
176    """
177    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
178    dataset = get_amos_dataset(path, patch_shape, split, modality, resize_inputs, download, **ds_kwargs)
179    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://zenodo.org/records/7155725/files/amos22.zip'
CHECKSUM = 'd2fbf2c31abba9824d183f05741ce187b17905b8cca64d1078eabf1ba96775c2'
def get_amos_data(path: Union[os.PathLike, str], download: bool = False) -> str:
26def get_amos_data(path: Union[os.PathLike, str], download: bool = False) -> str:
27    """Download the AMOS dataset.
28
29    Args:
30        path: Filepath to a folder where the data is downloaded for further processing.
31        download: Whether to download the data if it is not present.
32
33    Returns:
34        Filepath where the data is downloaded.
35    """
36    data_dir = os.path.join(path, "amos22")
37    if os.path.exists(data_dir):
38        return data_dir
39
40    os.makedirs(path, exist_ok=True)
41
42    zip_path = os.path.join(path, "amos22.zip")
43    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
44    util.unzip(zip_path=zip_path, dst=path)
45
46    shutil.rmtree(os.path.join(path, "__MACOSX"))
47
48    return data_dir

Download the AMOS dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_amos_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], modality: Optional[Literal['CT', 'MRI']] = None, download: bool = False) -> Tuple[List[str], List[str]]:
 51def get_amos_paths(
 52    path: Union[os.PathLike, str],
 53    split: Literal['train', 'val', 'test'],
 54    modality: Optional[Literal['CT', 'MRI']] = None,
 55    download: bool = False
 56) -> Tuple[List[str], List[str]]:
 57    """Get paths to the AMOS data.
 58
 59    Args:
 60        path: Filepath to a folder where the data is downloaded for further processing.
 61        split: The choice of data split.
 62        modality: The choice of imaging modality.
 63        download: Whether to download the data if it is not present.
 64
 65    Returns:
 66        List of filepaths for the image data.
 67        List of filepaths for the label data.
 68    """
 69    data_dir = get_amos_data(path=path, download=download)
 70
 71    if split == "train":
 72        im_dir, gt_dir = "imagesTr", "labelsTr"
 73    elif split == "val":
 74        im_dir, gt_dir = "imagesVa", "labelsVa"
 75    elif split == "test":
 76        im_dir, gt_dir = "imagesTs", "labelsTs"
 77    else:
 78        raise ValueError(f"'{split}' is not a valid split.")
 79
 80    image_paths = sorted(glob(os.path.join(data_dir, im_dir, "*.nii.gz")))
 81    gt_paths = sorted(glob(os.path.join(data_dir, gt_dir, "*.nii.gz")))
 82
 83    if modality is None:
 84        chosen_image_paths, chosen_gt_paths = image_paths, gt_paths
 85    else:
 86        ct_image_paths, ct_gt_paths = [], []
 87        mri_image_paths, mri_gt_paths = [], []
 88        for image_path, gt_path in zip(image_paths, gt_paths):
 89            patient_id = Path(image_path.split(".")[0]).stem
 90            id_value = int(patient_id.split("_")[-1])
 91
 92            is_ct = id_value < 500
 93
 94            if is_ct:
 95                ct_image_paths.append(image_path)
 96                ct_gt_paths.append(gt_path)
 97            else:
 98                mri_image_paths.append(image_path)
 99                mri_gt_paths.append(gt_path)
100
101        if modality.upper() == "CT":
102            chosen_image_paths, chosen_gt_paths = ct_image_paths, ct_gt_paths
103        elif modality.upper() == "MRI":
104            chosen_image_paths, chosen_gt_paths = mri_image_paths, mri_gt_paths
105        else:
106            raise ValueError(f"'{modality}' is not a valid modality.")
107
108    return chosen_image_paths, chosen_gt_paths

Get paths to the AMOS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • modality: The choice of imaging modality.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_amos_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], modality: Optional[Literal['CT', 'MRI']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
111def get_amos_dataset(
112    path: Union[os.PathLike, str],
113    patch_shape: Tuple[int, ...],
114    split: Literal['train', 'val', 'test'],
115    modality: Optional[Literal['CT', 'MRI']] = None,
116    resize_inputs: bool = False,
117    download: bool = False,
118    **kwargs
119) -> Dataset:
120    """Get the AMOS dataset for abdominal multi-organ segmentation in CT and MRI scans.
121
122    Args:
123        path: Filepath to a folder where the data is downloaded for further processing.
124        patch_shape: The patch shape to use for traiing.
125        split: The choice of data split.
126        modality: The choice of imaging modality.
127        resize_inputs: Whether to resize the inputs.
128        download: Whether to download the data if it is not present.
129        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
130
131    Returns:
132        The segmentation dataset.
133    """
134    image_paths, gt_paths = get_amos_paths(path, split, modality, download)
135
136    if resize_inputs:
137        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
138        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
139            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
140        )
141
142    return torch_em.default_segmentation_dataset(
143        raw_paths=image_paths,
144        raw_key="data",
145        label_paths=gt_paths,
146        label_key="data",
147        patch_shape=patch_shape,
148        is_seg_dataset=True,
149        **kwargs
150    )

Get the AMOS dataset for abdominal multi-organ segmentation in CT and MRI scans.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for traiing.
  • split: The choice of data split.
  • modality: The choice of imaging modality.
  • resize_inputs: Whether to resize the inputs.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_amos_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], modality: Optional[Literal['CT', 'MRI']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
153def get_amos_loader(
154    path: Union[os.PathLike, str],
155    batch_size: int,
156    patch_shape: Tuple[int, ...],
157    split: Literal['train', 'val', 'test'],
158    modality: Optional[Literal['CT', 'MRI']] = None,
159    resize_inputs: bool = False,
160    download: bool = False,
161    **kwargs
162) -> DataLoader:
163    """Get the AMOS dataloader for abdominal multi-organ segmentation in CT and MRI scans.
164
165    Args:
166        path: Filepath to a folder where the data is downloaded for further processing.
167        batch_size: The batch size for training.
168        patch_shape: The patch shape to use for training.
169        split: The choice of data split.
170        modality: The choice of imaging modality.
171        resize_inputs: Whether to resize the inputs.
172        download: Whether to download the data if it is not present.
173        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
174
175    Returns:
176        The DataLoader.
177    """
178    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
179    dataset = get_amos_dataset(path, patch_shape, split, modality, resize_inputs, download, **ds_kwargs)
180    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the AMOS dataloader for abdominal multi-organ segmentation in CT and MRI scans.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • modality: The choice of imaging modality.
  • resize_inputs: Whether to resize the inputs.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.