torch_em.data.datasets.medical.chaos

The CHAOS dataset contains annotations for segmentation of abdominal organs in CT and MRI scans.

This dataset is from the publication ttps://doi.org/10.1016/j.media.2020.101950. Please cite it if you use this dataset for your research.

  1"""The CHAOS dataset contains annotations for segmentation of abdominal organs in
  2CT and MRI scans.
  3
  4This dataset is from the publication ttps://doi.org/10.1016/j.media.2020.101950.
  5Please cite it if you use this dataset for your research.
  6"""
  7
  8import os
  9from glob import glob
 10from tqdm import tqdm
 11from natsort import natsorted
 12from typing import Union, Tuple, Optional, Literal, List
 13
 14import numpy as np
 15
 16from torch.utils.data import Dataset, DataLoader
 17
 18import torch_em
 19
 20from .. import util
 21
 22
 23URL = {
 24    "train": "https://zenodo.org/records/3431873/files/CHAOS_Train_Sets.zip",
 25    "test": "https://zenodo.org/records/3431873/files/CHAOS_Test_Sets.zip"
 26}
 27
 28CHECKSUM = {
 29    "train": "535f7d3417a0e0f0d9133fb3d962423d2a9cf3f103e4f09a3d8a1daf87d5d2fc",
 30    "test": "80e9e4d4c4e363f142de4570e9b698e3f92dcb5140cc25a9c1cf4963e5ae7541"
 31}
 32
 33
 34def get_chaos_data(
 35    path: Union[os.PathLike, str], split: Literal['train', 'test'] = "train", download: bool = False
 36) -> str:
 37    """Download the CHAOS dataset.
 38
 39    Args:
 40        path: Filepath to a folder where the data is downloaded for further processing.
 41        download: Whether to download the data if it is not present.
 42
 43    Returns:
 44        Filepath where the data is downloaded.
 45    """
 46    assert split == "train", "'train' is the only split with ground truth annotations."
 47
 48    data_dir = os.path.join(path, "data", "Train_Sets" if split == "train" else "Test_Sets")
 49    if os.path.exists(data_dir):
 50        return data_dir
 51
 52    os.makedirs(path, exist_ok=True)
 53
 54    zip_path = os.path.join(path, f"chaos_{split}.zip")
 55    util.download_source(path=zip_path, url=URL[split], download=download, checksum=CHECKSUM[split])
 56    util.unzip(zip_path=zip_path, dst=os.path.join(path, "data"))
 57
 58    return data_dir
 59
 60
 61def _open_image(input_path):
 62    ext = os.path.splitext(input_path)[-1]
 63
 64    if ext == ".dcm":
 65        import pydicom as dicom
 66        inputs = dicom.dcmread(input_path)
 67        inputs = inputs.pixel_array
 68
 69    elif ext == ".png":
 70        import imageio.v3 as imageio
 71        inputs = imageio.imread(input_path)
 72
 73    else:
 74        raise ValueError
 75
 76    return inputs
 77
 78
 79def _preprocess_inputs(data_dir, modality):
 80    image_paths, gt_paths = [], []
 81    for m in modality:
 82        if m.upper() == "CT":
 83            m = m.upper()
 84            image_exts = ["DICOM_anon/*"]
 85            gt_exts = ["Ground/*"]
 86
 87        elif m.upper().startswith("MR"):
 88            m = "MR"
 89            image_exts = ["T1DUAL/DICOM_anon/InPhase/*", "T2SPIR/DICOM_anon/*"]
 90            gt_exts = ["T1DUAL/Ground/*", "T2SPIR/Ground/*"]
 91
 92        else:
 93            raise ValueError
 94
 95        series_uids = glob(os.path.join(data_dir, m, "*"))
 96
 97        for uid in tqdm(series_uids):
 98            _id = os.path.split(uid)[-1]
 99
100            base_dir = os.path.join(data_dir, "preprocessed", m.upper())
101
102            os.makedirs(os.path.join(base_dir, "image"), exist_ok=True)
103            os.makedirs(os.path.join(base_dir, "ground_truth"), exist_ok=True)
104
105            for image_ext, gt_ext in zip(image_exts, gt_exts):
106                if m == "MR":
107                    modname = image_ext.split("/")[0] + "_MR"
108                else:
109                    modname = m
110
111                image_path = os.path.join(base_dir, "image", f"{_id}_{modname}.nii.gz")
112                gt_path = os.path.join(base_dir, "ground_truth", f"{_id}_{modname}.nii.gz")
113
114                image_paths.append(image_path)
115                gt_paths.append(gt_path)
116
117                if os.path.exists(image_path) and os.path.exists(gt_path):
118                    continue
119
120                raw_slices = natsorted(glob(os.path.join(uid, image_ext)))
121                gt_slices = natsorted(glob(os.path.join(uid, gt_ext)))
122
123                raw = np.stack([_open_image(raw_slice) for raw_slice in raw_slices])
124                gt = np.stack([_open_image(gt_slice) for gt_slice in gt_slices]).astype("uint8")
125
126                raw = raw.transpose(1, 2, 0)
127                gt = gt.transpose(1, 2, 0)
128
129                import nibabel as nib
130                raw_nifti = nib.Nifti2Image(raw, np.eye(4))
131                nib.save(raw_nifti, image_path)
132
133                gt_nifti = nib.Nifti2Image(gt, np.eye(4))
134                nib.save(gt_nifti, gt_path)
135
136    return image_paths, gt_paths
137
138
139def get_chaos_paths(
140    path: Union[os.PathLike, str],
141    split: Literal['train', 'test'] = "train",
142    modality: Optional[Literal['CT', 'MRI']] = None,
143    download: bool = False
144) -> Tuple[List[int], List[int]]:
145    """Get paths to the CHAOS data.
146
147    Args:
148        path: Filepath to a folder where the data is downloaded for further processing.
149        split: The data split to use. Either 'train', or 'test'.
150        modality: The choice of modality. Either 'CT' or 'MRI'.
151        download: Whether to download the data if it is not present.
152
153    Returns:
154        List of filepaths for the image data.
155        List of filepaths for the label data.
156    """
157    data_dir = get_chaos_data(path=path, split=split, download=download)
158
159    if modality is None:
160        modality = ["CT", "MRI"]
161    else:
162        if isinstance(modality, str):
163            modality = [modality]
164
165    image_paths, gt_paths = _preprocess_inputs(data_dir, modality)
166
167    return image_paths, gt_paths
168
169
170def get_chaos_dataset(
171    path: Union[os.PathLike, str],
172    patch_shape: Tuple[int, ...],
173    split: Literal['train', 'test'] = "train",
174    modality: Optional[Literal['CT', 'MRI']] = None,
175    resize_inputs: bool = False,
176    download: bool = False,
177    **kwargs
178) -> Dataset:
179    """Get the CHAOS dataset for abdominal organ segmentation.
180
181    Args:
182        path: Filepath to a folder where the data is downloaded for further processing.
183        patch_shape: The patch shape to use for training.
184        split: The data split to use. Either 'train', or 'test'.
185        modality: The choice of modality. Either 'CT' or 'MRI'.
186        resize_inputs: Whether to resize inputs to the desired patch shape.
187        download: Whether to download the data if it is not present.
188        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
189
190    Returns:
191        The segmentation dataset.
192    """
193    image_paths, gt_paths = get_chaos_paths(path, split, modality, download)
194
195    if resize_inputs:
196        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
197        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
198            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
199        )
200
201    dataset = torch_em.default_segmentation_dataset(
202        raw_paths=image_paths, raw_key="data", label_paths=gt_paths, label_key="data", patch_shape=patch_shape, **kwargs
203    )
204    dataset.max_sampling_attempts = 5000
205
206    return dataset
207
208
209def get_chaos_loader(
210    path: Union[os.PathLike, str],
211    batch_size: int,
212    patch_shape: Tuple[int, ...],
213    split: str = "train",
214    modality: Optional[str] = None,
215    resize_inputs: bool = False,
216    download: bool = False,
217    **kwargs
218) -> DataLoader:
219    """Get the CHAOS dataloader for abdominal organ segmentation.
220
221    Args:
222        path: Filepath to a folder where the data is downloaded for further processing.
223        batch_size: The batch size for training.
224        patch_shape: The patch shape to use for training.
225        split: The data split to use. Either 'train', or 'test'.
226        modality: The choice of modality. Either 'CT' or 'MRI'.
227        resize_inputs: Whether to resize inputs to the desired patch shape.
228        download: Whether to download the data if it is not present.
229        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
230
231    Returns:
232        The DataLoader.
233    """
234    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
235    dataset = get_chaos_dataset(path, patch_shape, split, modality, resize_inputs, download, **ds_kwargs)
236    return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
URL = {'train': 'https://zenodo.org/records/3431873/files/CHAOS_Train_Sets.zip', 'test': 'https://zenodo.org/records/3431873/files/CHAOS_Test_Sets.zip'}
CHECKSUM = {'train': '535f7d3417a0e0f0d9133fb3d962423d2a9cf3f103e4f09a3d8a1daf87d5d2fc', 'test': '80e9e4d4c4e363f142de4570e9b698e3f92dcb5140cc25a9c1cf4963e5ae7541'}
def get_chaos_data( path: Union[os.PathLike, str], split: Literal['train', 'test'] = 'train', download: bool = False) -> str:
35def get_chaos_data(
36    path: Union[os.PathLike, str], split: Literal['train', 'test'] = "train", download: bool = False
37) -> str:
38    """Download the CHAOS dataset.
39
40    Args:
41        path: Filepath to a folder where the data is downloaded for further processing.
42        download: Whether to download the data if it is not present.
43
44    Returns:
45        Filepath where the data is downloaded.
46    """
47    assert split == "train", "'train' is the only split with ground truth annotations."
48
49    data_dir = os.path.join(path, "data", "Train_Sets" if split == "train" else "Test_Sets")
50    if os.path.exists(data_dir):
51        return data_dir
52
53    os.makedirs(path, exist_ok=True)
54
55    zip_path = os.path.join(path, f"chaos_{split}.zip")
56    util.download_source(path=zip_path, url=URL[split], download=download, checksum=CHECKSUM[split])
57    util.unzip(zip_path=zip_path, dst=os.path.join(path, "data"))
58
59    return data_dir

Download the CHAOS dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_chaos_paths( path: Union[os.PathLike, str], split: Literal['train', 'test'] = 'train', modality: Optional[Literal['CT', 'MRI']] = None, download: bool = False) -> Tuple[List[int], List[int]]:
140def get_chaos_paths(
141    path: Union[os.PathLike, str],
142    split: Literal['train', 'test'] = "train",
143    modality: Optional[Literal['CT', 'MRI']] = None,
144    download: bool = False
145) -> Tuple[List[int], List[int]]:
146    """Get paths to the CHAOS data.
147
148    Args:
149        path: Filepath to a folder where the data is downloaded for further processing.
150        split: The data split to use. Either 'train', or 'test'.
151        modality: The choice of modality. Either 'CT' or 'MRI'.
152        download: Whether to download the data if it is not present.
153
154    Returns:
155        List of filepaths for the image data.
156        List of filepaths for the label data.
157    """
158    data_dir = get_chaos_data(path=path, split=split, download=download)
159
160    if modality is None:
161        modality = ["CT", "MRI"]
162    else:
163        if isinstance(modality, str):
164            modality = [modality]
165
166    image_paths, gt_paths = _preprocess_inputs(data_dir, modality)
167
168    return image_paths, gt_paths

Get paths to the CHAOS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The data split to use. Either 'train', or 'test'.
  • modality: The choice of modality. Either 'CT' or 'MRI'.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_chaos_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'test'] = 'train', modality: Optional[Literal['CT', 'MRI']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
171def get_chaos_dataset(
172    path: Union[os.PathLike, str],
173    patch_shape: Tuple[int, ...],
174    split: Literal['train', 'test'] = "train",
175    modality: Optional[Literal['CT', 'MRI']] = None,
176    resize_inputs: bool = False,
177    download: bool = False,
178    **kwargs
179) -> Dataset:
180    """Get the CHAOS dataset for abdominal organ segmentation.
181
182    Args:
183        path: Filepath to a folder where the data is downloaded for further processing.
184        patch_shape: The patch shape to use for training.
185        split: The data split to use. Either 'train', or 'test'.
186        modality: The choice of modality. Either 'CT' or 'MRI'.
187        resize_inputs: Whether to resize inputs to the desired patch shape.
188        download: Whether to download the data if it is not present.
189        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
190
191    Returns:
192        The segmentation dataset.
193    """
194    image_paths, gt_paths = get_chaos_paths(path, split, modality, download)
195
196    if resize_inputs:
197        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
198        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
199            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
200        )
201
202    dataset = torch_em.default_segmentation_dataset(
203        raw_paths=image_paths, raw_key="data", label_paths=gt_paths, label_key="data", patch_shape=patch_shape, **kwargs
204    )
205    dataset.max_sampling_attempts = 5000
206
207    return dataset

Get the CHAOS dataset for abdominal organ segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The data split to use. Either 'train', or 'test'.
  • modality: The choice of modality. Either 'CT' or 'MRI'.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_chaos_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: str = 'train', modality: Optional[str] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
210def get_chaos_loader(
211    path: Union[os.PathLike, str],
212    batch_size: int,
213    patch_shape: Tuple[int, ...],
214    split: str = "train",
215    modality: Optional[str] = None,
216    resize_inputs: bool = False,
217    download: bool = False,
218    **kwargs
219) -> DataLoader:
220    """Get the CHAOS dataloader for abdominal organ segmentation.
221
222    Args:
223        path: Filepath to a folder where the data is downloaded for further processing.
224        batch_size: The batch size for training.
225        patch_shape: The patch shape to use for training.
226        split: The data split to use. Either 'train', or 'test'.
227        modality: The choice of modality. Either 'CT' or 'MRI'.
228        resize_inputs: Whether to resize inputs to the desired patch shape.
229        download: Whether to download the data if it is not present.
230        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
231
232    Returns:
233        The DataLoader.
234    """
235    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
236    dataset = get_chaos_dataset(path, patch_shape, split, modality, resize_inputs, download, **ds_kwargs)
237    return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)

Get the CHAOS dataloader for abdominal organ segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The data split to use. Either 'train', or 'test'.
  • modality: The choice of modality. Either 'CT' or 'MRI'.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.