torch_em.data.datasets.medical.amos

The AMOS dataset contains annotations for abdominal multi-organ segmentation in CT and MRI scans.

This dataset is located at https://doi.org/10.5281/zenodo.7155725. The dataset is from AMOS 2022 Challenge https://doi.org/10.48550/arXiv.2206.08023. Please cite them if you use this dataset for your research.

  1"""The AMOS dataset contains annotations for abdominal multi-organ segmentation in CT and MRI scans.
  2
  3This dataset is located at https://doi.org/10.5281/zenodo.7155725.
  4The dataset is from AMOS 2022 Challenge https://doi.org/10.48550/arXiv.2206.08023.
  5Please cite them if you use this dataset for your research.
  6"""
  7
  8import os
  9from glob import glob
 10from pathlib import Path
 11from typing import Union, Tuple, Optional, Literal, List
 12
 13from torch.utils.data import Dataset, DataLoader
 14
 15import torch_em
 16
 17from .. import util
 18
 19
 20URL = "https://zenodo.org/records/7155725/files/amos22.zip"
 21CHECKSUM = "d2fbf2c31abba9824d183f05741ce187b17905b8cca64d1078eabf1ba96775c2"
 22
 23
 24def get_amos_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 25    """Download the AMOS dataset.
 26
 27    Args:
 28        path: Filepath to a folder where the data is downloaded for further processing.
 29        download: Whether to download the data if it is not present.
 30
 31    Returns:
 32        Filepath where the data is downloaded.
 33    """
 34    data_dir = os.path.join(path, "amos22")
 35    if os.path.exists(data_dir):
 36        return data_dir
 37
 38    os.makedirs(path, exist_ok=True)
 39
 40    zip_path = os.path.join(path, "amos22.zip")
 41    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 42    util.unzip(zip_path=zip_path, dst=path)
 43
 44    return data_dir
 45
 46
 47def get_amos_paths(
 48    path: Union[os.PathLike, str],
 49    split: Literal['train', 'val', 'test'],
 50    modality: Optional[Literal['CT', 'MRI']] = None,
 51    download: bool = False
 52) -> Tuple[List[str], List[str]]:
 53    """Get paths to the AMOS data.
 54
 55    Args:
 56        path: Filepath to a folder where the data is downloaded for further processing.
 57        split: The choice of data split.
 58        modality: The choice of imaging modality.
 59        download: Whether to download the data if it is not present.
 60
 61    Returns:
 62        List of filepaths for the image data.
 63        List of filepaths for the label data.
 64    """
 65    data_dir = get_amos_data(path=path, download=download)
 66
 67    if split == "train":
 68        im_dir, gt_dir = "imagesTr", "labelsTr"
 69    elif split == "val":
 70        im_dir, gt_dir = "imagesVa", "labelsVa"
 71    elif split == "test":
 72        im_dir, gt_dir = "imagesTs", "labelsTs"
 73    else:
 74        raise ValueError(f"'{split}' is not a valid split.")
 75
 76    image_paths = sorted(glob(os.path.join(data_dir, im_dir, "*.nii.gz")))
 77    gt_paths = sorted(glob(os.path.join(data_dir, gt_dir, "*.nii.gz")))
 78
 79    if modality is None:
 80        chosen_image_paths, chosen_gt_paths = image_paths, gt_paths
 81    else:
 82        ct_image_paths, ct_gt_paths = [], []
 83        mri_image_paths, mri_gt_paths = [], []
 84        for image_path, gt_path in zip(image_paths, gt_paths):
 85            patient_id = Path(image_path.split(".")[0]).stem
 86            id_value = int(patient_id.split("_")[-1])
 87
 88            is_ct = id_value < 500
 89
 90            if is_ct:
 91                ct_image_paths.append(image_path)
 92                ct_gt_paths.append(gt_path)
 93            else:
 94                mri_image_paths.append(image_path)
 95                mri_gt_paths.append(gt_path)
 96
 97        if modality.upper() == "CT":
 98            chosen_image_paths, chosen_gt_paths = ct_image_paths, ct_gt_paths
 99        elif modality.upper() == "MRI":
100            chosen_image_paths, chosen_gt_paths = mri_image_paths, mri_gt_paths
101        else:
102            raise ValueError(f"'{modality}' is not a valid modality.")
103
104    return chosen_image_paths, chosen_gt_paths
105
106
107def get_amos_dataset(
108    path: Union[os.PathLike, str],
109    patch_shape: Tuple[int, ...],
110    split: Literal['train', 'val', 'test'],
111    modality: Optional[Literal['CT', 'MRI']] = None,
112    resize_inputs: bool = False,
113    download: bool = False,
114    **kwargs
115) -> Dataset:
116    """Get the AMOS dataset for abdominal multi-organ segmentation in CT and MRI scans.
117
118    Args:
119        path: Filepath to a folder where the data is downloaded for further processing.
120        patch_shape: The patch shape to use for traiing.
121        split: The choice of data split.
122        modality: The choice of imaging modality.
123        resize_inputs: Whether to resize the inputs.
124        download: Whether to download the data if it is not present.
125        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
126
127    Returns:
128        The segmentation dataset.
129    """
130    image_paths, gt_paths = get_amos_paths(path, split, modality, download)
131
132    if resize_inputs:
133        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
134        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
135            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
136        )
137
138    return torch_em.default_segmentation_dataset(
139        raw_paths=image_paths,
140        raw_key="data",
141        label_paths=gt_paths,
142        label_key="data",
143        patch_shape=patch_shape,
144        is_seg_dataset=True,
145        **kwargs
146    )
147
148
149def get_amos_loader(
150    path: Union[os.PathLike, str],
151    batch_size: int,
152    patch_shape: Tuple[int, ...],
153    split: Literal['train', 'val', 'test'],
154    modality: Optional[Literal['CT', 'MRI']] = None,
155    resize_inputs: bool = False,
156    download: bool = False,
157    **kwargs
158) -> DataLoader:
159    """Get the AMOS dataloader for abdominal multi-organ segmentation in CT and MRI scans.
160
161    Args:
162        path: Filepath to a folder where the data is downloaded for further processing.
163        batch_size: The batch size for training.
164        patch_shape: The patch shape to use for training.
165        split: The choice of data split.
166        modality: The choice of imaging modality.
167        resize_inputs: Whether to resize the inputs.
168        download: Whether to download the data if it is not present.
169        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
170
171    Returns:
172        The DataLoader.
173    """
174    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
175    dataset = get_amos_dataset(path, patch_shape, split, modality, resize_inputs, download, **ds_kwargs)
176    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://zenodo.org/records/7155725/files/amos22.zip'
CHECKSUM = 'd2fbf2c31abba9824d183f05741ce187b17905b8cca64d1078eabf1ba96775c2'
def get_amos_data(path: Union[os.PathLike, str], download: bool = False) -> str:
25def get_amos_data(path: Union[os.PathLike, str], download: bool = False) -> str:
26    """Download the AMOS dataset.
27
28    Args:
29        path: Filepath to a folder where the data is downloaded for further processing.
30        download: Whether to download the data if it is not present.
31
32    Returns:
33        Filepath where the data is downloaded.
34    """
35    data_dir = os.path.join(path, "amos22")
36    if os.path.exists(data_dir):
37        return data_dir
38
39    os.makedirs(path, exist_ok=True)
40
41    zip_path = os.path.join(path, "amos22.zip")
42    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
43    util.unzip(zip_path=zip_path, dst=path)
44
45    return data_dir

Download the AMOS dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_amos_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], modality: Optional[Literal['CT', 'MRI']] = None, download: bool = False) -> Tuple[List[str], List[str]]:
 48def get_amos_paths(
 49    path: Union[os.PathLike, str],
 50    split: Literal['train', 'val', 'test'],
 51    modality: Optional[Literal['CT', 'MRI']] = None,
 52    download: bool = False
 53) -> Tuple[List[str], List[str]]:
 54    """Get paths to the AMOS data.
 55
 56    Args:
 57        path: Filepath to a folder where the data is downloaded for further processing.
 58        split: The choice of data split.
 59        modality: The choice of imaging modality.
 60        download: Whether to download the data if it is not present.
 61
 62    Returns:
 63        List of filepaths for the image data.
 64        List of filepaths for the label data.
 65    """
 66    data_dir = get_amos_data(path=path, download=download)
 67
 68    if split == "train":
 69        im_dir, gt_dir = "imagesTr", "labelsTr"
 70    elif split == "val":
 71        im_dir, gt_dir = "imagesVa", "labelsVa"
 72    elif split == "test":
 73        im_dir, gt_dir = "imagesTs", "labelsTs"
 74    else:
 75        raise ValueError(f"'{split}' is not a valid split.")
 76
 77    image_paths = sorted(glob(os.path.join(data_dir, im_dir, "*.nii.gz")))
 78    gt_paths = sorted(glob(os.path.join(data_dir, gt_dir, "*.nii.gz")))
 79
 80    if modality is None:
 81        chosen_image_paths, chosen_gt_paths = image_paths, gt_paths
 82    else:
 83        ct_image_paths, ct_gt_paths = [], []
 84        mri_image_paths, mri_gt_paths = [], []
 85        for image_path, gt_path in zip(image_paths, gt_paths):
 86            patient_id = Path(image_path.split(".")[0]).stem
 87            id_value = int(patient_id.split("_")[-1])
 88
 89            is_ct = id_value < 500
 90
 91            if is_ct:
 92                ct_image_paths.append(image_path)
 93                ct_gt_paths.append(gt_path)
 94            else:
 95                mri_image_paths.append(image_path)
 96                mri_gt_paths.append(gt_path)
 97
 98        if modality.upper() == "CT":
 99            chosen_image_paths, chosen_gt_paths = ct_image_paths, ct_gt_paths
100        elif modality.upper() == "MRI":
101            chosen_image_paths, chosen_gt_paths = mri_image_paths, mri_gt_paths
102        else:
103            raise ValueError(f"'{modality}' is not a valid modality.")
104
105    return chosen_image_paths, chosen_gt_paths

Get paths to the AMOS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • modality: The choice of imaging modality.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_amos_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], modality: Optional[Literal['CT', 'MRI']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
108def get_amos_dataset(
109    path: Union[os.PathLike, str],
110    patch_shape: Tuple[int, ...],
111    split: Literal['train', 'val', 'test'],
112    modality: Optional[Literal['CT', 'MRI']] = None,
113    resize_inputs: bool = False,
114    download: bool = False,
115    **kwargs
116) -> Dataset:
117    """Get the AMOS dataset for abdominal multi-organ segmentation in CT and MRI scans.
118
119    Args:
120        path: Filepath to a folder where the data is downloaded for further processing.
121        patch_shape: The patch shape to use for traiing.
122        split: The choice of data split.
123        modality: The choice of imaging modality.
124        resize_inputs: Whether to resize the inputs.
125        download: Whether to download the data if it is not present.
126        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
127
128    Returns:
129        The segmentation dataset.
130    """
131    image_paths, gt_paths = get_amos_paths(path, split, modality, download)
132
133    if resize_inputs:
134        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
135        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
136            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
137        )
138
139    return torch_em.default_segmentation_dataset(
140        raw_paths=image_paths,
141        raw_key="data",
142        label_paths=gt_paths,
143        label_key="data",
144        patch_shape=patch_shape,
145        is_seg_dataset=True,
146        **kwargs
147    )

Get the AMOS dataset for abdominal multi-organ segmentation in CT and MRI scans.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for traiing.
  • split: The choice of data split.
  • modality: The choice of imaging modality.
  • resize_inputs: Whether to resize the inputs.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_amos_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], modality: Optional[Literal['CT', 'MRI']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
150def get_amos_loader(
151    path: Union[os.PathLike, str],
152    batch_size: int,
153    patch_shape: Tuple[int, ...],
154    split: Literal['train', 'val', 'test'],
155    modality: Optional[Literal['CT', 'MRI']] = None,
156    resize_inputs: bool = False,
157    download: bool = False,
158    **kwargs
159) -> DataLoader:
160    """Get the AMOS dataloader for abdominal multi-organ segmentation in CT and MRI scans.
161
162    Args:
163        path: Filepath to a folder where the data is downloaded for further processing.
164        batch_size: The batch size for training.
165        patch_shape: The patch shape to use for training.
166        split: The choice of data split.
167        modality: The choice of imaging modality.
168        resize_inputs: Whether to resize the inputs.
169        download: Whether to download the data if it is not present.
170        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
171
172    Returns:
173        The DataLoader.
174    """
175    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
176    dataset = get_amos_dataset(path, patch_shape, split, modality, resize_inputs, download, **ds_kwargs)
177    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the AMOS dataloader for abdominal multi-organ segmentation in CT and MRI scans.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • modality: The choice of imaging modality.
  • resize_inputs: Whether to resize the inputs.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.