torch_em.data.datasets.medical.curvas
The CURVAS dataset contains annotations for pancreas, kidney and liver in abdominal CT scans.
This dataset is from the challenge: https://curvas.grand-challenge.org. The dataset is located at: https://zenodo.org/records/12687192, and is from the publication https://doi.org/10.48550/arXiv.2505.08685 Please cite tem if you use this dataset for your research.
1"""The CURVAS dataset contains annotations for pancreas, kidney and liver 2in abdominal CT scans. 3 4This dataset is from the challenge: https://curvas.grand-challenge.org. 5The dataset is located at: https://zenodo.org/records/12687192, 6and is from the publication https://doi.org/10.48550/arXiv.2505.08685 7Please cite tem if you use this dataset for your research. 8""" 9 10import os 11import shutil 12import subprocess 13from tqdm import tqdm 14from glob import glob 15from natsort import natsorted 16from typing import Tuple, Union, Literal, List 17 18import numpy as np 19 20from torch.utils.data import Dataset, DataLoader 21 22import torch_em 23 24from .. import util 25 26 27URL = "https://zenodo.org/records/12687192/files/training_set.zip" 28CHECKSUM = "1126a2205553ae1d4fe5fbaee7ea732aacc4f5a92b96504ed521c23e5a0e3f89" 29 30 31def _preprocess_data(data_dir): 32 import h5py 33 import nibabel as nib 34 35 h5_dir = os.path.join(os.path.dirname(data_dir), "data") 36 os.makedirs(h5_dir, exist_ok=True) 37 38 image_paths = natsorted(glob(os.path.join(data_dir, "*", "image.nii.gz"))) 39 for image_path in tqdm(image_paths, desc="Processing data"): 40 rater1_path = os.path.join(os.path.dirname(image_path), "annotation_1.nii.gz") 41 rater2_path = os.path.join(os.path.dirname(image_path), "annotation_2.nii.gz") 42 rater3_path = os.path.join(os.path.dirname(image_path), "annotation_3.nii.gz") 43 44 assert os.path.exists(rater1_path) and os.path.exists(rater2_path) and os.path.exists(rater3_path) 45 46 image = nib.load(image_path).get_fdata().astype("float32").transpose(2, 0, 1) 47 48 label_r1 = np.rint(nib.load(rater1_path).get_fdata()).astype("uint8").transpose(2, 0, 1) 49 label_r2 = np.rint(nib.load(rater2_path).get_fdata()).astype("uint8").transpose(2, 0, 1) 50 label_r3 = np.rint(nib.load(rater3_path).get_fdata()).astype("uint8").transpose(2, 0, 1) 51 52 fname = os.path.basename(os.path.dirname(image_path)) 53 chunks = (8, 512, 512) 54 with h5py.File(os.path.join(h5_dir, f"{fname}.h5"), "w") as f: 55 f.create_dataset("raw", data=image, compression="gzip", chunks=chunks) 56 f.create_dataset("labels/rater_1", data=label_r1, compression="gzip", chunks=chunks) 57 f.create_dataset("labels/rater_2", data=label_r2, compression="gzip", chunks=chunks) 58 f.create_dataset("labels/rater_3", data=label_r3, compression="gzip", chunks=chunks) 59 60 # Remove the nifti files as we don't need them anymore! 61 shutil.rmtree(data_dir) 62 63 64def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str: 65 """Download the CURVAS dataset. 66 67 Args: 68 path: Filepath to a folder where the data is downloaded for further processing. 69 download: Whether to download the data if it is not present. 70 71 Returns: 72 Filepath where the data is downloaded. 73 """ 74 data_dir = os.path.join(path, "data") 75 if os.path.exists(data_dir): 76 return data_dir 77 78 os.makedirs(path, exist_ok=True) 79 80 zip_path = os.path.join(path, "training_set.zip") 81 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 82 83 # HACK: The zip file is broken. We fix it using the following script. 84 fixed_zip_path = os.path.join(path, "training_set_fixed.zip") 85 subprocess.run(["zip", "-FF", zip_path, "--out", fixed_zip_path]) 86 subprocess.run(["unzip", fixed_zip_path, "-d", path]) 87 88 _preprocess_data(os.path.join(path, "training_set")) 89 90 # Remove the zip files as we don't need them anymore. 91 os.remove(zip_path) 92 os.remove(fixed_zip_path) 93 94 return data_dir 95 96 97def get_curvas_paths( 98 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 99) -> List[str]: 100 """Get paths to the CURVAS data. 101 102 Args: 103 path: Filepath to a folder where the data is downloaded for further processing. 104 split: The choice of data split. 105 download: Whether to download the data if it is not present. 106 107 Returns: 108 List of filepaths for the volumetric data. 109 """ 110 data_dir = get_curvas_data(path, download) 111 volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5"))) 112 113 if split == "train": 114 volume_paths = volume_paths[:10] 115 elif split == "val": 116 volume_paths = volume_paths[10:13] 117 elif split == "test": 118 volume_paths = volume_paths[13:] 119 else: 120 raise ValueError(f"'{split}' is not a valid split.") 121 122 return volume_paths 123 124 125def get_curvas_dataset( 126 path: Union[os.PathLike, str], 127 patch_shape: Tuple[int, ...], 128 split: Literal['train', 'val', 'test'], 129 rater: Literal["1", "2", "3"] = "1", 130 resize_inputs: bool = False, 131 download: bool = False, 132 **kwargs 133) -> Dataset: 134 """Get the CURVAS dataset for pancreas, kidney and liver segmentation. 135 136 Args: 137 path: Filepath to a folder where the data is downloaded for further processing. 138 patch_shape: The patch shape to use for training. 139 split: The choice of data split. 140 rater: The choice of rater providing the annotations. 141 resize_inputs: Whether to resize inputs to the desired patch shape. 142 download: Whether to download the data if it is not present. 143 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 144 145 Returns: 146 The segmentation dataset. 147 """ 148 volume_paths = get_curvas_paths(path, split, download) 149 150 if resize_inputs: 151 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 152 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 153 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 154 ) 155 156 return torch_em.default_segmentation_dataset( 157 raw_paths=volume_paths, 158 raw_key="raw", 159 label_paths=volume_paths, 160 label_key=f"labels/rater_{rater}", 161 patch_shape=patch_shape, 162 is_seg_dataset=True, 163 **kwargs, 164 ) 165 166 167def get_curvas_loader( 168 path: Union[os.PathLike, str], 169 batch_size: int, 170 patch_shape: Tuple[int, ...], 171 split: Literal['train', 'val', 'test'], 172 rater: Literal["1", "2", "3"] = "1", 173 resize_inputs: bool = False, 174 download: bool = False, 175 **kwargs 176) -> DataLoader: 177 """Get the CURVAS dataloader for pancreas, kidney and liver segmentation. 178 179 Args: 180 path: Filepath to a folder where the data is downloaded for further processing. 181 batch_size: The batch size for training. 182 patch_shape: The patch shape to use for training. 183 split: The choice of data split. 184 rater: The choice of rater providing the annotations. 185 resize_inputs: Whether to resize inputs to the desired patch shape. 186 download: Whether to download the data if it is not present. 187 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 188 189 Returns: 190 The DataLoader. 191 """ 192 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 193 dataset = get_curvas_dataset(path, patch_shape, split, rater, resize_inputs, download, **ds_kwargs) 194 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://zenodo.org/records/12687192/files/training_set.zip'
CHECKSUM =
'1126a2205553ae1d4fe5fbaee7ea732aacc4f5a92b96504ed521c23e5a0e3f89'
def
get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str:
65def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str: 66 """Download the CURVAS dataset. 67 68 Args: 69 path: Filepath to a folder where the data is downloaded for further processing. 70 download: Whether to download the data if it is not present. 71 72 Returns: 73 Filepath where the data is downloaded. 74 """ 75 data_dir = os.path.join(path, "data") 76 if os.path.exists(data_dir): 77 return data_dir 78 79 os.makedirs(path, exist_ok=True) 80 81 zip_path = os.path.join(path, "training_set.zip") 82 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 83 84 # HACK: The zip file is broken. We fix it using the following script. 85 fixed_zip_path = os.path.join(path, "training_set_fixed.zip") 86 subprocess.run(["zip", "-FF", zip_path, "--out", fixed_zip_path]) 87 subprocess.run(["unzip", fixed_zip_path, "-d", path]) 88 89 _preprocess_data(os.path.join(path, "training_set")) 90 91 # Remove the zip files as we don't need them anymore. 92 os.remove(zip_path) 93 os.remove(fixed_zip_path) 94 95 return data_dir
Download the CURVAS dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_curvas_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> List[str]:
98def get_curvas_paths( 99 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 100) -> List[str]: 101 """Get paths to the CURVAS data. 102 103 Args: 104 path: Filepath to a folder where the data is downloaded for further processing. 105 split: The choice of data split. 106 download: Whether to download the data if it is not present. 107 108 Returns: 109 List of filepaths for the volumetric data. 110 """ 111 data_dir = get_curvas_data(path, download) 112 volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5"))) 113 114 if split == "train": 115 volume_paths = volume_paths[:10] 116 elif split == "val": 117 volume_paths = volume_paths[10:13] 118 elif split == "test": 119 volume_paths = volume_paths[13:] 120 else: 121 raise ValueError(f"'{split}' is not a valid split.") 122 123 return volume_paths
Get paths to the CURVAS data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The choice of data split.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the volumetric data.
def
get_curvas_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Literal['1', '2', '3'] = '1', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
126def get_curvas_dataset( 127 path: Union[os.PathLike, str], 128 patch_shape: Tuple[int, ...], 129 split: Literal['train', 'val', 'test'], 130 rater: Literal["1", "2", "3"] = "1", 131 resize_inputs: bool = False, 132 download: bool = False, 133 **kwargs 134) -> Dataset: 135 """Get the CURVAS dataset for pancreas, kidney and liver segmentation. 136 137 Args: 138 path: Filepath to a folder where the data is downloaded for further processing. 139 patch_shape: The patch shape to use for training. 140 split: The choice of data split. 141 rater: The choice of rater providing the annotations. 142 resize_inputs: Whether to resize inputs to the desired patch shape. 143 download: Whether to download the data if it is not present. 144 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 145 146 Returns: 147 The segmentation dataset. 148 """ 149 volume_paths = get_curvas_paths(path, split, download) 150 151 if resize_inputs: 152 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 153 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 154 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 155 ) 156 157 return torch_em.default_segmentation_dataset( 158 raw_paths=volume_paths, 159 raw_key="raw", 160 label_paths=volume_paths, 161 label_key=f"labels/rater_{rater}", 162 patch_shape=patch_shape, 163 is_seg_dataset=True, 164 **kwargs, 165 )
Get the CURVAS dataset for pancreas, kidney and liver segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- rater: The choice of rater providing the annotations.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_curvas_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Literal['1', '2', '3'] = '1', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
168def get_curvas_loader( 169 path: Union[os.PathLike, str], 170 batch_size: int, 171 patch_shape: Tuple[int, ...], 172 split: Literal['train', 'val', 'test'], 173 rater: Literal["1", "2", "3"] = "1", 174 resize_inputs: bool = False, 175 download: bool = False, 176 **kwargs 177) -> DataLoader: 178 """Get the CURVAS dataloader for pancreas, kidney and liver segmentation. 179 180 Args: 181 path: Filepath to a folder where the data is downloaded for further processing. 182 batch_size: The batch size for training. 183 patch_shape: The patch shape to use for training. 184 split: The choice of data split. 185 rater: The choice of rater providing the annotations. 186 resize_inputs: Whether to resize inputs to the desired patch shape. 187 download: Whether to download the data if it is not present. 188 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 189 190 Returns: 191 The DataLoader. 192 """ 193 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 194 dataset = get_curvas_dataset(path, patch_shape, split, rater, resize_inputs, download, **ds_kwargs) 195 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the CURVAS dataloader for pancreas, kidney and liver segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- rater: The choice of rater providing the annotations.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.