torch_em.data.datasets.medical.curvas

The CURVAS dataset contains annotations for pancreas, kidney and liver in abdominal CT scans.

This dataset is from the challenge: https://curvas.grand-challenge.org. The dataset is located at: https://zenodo.org/records/12687192, and is from the publication https://doi.org/10.48550/arXiv.2505.08685 Please cite tem if you use this dataset for your research.

  1"""The CURVAS dataset contains annotations for pancreas, kidney and liver
  2in abdominal CT scans.
  3
  4This dataset is from the challenge: https://curvas.grand-challenge.org.
  5The dataset is located at: https://zenodo.org/records/12687192,
  6and is from the publication https://doi.org/10.48550/arXiv.2505.08685
  7Please cite tem if you use this dataset for your research.
  8"""
  9
 10import os
 11import shutil
 12import subprocess
 13from tqdm import tqdm
 14from glob import glob
 15from natsort import natsorted
 16from typing import Tuple, Union, Literal, List
 17
 18import numpy as np
 19
 20from torch.utils.data import Dataset, DataLoader
 21
 22import torch_em
 23
 24from .. import util
 25
 26
 27URL = "https://zenodo.org/records/12687192/files/training_set.zip"
 28CHECKSUM = "1126a2205553ae1d4fe5fbaee7ea732aacc4f5a92b96504ed521c23e5a0e3f89"
 29
 30
 31def _preprocess_data(data_dir):
 32    import h5py
 33    import nibabel as nib
 34
 35    h5_dir = os.path.join(os.path.dirname(data_dir), "data")
 36    os.makedirs(h5_dir, exist_ok=True)
 37
 38    image_paths = natsorted(glob(os.path.join(data_dir, "*", "image.nii.gz")))
 39    for image_path in tqdm(image_paths, desc="Processing data"):
 40        rater1_path = os.path.join(os.path.dirname(image_path), "annotation_1.nii.gz")
 41        rater2_path = os.path.join(os.path.dirname(image_path), "annotation_2.nii.gz")
 42        rater3_path = os.path.join(os.path.dirname(image_path), "annotation_3.nii.gz")
 43
 44        assert os.path.exists(rater1_path) and os.path.exists(rater2_path) and os.path.exists(rater3_path)
 45
 46        image = nib.load(image_path).get_fdata().astype("float32").transpose(2, 0, 1)
 47
 48        label_r1 = np.rint(nib.load(rater1_path).get_fdata()).astype("uint8").transpose(2, 0, 1)
 49        label_r2 = np.rint(nib.load(rater2_path).get_fdata()).astype("uint8").transpose(2, 0, 1)
 50        label_r3 = np.rint(nib.load(rater3_path).get_fdata()).astype("uint8").transpose(2, 0, 1)
 51
 52        fname = os.path.basename(os.path.dirname(image_path))
 53        chunks = (8, 512, 512)
 54        with h5py.File(os.path.join(h5_dir, f"{fname}.h5"), "w") as f:
 55            f.create_dataset("raw", data=image, compression="gzip", chunks=chunks)
 56            f.create_dataset("labels/rater_1", data=label_r1, compression="gzip", chunks=chunks)
 57            f.create_dataset("labels/rater_2", data=label_r2, compression="gzip", chunks=chunks)
 58            f.create_dataset("labels/rater_3", data=label_r3, compression="gzip", chunks=chunks)
 59
 60    # Remove the nifti files as we don't need them anymore!
 61    shutil.rmtree(data_dir)
 62
 63
 64def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 65    """Download the CURVAS dataset.
 66
 67    Args:
 68        path: Filepath to a folder where the data is downloaded for further processing.
 69        download: Whether to download the data if it is not present.
 70
 71    Returns:
 72        Filepath where the data is downloaded.
 73    """
 74    data_dir = os.path.join(path, "data")
 75    if os.path.exists(data_dir):
 76        return data_dir
 77
 78    os.makedirs(path, exist_ok=True)
 79
 80    zip_path = os.path.join(path, "training_set.zip")
 81    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 82
 83    # HACK: The zip file is broken. We fix it using the following script.
 84    fixed_zip_path = os.path.join(path, "training_set_fixed.zip")
 85    subprocess.run(["zip", "-FF", zip_path, "--out", fixed_zip_path])
 86    subprocess.run(["unzip", fixed_zip_path, "-d", path])
 87
 88    _preprocess_data(os.path.join(path, "training_set"))
 89
 90    # Remove the zip files as we don't need them anymore.
 91    os.remove(zip_path)
 92    os.remove(fixed_zip_path)
 93
 94    return data_dir
 95
 96
 97def get_curvas_paths(
 98    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 99) -> List[str]:
100    """Get paths to the CURVAS data.
101
102    Args:
103        path: Filepath to a folder where the data is downloaded for further processing.
104        split: The choice of data split.
105        download: Whether to download the data if it is not present.
106
107    Returns:
108        List of filepaths for the volumetric data.
109    """
110    data_dir = get_curvas_data(path, download)
111    volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5")))
112
113    if split == "train":
114        volume_paths = volume_paths[:10]
115    elif split == "val":
116        volume_paths = volume_paths[10:13]
117    elif split == "test":
118        volume_paths = volume_paths[13:]
119    else:
120        raise ValueError(f"'{split}' is not a valid split.")
121
122    return volume_paths
123
124
125def get_curvas_dataset(
126    path: Union[os.PathLike, str],
127    patch_shape: Tuple[int, ...],
128    split: Literal['train', 'val', 'test'],
129    rater: Literal["1", "2", "3"] = "1",
130    resize_inputs: bool = False,
131    download: bool = False,
132    **kwargs
133) -> Dataset:
134    """Get the CURVAS dataset for pancreas, kidney and liver segmentation.
135
136    Args:
137        path: Filepath to a folder where the data is downloaded for further processing.
138        patch_shape: The patch shape to use for training.
139        split: The choice of data split.
140        rater: The choice of rater providing the annotations.
141        resize_inputs: Whether to resize inputs to the desired patch shape.
142        download: Whether to download the data if it is not present.
143        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
144
145    Returns:
146        The segmentation dataset.
147    """
148    volume_paths = get_curvas_paths(path, split, download)
149
150    if resize_inputs:
151        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
152        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
153            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
154        )
155
156    return torch_em.default_segmentation_dataset(
157        raw_paths=volume_paths,
158        raw_key="raw",
159        label_paths=volume_paths,
160        label_key=f"labels/rater_{rater}",
161        patch_shape=patch_shape,
162        is_seg_dataset=True,
163        **kwargs,
164    )
165
166
167def get_curvas_loader(
168    path: Union[os.PathLike, str],
169    batch_size: int,
170    patch_shape: Tuple[int, ...],
171    split: Literal['train', 'val', 'test'],
172    rater: Literal["1", "2", "3"] = "1",
173    resize_inputs: bool = False,
174    download: bool = False,
175    **kwargs
176) -> DataLoader:
177    """Get the CURVAS dataloader for pancreas, kidney and liver segmentation.
178
179    Args:
180        path: Filepath to a folder where the data is downloaded for further processing.
181        batch_size: The batch size for training.
182        patch_shape: The patch shape to use for training.
183        split: The choice of data split.
184        rater: The choice of rater providing the annotations.
185        resize_inputs: Whether to resize inputs to the desired patch shape.
186        download: Whether to download the data if it is not present.
187        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
188
189    Returns:
190        The DataLoader.
191    """
192    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
193    dataset = get_curvas_dataset(path, patch_shape, split, rater, resize_inputs, download, **ds_kwargs)
194    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://zenodo.org/records/12687192/files/training_set.zip'
CHECKSUM = '1126a2205553ae1d4fe5fbaee7ea732aacc4f5a92b96504ed521c23e5a0e3f89'
def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str:
65def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str:
66    """Download the CURVAS dataset.
67
68    Args:
69        path: Filepath to a folder where the data is downloaded for further processing.
70        download: Whether to download the data if it is not present.
71
72    Returns:
73        Filepath where the data is downloaded.
74    """
75    data_dir = os.path.join(path, "data")
76    if os.path.exists(data_dir):
77        return data_dir
78
79    os.makedirs(path, exist_ok=True)
80
81    zip_path = os.path.join(path, "training_set.zip")
82    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
83
84    # HACK: The zip file is broken. We fix it using the following script.
85    fixed_zip_path = os.path.join(path, "training_set_fixed.zip")
86    subprocess.run(["zip", "-FF", zip_path, "--out", fixed_zip_path])
87    subprocess.run(["unzip", fixed_zip_path, "-d", path])
88
89    _preprocess_data(os.path.join(path, "training_set"))
90
91    # Remove the zip files as we don't need them anymore.
92    os.remove(zip_path)
93    os.remove(fixed_zip_path)
94
95    return data_dir

Download the CURVAS dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_curvas_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> List[str]:
 98def get_curvas_paths(
 99    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
100) -> List[str]:
101    """Get paths to the CURVAS data.
102
103    Args:
104        path: Filepath to a folder where the data is downloaded for further processing.
105        split: The choice of data split.
106        download: Whether to download the data if it is not present.
107
108    Returns:
109        List of filepaths for the volumetric data.
110    """
111    data_dir = get_curvas_data(path, download)
112    volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5")))
113
114    if split == "train":
115        volume_paths = volume_paths[:10]
116    elif split == "val":
117        volume_paths = volume_paths[10:13]
118    elif split == "test":
119        volume_paths = volume_paths[13:]
120    else:
121        raise ValueError(f"'{split}' is not a valid split.")
122
123    return volume_paths

Get paths to the CURVAS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the volumetric data.

def get_curvas_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Literal['1', '2', '3'] = '1', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
126def get_curvas_dataset(
127    path: Union[os.PathLike, str],
128    patch_shape: Tuple[int, ...],
129    split: Literal['train', 'val', 'test'],
130    rater: Literal["1", "2", "3"] = "1",
131    resize_inputs: bool = False,
132    download: bool = False,
133    **kwargs
134) -> Dataset:
135    """Get the CURVAS dataset for pancreas, kidney and liver segmentation.
136
137    Args:
138        path: Filepath to a folder where the data is downloaded for further processing.
139        patch_shape: The patch shape to use for training.
140        split: The choice of data split.
141        rater: The choice of rater providing the annotations.
142        resize_inputs: Whether to resize inputs to the desired patch shape.
143        download: Whether to download the data if it is not present.
144        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
145
146    Returns:
147        The segmentation dataset.
148    """
149    volume_paths = get_curvas_paths(path, split, download)
150
151    if resize_inputs:
152        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
153        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
154            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
155        )
156
157    return torch_em.default_segmentation_dataset(
158        raw_paths=volume_paths,
159        raw_key="raw",
160        label_paths=volume_paths,
161        label_key=f"labels/rater_{rater}",
162        patch_shape=patch_shape,
163        is_seg_dataset=True,
164        **kwargs,
165    )

Get the CURVAS dataset for pancreas, kidney and liver segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • rater: The choice of rater providing the annotations.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_curvas_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Literal['1', '2', '3'] = '1', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
168def get_curvas_loader(
169    path: Union[os.PathLike, str],
170    batch_size: int,
171    patch_shape: Tuple[int, ...],
172    split: Literal['train', 'val', 'test'],
173    rater: Literal["1", "2", "3"] = "1",
174    resize_inputs: bool = False,
175    download: bool = False,
176    **kwargs
177) -> DataLoader:
178    """Get the CURVAS dataloader for pancreas, kidney and liver segmentation.
179
180    Args:
181        path: Filepath to a folder where the data is downloaded for further processing.
182        batch_size: The batch size for training.
183        patch_shape: The patch shape to use for training.
184        split: The choice of data split.
185        rater: The choice of rater providing the annotations.
186        resize_inputs: Whether to resize inputs to the desired patch shape.
187        download: Whether to download the data if it is not present.
188        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
189
190    Returns:
191        The DataLoader.
192    """
193    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
194    dataset = get_curvas_dataset(path, patch_shape, split, rater, resize_inputs, download, **ds_kwargs)
195    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the CURVAS dataloader for pancreas, kidney and liver segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • rater: The choice of rater providing the annotations.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.