torch_em.data.datasets.medical.kits
The KiTS dataset contains annotations for kidney, tumor and cyst segmentation in CT scans. NOTE: All patients have kidney and tumor annotations (however, not always have cysts annotated).
The label ids are - kidney: 1, tumor: 2, cyst: 3
This dataset is from the KiTS2 Challenge: https://kits-challenge.org/kits23/. Please cite it if you use this dataset for your research.
1"""The KiTS dataset contains annotations for kidney, tumor and cyst segmentation in CT scans. 2NOTE: All patients have kidney and tumor annotations (however, not always have cysts annotated). 3 4The label ids are - kidney: 1, tumor: 2, cyst: 3 5 6This dataset is from the KiTS2 Challenge: https://kits-challenge.org/kits23/. 7Please cite it if you use this dataset for your research. 8""" 9 10import os 11import json 12import subprocess 13from glob import glob 14from tqdm import tqdm 15from pathlib import Path 16from natsort import natsorted 17from typing import Union, Tuple, List, Optional, Literal 18 19import numpy as np 20 21from sklearn.model_selection import train_test_split 22from torch.utils.data import Dataset, DataLoader 23 24import torch_em 25 26from .. import util 27 28 29URL = "https://github.com/neheller/kits23" 30VALID_SPLITS = ("train", "val", "test") 31 32 33def get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str: 34 """Download the KiTS data. 35 36 Args: 37 path: Filepath to a folder where the data is downloaded for further processing. 38 download: Whether to download the data if it is not present. 39 40 Returns: 41 The folder where the dataset is downloaded and preprocessed. 42 """ 43 data_dir = os.path.join(path, "preprocessed") 44 if os.path.exists(data_dir) and all(os.path.exists(os.path.join(data_dir, s)) for s in VALID_SPLITS): 45 return data_dir 46 47 os.makedirs(path, exist_ok=True) 48 49 if not download: 50 raise RuntimeError("The dataset is not found and download is set to False.") 51 52 # We clone the environment. 53 if not os.path.exists(os.path.join(path, "kits23")): 54 subprocess.run(["git", "clone", URL, os.path.join(path, "kits23")]) 55 56 # We install the package-only (with the assumption that the other necessary packages already exists). 57 chosen_patient_dir = natsorted(glob(os.path.join(path, "kits23", "dataset", "case*")))[-1] 58 if not os.path.exists(os.path.join(chosen_patient_dir, "imaging.nii.gz")): 59 subprocess.run(["pip", "install", "-e", os.path.join(path, "kits23"), "--no-deps"]) 60 61 print("The download might take several hours. Make sure you have consistent internet connection.") 62 63 # Run the CLI to download the input images. 64 subprocess.run(["kits23_download_data"]) 65 66 # Preprocess the images. 67 _preprocess_inputs(path) 68 69 return data_dir 70 71 72def _preprocess_inputs(path): 73 patient_dirs = glob(os.path.join(path, "kits23", "dataset", "case*")) 74 75 preprocessed_dir = os.path.join(path, "preprocessed") 76 77 for split in VALID_SPLITS: 78 os.makedirs(os.path.join(preprocessed_dir, split), exist_ok=True) 79 80 json_path = os.path.join(path, "splits_kits.json") 81 82 if os.path.exists(json_path): 83 with open(json_path) as f: 84 split_info = json.load(f) 85 split_map = { 86 os.path.join(path, "kits23", "dataset", Path(fname).stem): split 87 for split, fnames in split_info.items() 88 for fname in fnames 89 } 90 else: 91 train_dirs, test_dirs = train_test_split(patient_dirs, test_size=0.25, random_state=42) 92 train_dirs, val_dirs = train_test_split(train_dirs, test_size=0.1, random_state=42) 93 split_map = { 94 **{d: "train" for d in train_dirs}, 95 **{d: "val" for d in val_dirs}, 96 **{d: "test" for d in test_dirs}, 97 } 98 split_info = {"train": [], "val": [], "test": []} 99 100 for patient_dir in tqdm(patient_dirs, desc="Preprocessing inputs"): 101 patient_id = os.path.basename(patient_dir) 102 split = split_map[patient_dir] 103 patient_fname = Path(patient_id).with_suffix(".h5") 104 patient_path = os.path.join(preprocessed_dir, split, patient_fname) 105 106 if not os.path.exists(json_path): 107 split_info[split].append(str(patient_fname)) 108 109 if os.path.exists(patient_path): 110 continue 111 112 # Next, we find all rater annotations. 113 kidney_anns = natsorted(glob(os.path.join(patient_dir, "instances", "kidney_instance-1*"))) 114 tumor_anns = natsorted(glob(os.path.join(patient_dir, "instances", "tumor_instance*"))) 115 cyst_anns = natsorted(glob(os.path.join(patient_dir, "instances", "cyst_instance*"))) 116 117 import h5py 118 import nibabel as nib 119 120 with h5py.File(patient_path, "w") as f: 121 # Input image. 122 raw = nib.load(os.path.join(patient_dir, "imaging.nii.gz")).get_fdata() 123 f.create_dataset("raw", data=raw, compression="gzip") 124 125 # Valid segmentation masks for all classes. 126 labels = nib.load(os.path.join(patient_dir, "segmentation.nii.gz")).get_fdata() 127 assert raw.shape == labels.shape, "The shape of inputs and corresponding segmentation does not match." 128 f.create_dataset("labels/all", data=labels, compression="gzip") 129 130 # Add annotations for kidneys per rater. 131 _k_exclusive = False 132 if not kidney_anns: 133 _k_exclusive = True 134 kidney_anns = natsorted(glob(os.path.join(patient_dir, "instances", "kidney_instance-2*"))) 135 136 assert kidney_anns, f"There must be kidney annotations for '{patient_id}'." 137 for p in kidney_anns: 138 masks = np.zeros_like(raw) 139 rater_id = p[-8] # The rater count 140 141 # Get the other kidney instance. 142 if _k_exclusive: 143 print("The kidney annotations are numbered strangely.") 144 other_p = p.replace("instance-2", "instance-3") 145 else: 146 other_p = p.replace("instance-1", "instance-2") 147 148 # Merge both left and right kidney as one semantic id. 149 masks[nib.load(p).get_fdata() > 0] = 1 150 if os.path.exists(other_p): 151 masks[nib.load(other_p).get_fdata() > 0] = 1 152 else: 153 print(f"The second kidney instance does not exist for patient: '{patient_id}'.") 154 155 # Create a hierarchy for the particular rater's kidney annotations. 156 f.create_dataset(f"labels/kidney/rater_{rater_id}", data=masks, compression="gzip") 157 158 # Add annotations for tumor per rater. 159 assert tumor_anns, f"There must be tumor annotations for '{patient_id}'." 160 # Find the raters. 161 raters = [p[-8] for p in tumor_anns] 162 # Get masks per rater 163 unique_raters = np.unique(raters) 164 for rater in unique_raters: 165 masks = np.zeros_like(raw) 166 for p in glob(os.path.join(patient_dir, "instances", f"tumor_instance*-{rater}.nii.gz")): 167 masks[nib.load(p).get_fdata() > 0] = 1 168 169 f.create_dataset(f"labels/tumor/rater_{rater}", data=masks, compression="gzip") 170 171 # Add annotations for cysts per rater. 172 if cyst_anns: 173 # Find the raters first 174 raters = [p[-8] for p in cyst_anns] 175 # Get masks per rater 176 unique_raters = np.unique(raters) 177 for rater in unique_raters: 178 masks = np.zeros_like(raw) 179 for p in glob(os.path.join(patient_dir, "instances", f"cyst_instance*-{rater}.nii.gz")): 180 masks[nib.load(p).get_fdata() > 0] = 1 181 182 f.create_dataset(f"labels/cyst/rater_{rater}", data=masks, compression="gzip") 183 184 if not os.path.exists(json_path): 185 with open(json_path, "w") as f: 186 json.dump(split_info, f, indent=2) 187 188 189def get_kits_paths( 190 path: Union[os.PathLike, str], split: Literal["train", "val", "test"], download: bool = False 191) -> List[str]: 192 """Get paths to the KiTS data. 193 194 Args: 195 path: Filepath to a folder where the data is downloaded for further processing. 196 split: Which data split to use. 197 download: Whether to download the data if it is not present. 198 199 Returns: 200 List of filepaths for the input data. 201 """ 202 203 if split not in VALID_SPLITS: 204 raise ValueError(f"Invalid split '{split}'. Must be one of {VALID_SPLITS}.") 205 206 get_kits_data(path, download) 207 208 split_dir = os.path.join(path, "preprocessed", split) 209 if not os.path.exists(split_dir): 210 raise RuntimeError(f"Split folder '{split_dir}' does not exist.") 211 212 volume_paths = natsorted(glob(os.path.join(split_dir, "*.h5"))) 213 if not volume_paths: 214 raise RuntimeError(f"No .h5 files found in split folder '{split_dir}'.") 215 216 return volume_paths 217 218 219def get_kits_dataset( 220 path: Union[os.PathLike, str], 221 patch_shape: Tuple[int, ...], 222 split: Literal["train", "val", "test"], 223 rater: Optional[Literal[1, 2, 3]] = None, 224 annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None, 225 resize_inputs: bool = False, 226 download: bool = False, 227 **kwargs 228) -> Dataset: 229 """Get the KiTS dataset for kidney, tumor and cyst segmentation. 230 231 Args: 232 path: Filepath to a folder where the data is downloaded for further processing. 233 patch_shape: The patch shape to use for training. 234 split: Which data split to use. 235 rater: The choice of rater. 236 annotation_choice: The choice of annotations. 237 resize_inputs: Whether to resize inputs to the desired patch shape. 238 download: Whether to download the data if it is not present. 239 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 240 241 Returns: 242 The segmentation dataset. 243 """ 244 volume_paths = get_kits_paths(path, split, download) 245 246 if resize_inputs: 247 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 248 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 249 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 250 ) 251 252 # TODO: simplify the design below later, to allow: 253 # - multi-rater label loading. 254 # - multi-annotation label loading. 255 # (for now, only 1v1 annotation-rater loading is supported). 256 if rater is None and annotation_choice is None: 257 label_key = "labels/all" 258 else: 259 assert rater is not None and annotation_choice is not None, \ 260 "Both rater and annotation_choice must be specified together." 261 262 label_key = f"labels/{annotation_choice}/rater_{rater}" 263 264 return torch_em.default_segmentation_dataset( 265 raw_paths=volume_paths, 266 raw_key="raw", 267 label_paths=volume_paths, 268 label_key=label_key, 269 patch_shape=patch_shape, 270 **kwargs 271 ) 272 273 274def get_kits_loader( 275 path: Union[os.PathLike, str], 276 batch_size: int, 277 patch_shape: Tuple[int, ...], 278 split: Literal["train", "val", "test"], 279 rater: Optional[Literal[1, 2, 3]] = None, 280 annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None, 281 resize_inputs: bool = False, 282 download: bool = False, 283 **kwargs 284) -> DataLoader: 285 """Get the KiTS dataloader for kidney, tumor and cyst segmentation. 286 287 Args: 288 path: Filepath to a folder where the data is downloaded for further processing. 289 batch_size: The batch size for training. 290 patch_shape: The patch shape to use for training. 291 split: Which data split to use. 292 rater: The choice of rater. 293 annotation_choice: The choice of annotations. 294 resize_inputs: Whether to resize inputs to the desired patch shape. 295 download: Whether to download the data if it is not present. 296 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 297 298 Returns: 299 The DataLoader. 300 """ 301 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 302 dataset = get_kits_dataset(path, patch_shape, split, rater, annotation_choice, resize_inputs, download, **ds_kwargs) 303 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://github.com/neheller/kits23'
VALID_SPLITS =
('train', 'val', 'test')
def
get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str:
34def get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str: 35 """Download the KiTS data. 36 37 Args: 38 path: Filepath to a folder where the data is downloaded for further processing. 39 download: Whether to download the data if it is not present. 40 41 Returns: 42 The folder where the dataset is downloaded and preprocessed. 43 """ 44 data_dir = os.path.join(path, "preprocessed") 45 if os.path.exists(data_dir) and all(os.path.exists(os.path.join(data_dir, s)) for s in VALID_SPLITS): 46 return data_dir 47 48 os.makedirs(path, exist_ok=True) 49 50 if not download: 51 raise RuntimeError("The dataset is not found and download is set to False.") 52 53 # We clone the environment. 54 if not os.path.exists(os.path.join(path, "kits23")): 55 subprocess.run(["git", "clone", URL, os.path.join(path, "kits23")]) 56 57 # We install the package-only (with the assumption that the other necessary packages already exists). 58 chosen_patient_dir = natsorted(glob(os.path.join(path, "kits23", "dataset", "case*")))[-1] 59 if not os.path.exists(os.path.join(chosen_patient_dir, "imaging.nii.gz")): 60 subprocess.run(["pip", "install", "-e", os.path.join(path, "kits23"), "--no-deps"]) 61 62 print("The download might take several hours. Make sure you have consistent internet connection.") 63 64 # Run the CLI to download the input images. 65 subprocess.run(["kits23_download_data"]) 66 67 # Preprocess the images. 68 _preprocess_inputs(path) 69 70 return data_dir
Download the KiTS data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
The folder where the dataset is downloaded and preprocessed.
def
get_kits_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> List[str]:
190def get_kits_paths( 191 path: Union[os.PathLike, str], split: Literal["train", "val", "test"], download: bool = False 192) -> List[str]: 193 """Get paths to the KiTS data. 194 195 Args: 196 path: Filepath to a folder where the data is downloaded for further processing. 197 split: Which data split to use. 198 download: Whether to download the data if it is not present. 199 200 Returns: 201 List of filepaths for the input data. 202 """ 203 204 if split not in VALID_SPLITS: 205 raise ValueError(f"Invalid split '{split}'. Must be one of {VALID_SPLITS}.") 206 207 get_kits_data(path, download) 208 209 split_dir = os.path.join(path, "preprocessed", split) 210 if not os.path.exists(split_dir): 211 raise RuntimeError(f"Split folder '{split_dir}' does not exist.") 212 213 volume_paths = natsorted(glob(os.path.join(split_dir, "*.h5"))) 214 if not volume_paths: 215 raise RuntimeError(f"No .h5 files found in split folder '{split_dir}'.") 216 217 return volume_paths
Get paths to the KiTS data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: Which data split to use.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the input data.
def
get_kits_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Optional[Literal[1, 2, 3]] = None, annotation_choice: Optional[Literal['kidney', 'tumor', 'cyst']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
220def get_kits_dataset( 221 path: Union[os.PathLike, str], 222 patch_shape: Tuple[int, ...], 223 split: Literal["train", "val", "test"], 224 rater: Optional[Literal[1, 2, 3]] = None, 225 annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None, 226 resize_inputs: bool = False, 227 download: bool = False, 228 **kwargs 229) -> Dataset: 230 """Get the KiTS dataset for kidney, tumor and cyst segmentation. 231 232 Args: 233 path: Filepath to a folder where the data is downloaded for further processing. 234 patch_shape: The patch shape to use for training. 235 split: Which data split to use. 236 rater: The choice of rater. 237 annotation_choice: The choice of annotations. 238 resize_inputs: Whether to resize inputs to the desired patch shape. 239 download: Whether to download the data if it is not present. 240 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 241 242 Returns: 243 The segmentation dataset. 244 """ 245 volume_paths = get_kits_paths(path, split, download) 246 247 if resize_inputs: 248 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 249 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 250 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 251 ) 252 253 # TODO: simplify the design below later, to allow: 254 # - multi-rater label loading. 255 # - multi-annotation label loading. 256 # (for now, only 1v1 annotation-rater loading is supported). 257 if rater is None and annotation_choice is None: 258 label_key = "labels/all" 259 else: 260 assert rater is not None and annotation_choice is not None, \ 261 "Both rater and annotation_choice must be specified together." 262 263 label_key = f"labels/{annotation_choice}/rater_{rater}" 264 265 return torch_em.default_segmentation_dataset( 266 raw_paths=volume_paths, 267 raw_key="raw", 268 label_paths=volume_paths, 269 label_key=label_key, 270 patch_shape=patch_shape, 271 **kwargs 272 )
Get the KiTS dataset for kidney, tumor and cyst segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: Which data split to use.
- rater: The choice of rater.
- annotation_choice: The choice of annotations.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset.
Returns:
The segmentation dataset.
def
get_kits_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Optional[Literal[1, 2, 3]] = None, annotation_choice: Optional[Literal['kidney', 'tumor', 'cyst']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
275def get_kits_loader( 276 path: Union[os.PathLike, str], 277 batch_size: int, 278 patch_shape: Tuple[int, ...], 279 split: Literal["train", "val", "test"], 280 rater: Optional[Literal[1, 2, 3]] = None, 281 annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None, 282 resize_inputs: bool = False, 283 download: bool = False, 284 **kwargs 285) -> DataLoader: 286 """Get the KiTS dataloader for kidney, tumor and cyst segmentation. 287 288 Args: 289 path: Filepath to a folder where the data is downloaded for further processing. 290 batch_size: The batch size for training. 291 patch_shape: The patch shape to use for training. 292 split: Which data split to use. 293 rater: The choice of rater. 294 annotation_choice: The choice of annotations. 295 resize_inputs: Whether to resize inputs to the desired patch shape. 296 download: Whether to download the data if it is not present. 297 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 298 299 Returns: 300 The DataLoader. 301 """ 302 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 303 dataset = get_kits_dataset(path, patch_shape, split, rater, annotation_choice, resize_inputs, download, **ds_kwargs) 304 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the KiTS dataloader for kidney, tumor and cyst segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: Which data split to use.
- rater: The choice of rater.
- annotation_choice: The choice of annotations.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_datasetor for the PyTorch DataLoader.
Returns:
The DataLoader.