torch_em.data.datasets.medical.kits

The KiTS dataset contains annotations for kidney, tumor and cyst segmentation in CT scans. NOTE: All patients have kidney and tumor annotations (however, not always have cysts annotated).

The label ids are - kidney: 1, tumor: 2, cyst: 3

This dataset is from the KiTS2 Challenge: https://kits-challenge.org/kits23/. Please cite it if you use this dataset for your research.

  1"""The KiTS dataset contains annotations for kidney, tumor and cyst segmentation in CT scans.
  2NOTE: All patients have kidney and tumor annotations (however, not always have cysts annotated).
  3
  4The label ids are - kidney: 1, tumor: 2, cyst: 3
  5
  6This dataset is from the KiTS2 Challenge: https://kits-challenge.org/kits23/.
  7Please cite it if you use this dataset for your research.
  8"""
  9
 10import os
 11import json
 12import subprocess
 13from glob import glob
 14from tqdm import tqdm
 15from pathlib import Path
 16from natsort import natsorted
 17from typing import Union, Tuple, List, Optional, Literal
 18
 19import numpy as np
 20
 21from sklearn.model_selection import train_test_split
 22from torch.utils.data import Dataset, DataLoader
 23
 24import torch_em
 25
 26from .. import util
 27
 28
 29URL = "https://github.com/neheller/kits23"
 30VALID_SPLITS = ("train", "val", "test")
 31
 32
 33def get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 34    """Download the KiTS data.
 35
 36    Args:
 37        path: Filepath to a folder where the data is downloaded for further processing.
 38        download: Whether to download the data if it is not present.
 39
 40    Returns:
 41        The folder where the dataset is downloaded and preprocessed.
 42    """
 43    data_dir = os.path.join(path, "preprocessed")
 44    if os.path.exists(data_dir) and all(os.path.exists(os.path.join(data_dir, s)) for s in VALID_SPLITS):
 45        return data_dir
 46
 47    os.makedirs(path, exist_ok=True)
 48
 49    if not download:
 50        raise RuntimeError("The dataset is not found and download is set to False.")
 51
 52    # We clone the environment.
 53    if not os.path.exists(os.path.join(path, "kits23")):
 54        subprocess.run(["git", "clone", URL, os.path.join(path, "kits23")])
 55
 56    # We install the package-only (with the assumption that the other necessary packages already exists).
 57    chosen_patient_dir = natsorted(glob(os.path.join(path, "kits23", "dataset", "case*")))[-1]
 58    if not os.path.exists(os.path.join(chosen_patient_dir, "imaging.nii.gz")):
 59        subprocess.run(["pip", "install", "-e", os.path.join(path, "kits23"), "--no-deps"])
 60
 61        print("The download might take several hours. Make sure you have consistent internet connection.")
 62
 63        # Run the CLI to download the input images.
 64        subprocess.run(["kits23_download_data"])
 65
 66    # Preprocess the images.
 67    _preprocess_inputs(path)
 68
 69    return data_dir
 70
 71
 72def _preprocess_inputs(path):
 73    patient_dirs = glob(os.path.join(path, "kits23", "dataset", "case*"))
 74
 75    preprocessed_dir = os.path.join(path, "preprocessed")
 76
 77    for split in VALID_SPLITS:
 78        os.makedirs(os.path.join(preprocessed_dir, split), exist_ok=True)
 79
 80    json_path = os.path.join(path, "splits_kits.json")
 81
 82    if os.path.exists(json_path):
 83        with open(json_path) as f:
 84            split_info = json.load(f)
 85        split_map = {
 86            os.path.join(path, "kits23", "dataset", Path(fname).stem): split
 87            for split, fnames in split_info.items()
 88            for fname in fnames
 89        }
 90    else:
 91        train_dirs, test_dirs = train_test_split(patient_dirs, test_size=0.25, random_state=42)
 92        train_dirs, val_dirs = train_test_split(train_dirs, test_size=0.1, random_state=42)
 93        split_map = {
 94            **{d: "train" for d in train_dirs},
 95            **{d: "val" for d in val_dirs},
 96            **{d: "test" for d in test_dirs},
 97        }
 98        split_info = {"train": [], "val": [], "test": []}
 99
100    for patient_dir in tqdm(patient_dirs, desc="Preprocessing inputs"):
101        patient_id = os.path.basename(patient_dir)
102        split = split_map[patient_dir]
103        patient_fname = Path(patient_id).with_suffix(".h5")
104        patient_path = os.path.join(preprocessed_dir, split, patient_fname)
105
106        if not os.path.exists(json_path):
107            split_info[split].append(str(patient_fname))
108
109        if os.path.exists(patient_path):
110            continue
111
112        # Next, we find all rater annotations.
113        kidney_anns = natsorted(glob(os.path.join(patient_dir, "instances", "kidney_instance-1*")))
114        tumor_anns = natsorted(glob(os.path.join(patient_dir, "instances", "tumor_instance*")))
115        cyst_anns = natsorted(glob(os.path.join(patient_dir, "instances", "cyst_instance*")))
116
117        import h5py
118        import nibabel as nib
119
120        with h5py.File(patient_path, "w") as f:
121            # Input image.
122            raw = nib.load(os.path.join(patient_dir, "imaging.nii.gz")).get_fdata()
123            f.create_dataset("raw", data=raw, compression="gzip")
124
125            # Valid segmentation masks for all classes.
126            labels = nib.load(os.path.join(patient_dir, "segmentation.nii.gz")).get_fdata()
127            assert raw.shape == labels.shape, "The shape of inputs and corresponding segmentation does not match."
128            f.create_dataset("labels/all", data=labels, compression="gzip")
129
130            # Add annotations for kidneys per rater.
131            _k_exclusive = False
132            if not kidney_anns:
133                _k_exclusive = True
134                kidney_anns = natsorted(glob(os.path.join(patient_dir, "instances", "kidney_instance-2*")))
135
136            assert kidney_anns, f"There must be kidney annotations for '{patient_id}'."
137            for p in kidney_anns:
138                masks = np.zeros_like(raw)
139                rater_id = p[-8]  # The rater count
140
141                # Get the other kidney instance.
142                if _k_exclusive:
143                    print("The kidney annotations are numbered strangely.")
144                    other_p = p.replace("instance-2", "instance-3")
145                else:
146                    other_p = p.replace("instance-1", "instance-2")
147
148                # Merge both left and right kidney as one semantic id.
149                masks[nib.load(p).get_fdata() > 0] = 1
150                if os.path.exists(other_p):
151                    masks[nib.load(other_p).get_fdata() > 0] = 1
152                else:
153                    print(f"The second kidney instance does not exist for patient: '{patient_id}'.")
154
155                # Create a hierarchy for the particular rater's kidney annotations.
156                f.create_dataset(f"labels/kidney/rater_{rater_id}", data=masks, compression="gzip")
157
158            # Add annotations for tumor per rater.
159            assert tumor_anns, f"There must be tumor annotations for '{patient_id}'."
160            # Find the raters.
161            raters = [p[-8] for p in tumor_anns]
162            # Get masks per rater
163            unique_raters = np.unique(raters)
164            for rater in unique_raters:
165                masks = np.zeros_like(raw)
166                for p in glob(os.path.join(patient_dir, "instances", f"tumor_instance*-{rater}.nii.gz")):
167                    masks[nib.load(p).get_fdata() > 0] = 1
168
169                f.create_dataset(f"labels/tumor/rater_{rater}", data=masks, compression="gzip")
170
171            # Add annotations for cysts per rater.
172            if cyst_anns:
173                # Find the raters first
174                raters = [p[-8] for p in cyst_anns]
175                # Get masks per rater
176                unique_raters = np.unique(raters)
177                for rater in unique_raters:
178                    masks = np.zeros_like(raw)
179                    for p in glob(os.path.join(patient_dir, "instances", f"cyst_instance*-{rater}.nii.gz")):
180                        masks[nib.load(p).get_fdata() > 0] = 1
181
182                    f.create_dataset(f"labels/cyst/rater_{rater}", data=masks, compression="gzip")
183
184    if not os.path.exists(json_path):
185        with open(json_path, "w") as f:
186            json.dump(split_info, f, indent=2)
187
188
189def get_kits_paths(
190    path: Union[os.PathLike, str], split: Literal["train", "val", "test"], download: bool = False
191) -> List[str]:
192    """Get paths to the KiTS data.
193
194    Args:
195        path: Filepath to a folder where the data is downloaded for further processing.
196        split: Which data split to use.
197        download: Whether to download the data if it is not present.
198
199    Returns:
200        List of filepaths for the input data.
201    """
202
203    if split not in VALID_SPLITS:
204        raise ValueError(f"Invalid split '{split}'. Must be one of {VALID_SPLITS}.")
205
206    get_kits_data(path, download)
207
208    split_dir = os.path.join(path, "preprocessed", split)
209    if not os.path.exists(split_dir):
210        raise RuntimeError(f"Split folder '{split_dir}' does not exist.")
211
212    volume_paths = natsorted(glob(os.path.join(split_dir, "*.h5")))
213    if not volume_paths:
214        raise RuntimeError(f"No .h5 files found in split folder '{split_dir}'.")
215
216    return volume_paths
217
218
219def get_kits_dataset(
220    path: Union[os.PathLike, str],
221    patch_shape: Tuple[int, ...],
222    split: Literal["train", "val", "test"],
223    rater: Optional[Literal[1, 2, 3]] = None,
224    annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None,
225    resize_inputs: bool = False,
226    download: bool = False,
227    **kwargs
228) -> Dataset:
229    """Get the KiTS dataset for kidney, tumor and cyst segmentation.
230
231    Args:
232        path: Filepath to a folder where the data is downloaded for further processing.
233        patch_shape: The patch shape to use for training.
234        split: Which data split to use.
235        rater: The choice of rater.
236        annotation_choice: The choice of annotations.
237        resize_inputs: Whether to resize inputs to the desired patch shape.
238        download: Whether to download the data if it is not present.
239        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
240
241    Returns:
242        The segmentation dataset.
243    """
244    volume_paths = get_kits_paths(path, split, download)
245
246    if resize_inputs:
247        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
248        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
249            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
250        )
251
252    # TODO: simplify the design below later, to allow:
253    # - multi-rater label loading.
254    # - multi-annotation label loading.
255    # (for now, only 1v1 annotation-rater loading is supported).
256    if rater is None and annotation_choice is None:
257        label_key = "labels/all"
258    else:
259        assert rater is not None and annotation_choice is not None, \
260            "Both rater and annotation_choice must be specified together."
261
262        label_key = f"labels/{annotation_choice}/rater_{rater}"
263
264    return torch_em.default_segmentation_dataset(
265        raw_paths=volume_paths,
266        raw_key="raw",
267        label_paths=volume_paths,
268        label_key=label_key,
269        patch_shape=patch_shape,
270        **kwargs
271    )
272
273
274def get_kits_loader(
275    path: Union[os.PathLike, str],
276    batch_size: int,
277    patch_shape: Tuple[int, ...],
278    split: Literal["train", "val", "test"],
279    rater: Optional[Literal[1, 2, 3]] = None,
280    annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None,
281    resize_inputs: bool = False,
282    download: bool = False,
283    **kwargs
284) -> DataLoader:
285    """Get the KiTS dataloader for kidney, tumor and cyst segmentation.
286
287    Args:
288        path: Filepath to a folder where the data is downloaded for further processing.
289        batch_size: The batch size for training.
290        patch_shape: The patch shape to use for training.
291        split: Which data split to use.
292        rater: The choice of rater.
293        annotation_choice: The choice of annotations.
294        resize_inputs: Whether to resize inputs to the desired patch shape.
295        download: Whether to download the data if it is not present.
296        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
297
298    Returns:
299        The DataLoader.
300    """
301    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
302    dataset = get_kits_dataset(path, patch_shape, split, rater, annotation_choice, resize_inputs, download, **ds_kwargs)
303    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://github.com/neheller/kits23'
VALID_SPLITS = ('train', 'val', 'test')
def get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str:
34def get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str:
35    """Download the KiTS data.
36
37    Args:
38        path: Filepath to a folder where the data is downloaded for further processing.
39        download: Whether to download the data if it is not present.
40
41    Returns:
42        The folder where the dataset is downloaded and preprocessed.
43    """
44    data_dir = os.path.join(path, "preprocessed")
45    if os.path.exists(data_dir) and all(os.path.exists(os.path.join(data_dir, s)) for s in VALID_SPLITS):
46        return data_dir
47
48    os.makedirs(path, exist_ok=True)
49
50    if not download:
51        raise RuntimeError("The dataset is not found and download is set to False.")
52
53    # We clone the environment.
54    if not os.path.exists(os.path.join(path, "kits23")):
55        subprocess.run(["git", "clone", URL, os.path.join(path, "kits23")])
56
57    # We install the package-only (with the assumption that the other necessary packages already exists).
58    chosen_patient_dir = natsorted(glob(os.path.join(path, "kits23", "dataset", "case*")))[-1]
59    if not os.path.exists(os.path.join(chosen_patient_dir, "imaging.nii.gz")):
60        subprocess.run(["pip", "install", "-e", os.path.join(path, "kits23"), "--no-deps"])
61
62        print("The download might take several hours. Make sure you have consistent internet connection.")
63
64        # Run the CLI to download the input images.
65        subprocess.run(["kits23_download_data"])
66
67    # Preprocess the images.
68    _preprocess_inputs(path)
69
70    return data_dir

Download the KiTS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

The folder where the dataset is downloaded and preprocessed.

def get_kits_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> List[str]:
190def get_kits_paths(
191    path: Union[os.PathLike, str], split: Literal["train", "val", "test"], download: bool = False
192) -> List[str]:
193    """Get paths to the KiTS data.
194
195    Args:
196        path: Filepath to a folder where the data is downloaded for further processing.
197        split: Which data split to use.
198        download: Whether to download the data if it is not present.
199
200    Returns:
201        List of filepaths for the input data.
202    """
203
204    if split not in VALID_SPLITS:
205        raise ValueError(f"Invalid split '{split}'. Must be one of {VALID_SPLITS}.")
206
207    get_kits_data(path, download)
208
209    split_dir = os.path.join(path, "preprocessed", split)
210    if not os.path.exists(split_dir):
211        raise RuntimeError(f"Split folder '{split_dir}' does not exist.")
212
213    volume_paths = natsorted(glob(os.path.join(split_dir, "*.h5")))
214    if not volume_paths:
215        raise RuntimeError(f"No .h5 files found in split folder '{split_dir}'.")
216
217    return volume_paths

Get paths to the KiTS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: Which data split to use.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the input data.

def get_kits_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Optional[Literal[1, 2, 3]] = None, annotation_choice: Optional[Literal['kidney', 'tumor', 'cyst']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
220def get_kits_dataset(
221    path: Union[os.PathLike, str],
222    patch_shape: Tuple[int, ...],
223    split: Literal["train", "val", "test"],
224    rater: Optional[Literal[1, 2, 3]] = None,
225    annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None,
226    resize_inputs: bool = False,
227    download: bool = False,
228    **kwargs
229) -> Dataset:
230    """Get the KiTS dataset for kidney, tumor and cyst segmentation.
231
232    Args:
233        path: Filepath to a folder where the data is downloaded for further processing.
234        patch_shape: The patch shape to use for training.
235        split: Which data split to use.
236        rater: The choice of rater.
237        annotation_choice: The choice of annotations.
238        resize_inputs: Whether to resize inputs to the desired patch shape.
239        download: Whether to download the data if it is not present.
240        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
241
242    Returns:
243        The segmentation dataset.
244    """
245    volume_paths = get_kits_paths(path, split, download)
246
247    if resize_inputs:
248        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
249        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
250            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
251        )
252
253    # TODO: simplify the design below later, to allow:
254    # - multi-rater label loading.
255    # - multi-annotation label loading.
256    # (for now, only 1v1 annotation-rater loading is supported).
257    if rater is None and annotation_choice is None:
258        label_key = "labels/all"
259    else:
260        assert rater is not None and annotation_choice is not None, \
261            "Both rater and annotation_choice must be specified together."
262
263        label_key = f"labels/{annotation_choice}/rater_{rater}"
264
265    return torch_em.default_segmentation_dataset(
266        raw_paths=volume_paths,
267        raw_key="raw",
268        label_paths=volume_paths,
269        label_key=label_key,
270        patch_shape=patch_shape,
271        **kwargs
272    )

Get the KiTS dataset for kidney, tumor and cyst segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: Which data split to use.
  • rater: The choice of rater.
  • annotation_choice: The choice of annotations.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_kits_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], rater: Optional[Literal[1, 2, 3]] = None, annotation_choice: Optional[Literal['kidney', 'tumor', 'cyst']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
275def get_kits_loader(
276    path: Union[os.PathLike, str],
277    batch_size: int,
278    patch_shape: Tuple[int, ...],
279    split: Literal["train", "val", "test"],
280    rater: Optional[Literal[1, 2, 3]] = None,
281    annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None,
282    resize_inputs: bool = False,
283    download: bool = False,
284    **kwargs
285) -> DataLoader:
286    """Get the KiTS dataloader for kidney, tumor and cyst segmentation.
287
288    Args:
289        path: Filepath to a folder where the data is downloaded for further processing.
290        batch_size: The batch size for training.
291        patch_shape: The patch shape to use for training.
292        split: Which data split to use.
293        rater: The choice of rater.
294        annotation_choice: The choice of annotations.
295        resize_inputs: Whether to resize inputs to the desired patch shape.
296        download: Whether to download the data if it is not present.
297        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
298
299    Returns:
300        The DataLoader.
301    """
302    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
303    dataset = get_kits_dataset(path, patch_shape, split, rater, annotation_choice, resize_inputs, download, **ds_kwargs)
304    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the KiTS dataloader for kidney, tumor and cyst segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: Which data split to use.
  • rater: The choice of rater.
  • annotation_choice: The choice of annotations.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.