torch_em.data.datasets.medical.curvas

The CURVAS dataset contains annotations for pancreas, kidney and liver in abdominal CT scans.

This dataset is from the challenge: https://curvas.grand-challenge.org. The dataset is located at: https://zenodo.org/records/12687192. Please cite tem if you use this dataset for your research.

  1"""The CURVAS dataset contains annotations for pancreas, kidney and liver
  2in abdominal CT scans.
  3
  4This dataset is from the challenge: https://curvas.grand-challenge.org.
  5The dataset is located at: https://zenodo.org/records/12687192.
  6Please cite tem if you use this dataset for your research.
  7"""
  8
  9import os
 10import shutil
 11import subprocess
 12from tqdm import tqdm
 13from glob import glob
 14from natsort import natsorted
 15from typing import Tuple, Union, Literal, List
 16
 17import numpy as np
 18
 19from torch.utils.data import Dataset, DataLoader
 20
 21import torch_em
 22
 23from .. import util
 24
 25
 26URL = "https://zenodo.org/records/12687192/files/training_set.zip"
 27CHECKSUM = "1126a2205553ae1d4fe5fbaee7ea732aacc4f5a92b96504ed521c23e5a0e3f89"
 28
 29
 30def _preprocess_data(data_dir):
 31    import h5py
 32    import nibabel as nib
 33
 34    h5_dir = os.path.join(os.path.dirname(data_dir), "data")
 35    os.makedirs(h5_dir, exist_ok=True)
 36
 37    image_paths = natsorted(glob(os.path.join(data_dir, "*", "image.nii.gz")))
 38    for image_path in tqdm(image_paths, desc="Processing data"):
 39        rater1_path = os.path.join(os.path.dirname(image_path), "annotation_1.nii.gz")
 40        rater2_path = os.path.join(os.path.dirname(image_path), "annotation_2.nii.gz")
 41        rater3_path = os.path.join(os.path.dirname(image_path), "annotation_3.nii.gz")
 42
 43        assert os.path.exists(rater1_path) and os.path.exists(rater2_path) and os.path.exists(rater3_path)
 44
 45        image = nib.load(image_path).get_fdata().astype("float32").transpose(2, 0, 1)
 46
 47        label_r1 = np.rint(nib.load(rater1_path).get_fdata()).astype("uint8").transpose(2, 0, 1)
 48        label_r2 = np.rint(nib.load(rater2_path).get_fdata()).astype("uint8").transpose(2, 0, 1)
 49        label_r3 = np.rint(nib.load(rater3_path).get_fdata()).astype("uint8").transpose(2, 0, 1)
 50
 51        fname = os.path.basename(os.path.dirname(image_path))
 52        chunks = (8, 512, 512)
 53        with h5py.File(os.path.join(h5_dir, f"{fname}.h5"), "w") as f:
 54            f.create_dataset("raw", data=image, compression="gzip", chunks=chunks)
 55            f.create_dataset("labels/rater_1", data=label_r1, compression="gzip", chunks=chunks)
 56            f.create_dataset("labels/rater_2", data=label_r2, compression="gzip", chunks=chunks)
 57            f.create_dataset("labels/rater_3", data=label_r3, compression="gzip", chunks=chunks)
 58
 59    # Remove the nifti files as we don't need them anymore!
 60    shutil.rmtree(data_dir)
 61
 62
 63def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 64    """Download the CURVAS dataset.
 65
 66    Args:
 67        path: Filepath to a folder where the data is downloaded for further processing.
 68        download: Whether to download the data if it is not present.
 69
 70    Returns:
 71        Filepath where the data is downloaded.
 72    """
 73    data_dir = os.path.join(path, "data")
 74    if os.path.exists(data_dir):
 75        return data_dir
 76
 77    os.makedirs(path, exist_ok=True)
 78
 79    zip_path = os.path.join(path, "training_set.zip")
 80    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 81
 82    # HACK: The zip file is broken. We fix it using the following script.
 83    fixed_zip_path = os.path.join(path, "training_set_fixed.zip")
 84    subprocess.run(["zip", "-FF", zip_path, "--out", fixed_zip_path])
 85    subprocess.run(["unzip", fixed_zip_path, "-d", path])
 86
 87    _preprocess_data(os.path.join(path, "training_set"))
 88
 89    # Remove the zip files as we don't need them anymore.
 90    os.remove(zip_path)
 91    os.remove(fixed_zip_path)
 92
 93    return data_dir
 94
 95
 96def get_curvas_paths(
 97    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 98) -> List[str]:
 99    """Get paths to the CURVAS data.
100
101    Args:
102        path: Filepath to a folder where the data is downloaded for further processing.
103        split: The choice of data split.
104        download: Whether to download the data if it is not present.
105
106    Returns:
107        List of filepaths for the volumetric data.
108    """
109    data_dir = get_curvas_data(path, download)
110    volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5")))
111
112    if split == "train":
113        volume_paths = volume_paths[:10]
114    elif split == "val":
115        volume_paths = volume_paths[10:13]
116    elif split == "test":
117        volume_paths = volume_paths[13:]
118    else:
119        raise ValueError(f"'{split}' is not a valid split.")
120
121    return volume_paths
122
123
124def get_curvas_dataset(
125    path: Union[os.PathLike, str],
126    patch_shape: Tuple[int, ...],
127    split: Literal['train', 'val', 'test'],
128    rater: Literal["1", "2", "3"] = "1",
129    resize_inputs: bool = False,
130    download: bool = False,
131    **kwargs
132) -> Dataset:
133    """Get the CURVAS dataset for pancreas, kidney and liver segmentation.
134
135    Args:
136        path: Filepath to a folder where the data is downloaded for further processing.
137        patch_shape: The patch shape to use for training.
138        split: The choice of data split.
139        rater: The choice of rater providing the annotations.
140        resize_inputs: Whether to resize inputs to the desired patch shape.
141        download: Whether to download the data if it is not present.
142        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
143
144    Returns:
145        The segmentation dataset.
146    """
147    volume_paths = get_curvas_paths(path, split, download)
148
149    if resize_inputs:
150        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
151        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
152            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
153        )
154
155    return torch_em.default_segmentation_dataset(
156        raw_paths=volume_paths,
157        raw_key="raw",
158        label_paths=volume_paths,
159        label_key=f"labels/rater_{rater}",
160        patch_shape=patch_shape,
161        is_seg_dataset=True,
162        **kwargs,
163    )
164
165
166def get_curvas_loader(
167    path: Union[os.PathLike, str],
168    batch_size: int,
169    patch_shape: Tuple[int, ...],
170    split: Literal['train', 'val', 'test'],
171    rater: Literal["1", "2", "3"] = "1",
172    resize_inputs: bool = False,
173    download: bool = False,
174    **kwargs
175) -> DataLoader:
176    """Get the CURVAS dataloader for pancreas, kidney and liver segmentation.
177
178    Args:
179        path: Filepath to a folder where the data is downloaded for further processing.
180        batch_size: The batch size for training.
181        patch_shape: The patch shape to use for training.
182        split: The choice of data split.
183        rater: The choice of rater providing the annotations.
184        resize_inputs: Whether to resize inputs to the desired patch shape.
185        download: Whether to download the data if it is not present.
186        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
187
188    Returns:
189        The DataLoader.
190    """
191    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
192    dataset = get_curvas_dataset(path, patch_shape, split, rater, resize_inputs, download, **ds_kwargs)
193    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://zenodo.org/records/12687192/files/training_set.zip'
CHECKSUM = '1126a2205553ae1d4fe5fbaee7ea732aacc4f5a92b96504ed521c23e5a0e3f89'
def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str:
64def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str:
65    """Download the CURVAS dataset.
66
67    Args:
68        path: Filepath to a folder where the data is downloaded for further processing.
69        download: Whether to download the data if it is not present.
70
71    Returns:
72        Filepath where the data is downloaded.
73    """
74    data_dir = os.path.join(path, "data")
75    if os.path.exists(data_dir):
76        return data_dir
77
78    os.makedirs(path, exist_ok=True)
79
80    zip_path = os.path.join(path, "training_set.zip")
81    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
82
83    # HACK: The zip file is broken. We fix it using the following script.
84    fixed_zip_path = os.path.join(path, "training_set_fixed.zip")
85    subprocess.run(["zip", "-FF", zip_path, "--out", fixed_zip_path])
86    subprocess.run(["unzip", fixed_zip_path, "-d", path])
87
88    _preprocess_data(os.path.join(path, "training_set"))
89
90    # Remove the zip files as we don't need them anymore.
91    os.remove(zip_path)
92    os.remove(fixed_zip_path)
93
94    return data_dir

Download the CURVAS dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_curvas_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> List[str]:
 97def get_curvas_paths(
 98    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 99) -> List[str]:
100    """Get paths to the CURVAS data.
101
102    Args:
103        path: Filepath to a folder where the data is downloaded for further processing.
104        split: The choice of data split.
105        download: Whether to download the data if it is not present.
106
107    Returns:
108        List of filepaths for the volumetric data.
109    """
110    data_dir = get_curvas_data(path, download)
111    volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5")))
112
113    if split == "train":
114        volume_paths = volume_paths[:10]
115    elif split == "val":
116        volume_paths = volume_paths[10:13]
117    elif split == "test":
118        volume_paths = volume_paths[13:]
119    else:
120        raise ValueError(f"'{split}' is not a valid split.")
121
122    return volume_paths

Get paths to the CURVAS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the volumetric data.

def get_curvas_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Literal['1', '2', '3'] = '1', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
125def get_curvas_dataset(
126    path: Union[os.PathLike, str],
127    patch_shape: Tuple[int, ...],
128    split: Literal['train', 'val', 'test'],
129    rater: Literal["1", "2", "3"] = "1",
130    resize_inputs: bool = False,
131    download: bool = False,
132    **kwargs
133) -> Dataset:
134    """Get the CURVAS dataset for pancreas, kidney and liver segmentation.
135
136    Args:
137        path: Filepath to a folder where the data is downloaded for further processing.
138        patch_shape: The patch shape to use for training.
139        split: The choice of data split.
140        rater: The choice of rater providing the annotations.
141        resize_inputs: Whether to resize inputs to the desired patch shape.
142        download: Whether to download the data if it is not present.
143        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
144
145    Returns:
146        The segmentation dataset.
147    """
148    volume_paths = get_curvas_paths(path, split, download)
149
150    if resize_inputs:
151        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
152        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
153            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
154        )
155
156    return torch_em.default_segmentation_dataset(
157        raw_paths=volume_paths,
158        raw_key="raw",
159        label_paths=volume_paths,
160        label_key=f"labels/rater_{rater}",
161        patch_shape=patch_shape,
162        is_seg_dataset=True,
163        **kwargs,
164    )

Get the CURVAS dataset for pancreas, kidney and liver segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • rater: The choice of rater providing the annotations.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_curvas_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Literal['1', '2', '3'] = '1', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
167def get_curvas_loader(
168    path: Union[os.PathLike, str],
169    batch_size: int,
170    patch_shape: Tuple[int, ...],
171    split: Literal['train', 'val', 'test'],
172    rater: Literal["1", "2", "3"] = "1",
173    resize_inputs: bool = False,
174    download: bool = False,
175    **kwargs
176) -> DataLoader:
177    """Get the CURVAS dataloader for pancreas, kidney and liver segmentation.
178
179    Args:
180        path: Filepath to a folder where the data is downloaded for further processing.
181        batch_size: The batch size for training.
182        patch_shape: The patch shape to use for training.
183        split: The choice of data split.
184        rater: The choice of rater providing the annotations.
185        resize_inputs: Whether to resize inputs to the desired patch shape.
186        download: Whether to download the data if it is not present.
187        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
188
189    Returns:
190        The DataLoader.
191    """
192    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
193    dataset = get_curvas_dataset(path, patch_shape, split, rater, resize_inputs, download, **ds_kwargs)
194    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the CURVAS dataloader for pancreas, kidney and liver segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • rater: The choice of rater providing the annotations.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.