torch_em.data.datasets.medical.curvas
The CURVAS dataset contains annotations for pancreas, kidney and liver in abdominal CT scans.
This dataset is from the challenge: https://curvas.grand-challenge.org. The dataset is located at: https://zenodo.org/records/12687192. Please cite tem if you use this dataset for your research.
1"""The CURVAS dataset contains annotations for pancreas, kidney and liver 2in abdominal CT scans. 3 4This dataset is from the challenge: https://curvas.grand-challenge.org. 5The dataset is located at: https://zenodo.org/records/12687192. 6Please cite tem if you use this dataset for your research. 7""" 8 9import os 10import shutil 11import subprocess 12from tqdm import tqdm 13from glob import glob 14from natsort import natsorted 15from typing import Tuple, Union, Literal, List 16 17import numpy as np 18 19from torch.utils.data import Dataset, DataLoader 20 21import torch_em 22 23from .. import util 24 25 26URL = "https://zenodo.org/records/12687192/files/training_set.zip" 27CHECKSUM = "1126a2205553ae1d4fe5fbaee7ea732aacc4f5a92b96504ed521c23e5a0e3f89" 28 29 30def _preprocess_data(data_dir): 31 import h5py 32 import nibabel as nib 33 34 h5_dir = os.path.join(os.path.dirname(data_dir), "data") 35 os.makedirs(h5_dir, exist_ok=True) 36 37 image_paths = natsorted(glob(os.path.join(data_dir, "*", "image.nii.gz"))) 38 for image_path in tqdm(image_paths, desc="Processing data"): 39 rater1_path = os.path.join(os.path.dirname(image_path), "annotation_1.nii.gz") 40 rater2_path = os.path.join(os.path.dirname(image_path), "annotation_2.nii.gz") 41 rater3_path = os.path.join(os.path.dirname(image_path), "annotation_3.nii.gz") 42 43 assert os.path.exists(rater1_path) and os.path.exists(rater2_path) and os.path.exists(rater3_path) 44 45 image = nib.load(image_path).get_fdata().astype("float32").transpose(2, 0, 1) 46 47 label_r1 = np.rint(nib.load(rater1_path).get_fdata()).astype("uint8").transpose(2, 0, 1) 48 label_r2 = np.rint(nib.load(rater2_path).get_fdata()).astype("uint8").transpose(2, 0, 1) 49 label_r3 = np.rint(nib.load(rater3_path).get_fdata()).astype("uint8").transpose(2, 0, 1) 50 51 fname = os.path.basename(os.path.dirname(image_path)) 52 chunks = (8, 512, 512) 53 with h5py.File(os.path.join(h5_dir, f"{fname}.h5"), "w") as f: 54 f.create_dataset("raw", data=image, compression="gzip", chunks=chunks) 55 f.create_dataset("labels/rater_1", data=label_r1, compression="gzip", chunks=chunks) 56 f.create_dataset("labels/rater_2", data=label_r2, compression="gzip", chunks=chunks) 57 f.create_dataset("labels/rater_3", data=label_r3, compression="gzip", chunks=chunks) 58 59 # Remove the nifti files as we don't need them anymore! 60 shutil.rmtree(data_dir) 61 62 63def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str: 64 """Download the CURVAS dataset. 65 66 Args: 67 path: Filepath to a folder where the data is downloaded for further processing. 68 download: Whether to download the data if it is not present. 69 70 Returns: 71 Filepath where the data is downloaded. 72 """ 73 data_dir = os.path.join(path, "data") 74 if os.path.exists(data_dir): 75 return data_dir 76 77 os.makedirs(path, exist_ok=True) 78 79 zip_path = os.path.join(path, "training_set.zip") 80 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 81 82 # HACK: The zip file is broken. We fix it using the following script. 83 fixed_zip_path = os.path.join(path, "training_set_fixed.zip") 84 subprocess.run(["zip", "-FF", zip_path, "--out", fixed_zip_path]) 85 subprocess.run(["unzip", fixed_zip_path, "-d", path]) 86 87 _preprocess_data(os.path.join(path, "training_set")) 88 89 # Remove the zip files as we don't need them anymore. 90 os.remove(zip_path) 91 os.remove(fixed_zip_path) 92 93 return data_dir 94 95 96def get_curvas_paths( 97 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 98) -> List[str]: 99 """Get paths to the CURVAS data. 100 101 Args: 102 path: Filepath to a folder where the data is downloaded for further processing. 103 split: The choice of data split. 104 download: Whether to download the data if it is not present. 105 106 Returns: 107 List of filepaths for the volumetric data. 108 """ 109 data_dir = get_curvas_data(path, download) 110 volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5"))) 111 112 if split == "train": 113 volume_paths = volume_paths[:10] 114 elif split == "val": 115 volume_paths = volume_paths[10:13] 116 elif split == "test": 117 volume_paths = volume_paths[13:] 118 else: 119 raise ValueError(f"'{split}' is not a valid split.") 120 121 return volume_paths 122 123 124def get_curvas_dataset( 125 path: Union[os.PathLike, str], 126 patch_shape: Tuple[int, ...], 127 split: Literal['train', 'val', 'test'], 128 rater: Literal["1", "2", "3"] = "1", 129 resize_inputs: bool = False, 130 download: bool = False, 131 **kwargs 132) -> Dataset: 133 """Get the CURVAS dataset for pancreas, kidney and liver segmentation. 134 135 Args: 136 path: Filepath to a folder where the data is downloaded for further processing. 137 patch_shape: The patch shape to use for training. 138 split: The choice of data split. 139 rater: The choice of rater providing the annotations. 140 resize_inputs: Whether to resize inputs to the desired patch shape. 141 download: Whether to download the data if it is not present. 142 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 143 144 Returns: 145 The segmentation dataset. 146 """ 147 volume_paths = get_curvas_paths(path, split, download) 148 149 if resize_inputs: 150 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 151 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 152 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 153 ) 154 155 return torch_em.default_segmentation_dataset( 156 raw_paths=volume_paths, 157 raw_key="raw", 158 label_paths=volume_paths, 159 label_key=f"labels/rater_{rater}", 160 patch_shape=patch_shape, 161 is_seg_dataset=True, 162 **kwargs, 163 ) 164 165 166def get_curvas_loader( 167 path: Union[os.PathLike, str], 168 batch_size: int, 169 patch_shape: Tuple[int, ...], 170 split: Literal['train', 'val', 'test'], 171 rater: Literal["1", "2", "3"] = "1", 172 resize_inputs: bool = False, 173 download: bool = False, 174 **kwargs 175) -> DataLoader: 176 """Get the CURVAS dataloader for pancreas, kidney and liver segmentation. 177 178 Args: 179 path: Filepath to a folder where the data is downloaded for further processing. 180 batch_size: The batch size for training. 181 patch_shape: The patch shape to use for training. 182 split: The choice of data split. 183 rater: The choice of rater providing the annotations. 184 resize_inputs: Whether to resize inputs to the desired patch shape. 185 download: Whether to download the data if it is not present. 186 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 187 188 Returns: 189 The DataLoader. 190 """ 191 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 192 dataset = get_curvas_dataset(path, patch_shape, split, rater, resize_inputs, download, **ds_kwargs) 193 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://zenodo.org/records/12687192/files/training_set.zip'
CHECKSUM =
'1126a2205553ae1d4fe5fbaee7ea732aacc4f5a92b96504ed521c23e5a0e3f89'
def
get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str:
64def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str: 65 """Download the CURVAS dataset. 66 67 Args: 68 path: Filepath to a folder where the data is downloaded for further processing. 69 download: Whether to download the data if it is not present. 70 71 Returns: 72 Filepath where the data is downloaded. 73 """ 74 data_dir = os.path.join(path, "data") 75 if os.path.exists(data_dir): 76 return data_dir 77 78 os.makedirs(path, exist_ok=True) 79 80 zip_path = os.path.join(path, "training_set.zip") 81 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 82 83 # HACK: The zip file is broken. We fix it using the following script. 84 fixed_zip_path = os.path.join(path, "training_set_fixed.zip") 85 subprocess.run(["zip", "-FF", zip_path, "--out", fixed_zip_path]) 86 subprocess.run(["unzip", fixed_zip_path, "-d", path]) 87 88 _preprocess_data(os.path.join(path, "training_set")) 89 90 # Remove the zip files as we don't need them anymore. 91 os.remove(zip_path) 92 os.remove(fixed_zip_path) 93 94 return data_dir
Download the CURVAS dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_curvas_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> List[str]:
97def get_curvas_paths( 98 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 99) -> List[str]: 100 """Get paths to the CURVAS data. 101 102 Args: 103 path: Filepath to a folder where the data is downloaded for further processing. 104 split: The choice of data split. 105 download: Whether to download the data if it is not present. 106 107 Returns: 108 List of filepaths for the volumetric data. 109 """ 110 data_dir = get_curvas_data(path, download) 111 volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5"))) 112 113 if split == "train": 114 volume_paths = volume_paths[:10] 115 elif split == "val": 116 volume_paths = volume_paths[10:13] 117 elif split == "test": 118 volume_paths = volume_paths[13:] 119 else: 120 raise ValueError(f"'{split}' is not a valid split.") 121 122 return volume_paths
Get paths to the CURVAS data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The choice of data split.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the volumetric data.
def
get_curvas_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Literal['1', '2', '3'] = '1', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
125def get_curvas_dataset( 126 path: Union[os.PathLike, str], 127 patch_shape: Tuple[int, ...], 128 split: Literal['train', 'val', 'test'], 129 rater: Literal["1", "2", "3"] = "1", 130 resize_inputs: bool = False, 131 download: bool = False, 132 **kwargs 133) -> Dataset: 134 """Get the CURVAS dataset for pancreas, kidney and liver segmentation. 135 136 Args: 137 path: Filepath to a folder where the data is downloaded for further processing. 138 patch_shape: The patch shape to use for training. 139 split: The choice of data split. 140 rater: The choice of rater providing the annotations. 141 resize_inputs: Whether to resize inputs to the desired patch shape. 142 download: Whether to download the data if it is not present. 143 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 144 145 Returns: 146 The segmentation dataset. 147 """ 148 volume_paths = get_curvas_paths(path, split, download) 149 150 if resize_inputs: 151 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 152 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 153 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 154 ) 155 156 return torch_em.default_segmentation_dataset( 157 raw_paths=volume_paths, 158 raw_key="raw", 159 label_paths=volume_paths, 160 label_key=f"labels/rater_{rater}", 161 patch_shape=patch_shape, 162 is_seg_dataset=True, 163 **kwargs, 164 )
Get the CURVAS dataset for pancreas, kidney and liver segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- rater: The choice of rater providing the annotations.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_curvas_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Literal['1', '2', '3'] = '1', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
167def get_curvas_loader( 168 path: Union[os.PathLike, str], 169 batch_size: int, 170 patch_shape: Tuple[int, ...], 171 split: Literal['train', 'val', 'test'], 172 rater: Literal["1", "2", "3"] = "1", 173 resize_inputs: bool = False, 174 download: bool = False, 175 **kwargs 176) -> DataLoader: 177 """Get the CURVAS dataloader for pancreas, kidney and liver segmentation. 178 179 Args: 180 path: Filepath to a folder where the data is downloaded for further processing. 181 batch_size: The batch size for training. 182 patch_shape: The patch shape to use for training. 183 split: The choice of data split. 184 rater: The choice of rater providing the annotations. 185 resize_inputs: Whether to resize inputs to the desired patch shape. 186 download: Whether to download the data if it is not present. 187 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 188 189 Returns: 190 The DataLoader. 191 """ 192 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 193 dataset = get_curvas_dataset(path, patch_shape, split, rater, resize_inputs, download, **ds_kwargs) 194 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the CURVAS dataloader for pancreas, kidney and liver segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- rater: The choice of rater providing the annotations.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.