torch_em.data.datasets.medical.kits
The KiTS dataset contains annotations for kidney, tumor and cyst segmentation in CT scans. NOTE: All patients have kidney and tumor annotations (however, not always have cysts annotated).
The label ids are - kidney: 1, tumor: 2, cyst: 3
This dataset is from the KiTS2 Challenge: https://kits-challenge.org/kits23/. Please cite it if you use this dataset for your research.
1"""The KiTS dataset contains annotations for kidney, tumor and cyst segmentation in CT scans. 2NOTE: All patients have kidney and tumor annotations (however, not always have cysts annotated). 3 4The label ids are - kidney: 1, tumor: 2, cyst: 3 5 6This dataset is from the KiTS2 Challenge: https://kits-challenge.org/kits23/. 7Please cite it if you use this dataset for your research. 8""" 9 10import os 11import subprocess 12from glob import glob 13from tqdm import tqdm 14from pathlib import Path 15from natsort import natsorted 16from typing import Union, Tuple, List, Optional, Literal 17 18import numpy as np 19 20from torch.utils.data import Dataset, DataLoader 21 22import torch_em 23 24from .. import util 25 26 27URL = "https://github.com/neheller/kits23" 28 29 30def get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str: 31 """Download the KiTS data. 32 33 Args: 34 path: Filepath to a folder where the data is downloaded for further processing. 35 download: Whether to download the data if it is not present. 36 37 Returns: 38 The folder where the dataset is downloaded and preprocessed. 39 """ 40 data_dir = os.path.join(path, "preprocessed") 41 if os.path.exists(data_dir): 42 return data_dir 43 44 os.makedirs(path, exist_ok=True) 45 46 if not download: 47 raise RuntimeError("The dataset is not found and download is set to False.") 48 49 # We clone the environment. 50 if not os.path.exists(os.path.join(path, "kits23")): 51 subprocess.run(["git", "clone", URL, os.path.join(path, "kits23")]) 52 53 # We install the package-only (with the assumption that the other necessary packages already exists). 54 chosen_patient_dir = natsorted(glob(os.path.join(path, "kits23", "dataset", "case*")))[-1] 55 if not os.path.exists(os.path.join(chosen_patient_dir, "imaging.nii.gz")): 56 subprocess.run(["pip", "install", "-e", os.path.join(path, "kits23"), "--no-deps"]) 57 58 print("The download might take several hours. Make sure you have consistent internet connection.") 59 60 # Run the CLI to download the input images. 61 subprocess.run(["kits23_download_data"]) 62 63 # Preprocess the images. 64 _preprocess_inputs(path) 65 66 return data_dir 67 68 69def _preprocess_inputs(path): 70 patient_dirs = glob(os.path.join(path, "kits23", "dataset", "case*")) 71 72 preprocessed_dir = os.path.join(path, "preprocessed") 73 os.makedirs(preprocessed_dir, exist_ok=True) 74 75 for patient_dir in tqdm(patient_dirs, desc="Preprocessing inputs"): 76 patient_id = os.path.basename(patient_dir) 77 patient_path = os.path.join(preprocessed_dir, Path(patient_id).with_suffix(".h5")) 78 79 if os.path.exists(patient_path): 80 continue 81 82 # Next, we find all rater annotations. 83 kidney_anns = natsorted(glob(os.path.join(patient_dir, "instances", "kidney_instance-1*"))) 84 tumor_anns = natsorted(glob(os.path.join(patient_dir, "instances", "tumor_instance*"))) 85 cyst_anns = natsorted(glob(os.path.join(patient_dir, "instances", "cyst_instance*"))) 86 87 import h5py 88 import nibabel as nib 89 90 with h5py.File(patient_path, "w") as f: 91 # Input image. 92 raw = nib.load(os.path.join(patient_dir, "imaging.nii.gz")).get_fdata() 93 f.create_dataset("raw", data=raw, compression="gzip") 94 95 # Valid segmentation masks for all classes. 96 labels = nib.load(os.path.join(patient_dir, "segmentation.nii.gz")).get_fdata() 97 assert raw.shape == labels.shape, "The shape of inputs and corresponding segmentation does not match." 98 99 f.create_dataset("labels/all", data=labels, compression="gzip") 100 101 # Add annotations for kidneys per rater. 102 _k_exclusive = False 103 if not kidney_anns: 104 _k_exclusive = True 105 kidney_anns = natsorted(glob(os.path.join(patient_dir, "instances", "kidney_instance-2*"))) 106 107 assert kidney_anns, f"There must be kidney annotations for '{patient_id}'." 108 for p in kidney_anns: 109 masks = np.zeros_like(raw) 110 rater_id = p[-8] # The rater count 111 112 # Get the other kidney instance. 113 if _k_exclusive: 114 print("The kidney annotations are numbered strangely.") 115 other_p = p.replace("instance-2", "instance-3") 116 else: 117 other_p = p.replace("instance-1", "instance-2") 118 119 # Merge both left and right kidney as one semantic id. 120 masks[nib.load(p).get_fdata() > 0] = 1 121 if os.path.exists(other_p): 122 masks[nib.load(other_p).get_fdata() > 0] = 1 123 else: 124 print(f"The second kidney instance does not exist for patient: '{patient_id}'.") 125 126 # Create a hierarchy for the particular rater's kidney annotations. 127 f.create_dataset(f"labels/kidney/rater_{rater_id}", data=masks, compression="gzip") 128 129 # Add annotations for tumor per rater. 130 assert tumor_anns, f"There must be tumor annotations for '{patient_id}'." 131 # Find the raters. 132 raters = [p[-8] for p in tumor_anns] 133 # Get masks per rater 134 unique_raters = np.unique(raters) 135 for rater in unique_raters: 136 masks = np.zeros_like(raw) 137 for p in glob(os.path.join(patient_dir, "instances", f"tumor_instance*-{rater}.nii.gz")): 138 masks[nib.load(p).get_fdata() > 0] = 1 139 140 f.create_dataset(f"labels/tumor/rater_{rater}", data=masks, compression="gzip") 141 142 # Add annotations for cysts per rater. 143 if cyst_anns: 144 # Find the raters first 145 raters = [p[-8] for p in cyst_anns] 146 # Get masks per rater 147 unique_raters = np.unique(raters) 148 for rater in unique_raters: 149 masks = np.zeros_like(raw) 150 for p in glob(os.path.join(patient_dir, "instances", f"cyst_instance*-{rater}.nii.gz")): 151 masks[nib.load(p).get_fdata() > 0] = 1 152 153 f.create_dataset(f"labels/cyst/rater_{rater}", data=masks, compression="gzip") 154 155 156def get_kits_paths(path: Union[os.PathLike, str], download: bool = False) -> List[str]: 157 """Get paths to the KiTS data. 158 159 Args: 160 path: Filepath to a folder where the data is downloaded for further processing. 161 download: Whether to download the data if it is not present. 162 163 Returns: 164 List of filepaths for the input data. 165 """ 166 data_dir = get_kits_data(path, download) 167 volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5"))) 168 return volume_paths 169 170 171def get_kits_dataset( 172 path: Union[os.PathLike, str], 173 patch_shape: Tuple[int, ...], 174 rater: Optional[Literal[1, 2, 3]] = None, 175 annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None, 176 resize_inputs: bool = False, 177 download: bool = False, 178 **kwargs 179) -> Dataset: 180 """Get the KiTS dataset for kidney, tumor and cyst segmentation. 181 182 Args: 183 path: Filepath to a folder where the data is downloaded for further processing. 184 patch_shape: The patch shape to use for training. 185 rater: The choice of rater. 186 annotation_choice: The choice of annotations. 187 resize_inputs: Whether to resize inputs to the desired patch shape. 188 download: Whether to download the data if it is not present. 189 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 190 191 Returns: 192 The segmentation dataset. 193 """ 194 volume_paths = get_kits_paths(path, download) 195 196 if resize_inputs: 197 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 198 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 199 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 200 ) 201 202 # TODO: simplify the design below later, to allow: 203 # - multi-rater label loading. 204 # - multi-annotation label loading. 205 # (for now, only 1v1 annotation-rater loading is supported). 206 if rater is None and annotation_choice is None: 207 label_key = "labels/all" 208 else: 209 assert rater is not None and annotation_choice is not None, \ 210 "Both rater and annotation_choice must be specified together." 211 212 label_key = f"labels/{annotation_choice}/rater_{rater}" 213 214 return torch_em.default_segmentation_dataset( 215 raw_paths=volume_paths, 216 raw_key="raw", 217 label_paths=volume_paths, 218 label_key=label_key, 219 patch_shape=patch_shape, 220 **kwargs 221 ) 222 223 224def get_kits_loader( 225 path: Union[os.PathLike, str], 226 batch_size: int, 227 patch_shape: Tuple[int, ...], 228 rater: Optional[Literal[1, 2, 3]] = None, 229 annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None, 230 resize_inputs: bool = False, 231 download: bool = False, 232 **kwargs 233) -> DataLoader: 234 """Get the KiTS dataloader for kidney, tumor and cyst segmentation. 235 236 Args: 237 path: Filepath to a folder where the data is downloaded for further processing. 238 batch_size: The batch size for training. 239 patch_shape: The patch shape to use for training. 240 rater: The choice of rater. 241 annotation_choice: The choice of annotations. 242 resize_inputs: Whether to resize inputs to the desired patch shape. 243 download: Whether to download the data if it is not present. 244 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 245 246 Returns: 247 The DataLoader. 248 """ 249 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 250 dataset = get_kits_dataset(path, patch_shape, rater, annotation_choice, resize_inputs, download, **ds_kwargs) 251 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://github.com/neheller/kits23'
def
get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str:
31def get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str: 32 """Download the KiTS data. 33 34 Args: 35 path: Filepath to a folder where the data is downloaded for further processing. 36 download: Whether to download the data if it is not present. 37 38 Returns: 39 The folder where the dataset is downloaded and preprocessed. 40 """ 41 data_dir = os.path.join(path, "preprocessed") 42 if os.path.exists(data_dir): 43 return data_dir 44 45 os.makedirs(path, exist_ok=True) 46 47 if not download: 48 raise RuntimeError("The dataset is not found and download is set to False.") 49 50 # We clone the environment. 51 if not os.path.exists(os.path.join(path, "kits23")): 52 subprocess.run(["git", "clone", URL, os.path.join(path, "kits23")]) 53 54 # We install the package-only (with the assumption that the other necessary packages already exists). 55 chosen_patient_dir = natsorted(glob(os.path.join(path, "kits23", "dataset", "case*")))[-1] 56 if not os.path.exists(os.path.join(chosen_patient_dir, "imaging.nii.gz")): 57 subprocess.run(["pip", "install", "-e", os.path.join(path, "kits23"), "--no-deps"]) 58 59 print("The download might take several hours. Make sure you have consistent internet connection.") 60 61 # Run the CLI to download the input images. 62 subprocess.run(["kits23_download_data"]) 63 64 # Preprocess the images. 65 _preprocess_inputs(path) 66 67 return data_dir
Download the KiTS data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
The folder where the dataset is downloaded and preprocessed.
def
get_kits_paths(path: Union[os.PathLike, str], download: bool = False) -> List[str]:
157def get_kits_paths(path: Union[os.PathLike, str], download: bool = False) -> List[str]: 158 """Get paths to the KiTS data. 159 160 Args: 161 path: Filepath to a folder where the data is downloaded for further processing. 162 download: Whether to download the data if it is not present. 163 164 Returns: 165 List of filepaths for the input data. 166 """ 167 data_dir = get_kits_data(path, download) 168 volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5"))) 169 return volume_paths
Get paths to the KiTS data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the input data.
def
get_kits_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], rater: Optional[Literal[1, 2, 3]] = None, annotation_choice: Optional[Literal['kidney', 'tumor', 'cyst']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
172def get_kits_dataset( 173 path: Union[os.PathLike, str], 174 patch_shape: Tuple[int, ...], 175 rater: Optional[Literal[1, 2, 3]] = None, 176 annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None, 177 resize_inputs: bool = False, 178 download: bool = False, 179 **kwargs 180) -> Dataset: 181 """Get the KiTS dataset for kidney, tumor and cyst segmentation. 182 183 Args: 184 path: Filepath to a folder where the data is downloaded for further processing. 185 patch_shape: The patch shape to use for training. 186 rater: The choice of rater. 187 annotation_choice: The choice of annotations. 188 resize_inputs: Whether to resize inputs to the desired patch shape. 189 download: Whether to download the data if it is not present. 190 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 191 192 Returns: 193 The segmentation dataset. 194 """ 195 volume_paths = get_kits_paths(path, download) 196 197 if resize_inputs: 198 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 199 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 200 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 201 ) 202 203 # TODO: simplify the design below later, to allow: 204 # - multi-rater label loading. 205 # - multi-annotation label loading. 206 # (for now, only 1v1 annotation-rater loading is supported). 207 if rater is None and annotation_choice is None: 208 label_key = "labels/all" 209 else: 210 assert rater is not None and annotation_choice is not None, \ 211 "Both rater and annotation_choice must be specified together." 212 213 label_key = f"labels/{annotation_choice}/rater_{rater}" 214 215 return torch_em.default_segmentation_dataset( 216 raw_paths=volume_paths, 217 raw_key="raw", 218 label_paths=volume_paths, 219 label_key=label_key, 220 patch_shape=patch_shape, 221 **kwargs 222 )
Get the KiTS dataset for kidney, tumor and cyst segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- rater: The choice of rater.
- annotation_choice: The choice of annotations.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_kits_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], rater: Optional[Literal[1, 2, 3]] = None, annotation_choice: Optional[Literal['kidney', 'tumor', 'cyst']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
225def get_kits_loader( 226 path: Union[os.PathLike, str], 227 batch_size: int, 228 patch_shape: Tuple[int, ...], 229 rater: Optional[Literal[1, 2, 3]] = None, 230 annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None, 231 resize_inputs: bool = False, 232 download: bool = False, 233 **kwargs 234) -> DataLoader: 235 """Get the KiTS dataloader for kidney, tumor and cyst segmentation. 236 237 Args: 238 path: Filepath to a folder where the data is downloaded for further processing. 239 batch_size: The batch size for training. 240 patch_shape: The patch shape to use for training. 241 rater: The choice of rater. 242 annotation_choice: The choice of annotations. 243 resize_inputs: Whether to resize inputs to the desired patch shape. 244 download: Whether to download the data if it is not present. 245 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 246 247 Returns: 248 The DataLoader. 249 """ 250 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 251 dataset = get_kits_dataset(path, patch_shape, rater, annotation_choice, resize_inputs, download, **ds_kwargs) 252 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the KiTS dataloader for kidney, tumor and cyst segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- rater: The choice of rater.
- annotation_choice: The choice of annotations.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.