torch_em.data.datasets.medical.kits

The KiTS dataset contains annotations for kidney, tumor and cyst segmentation in CT scans. NOTE: All patients have kidney and tumor annotations (however, not always have cysts annotated).

The label ids are - kidney: 1, tumor: 2, cyst: 3

This dataset is from the KiTS2 Challenge: https://kits-challenge.org/kits23/. Please cite it if you use this dataset for your research.

  1"""The KiTS dataset contains annotations for kidney, tumor and cyst segmentation in CT scans.
  2NOTE: All patients have kidney and tumor annotations (however, not always have cysts annotated).
  3
  4The label ids are - kidney: 1, tumor: 2, cyst: 3
  5
  6This dataset is from the KiTS2 Challenge: https://kits-challenge.org/kits23/.
  7Please cite it if you use this dataset for your research.
  8"""
  9
 10import os
 11import subprocess
 12from glob import glob
 13from tqdm import tqdm
 14from pathlib import Path
 15from natsort import natsorted
 16from typing import Union, Tuple, List, Optional, Literal
 17
 18import numpy as np
 19
 20from torch.utils.data import Dataset, DataLoader
 21
 22import torch_em
 23
 24from .. import util
 25
 26
 27URL = "https://github.com/neheller/kits23"
 28
 29
 30def get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 31    """Download the KiTS data.
 32
 33    Args:
 34        path: Filepath to a folder where the data is downloaded for further processing.
 35        download: Whether to download the data if it is not present.
 36
 37    Returns:
 38        The folder where the dataset is downloaded and preprocessed.
 39    """
 40    data_dir = os.path.join(path, "preprocessed")
 41    if os.path.exists(data_dir):
 42        return data_dir
 43
 44    os.makedirs(path, exist_ok=True)
 45
 46    if not download:
 47        raise RuntimeError("The dataset is not found and download is set to False.")
 48
 49    # We clone the environment.
 50    if not os.path.exists(os.path.join(path, "kits23")):
 51        subprocess.run(["git", "clone", URL, os.path.join(path, "kits23")])
 52
 53    # We install the package-only (with the assumption that the other necessary packages already exists).
 54    chosen_patient_dir = natsorted(glob(os.path.join(path, "kits23", "dataset", "case*")))[-1]
 55    if not os.path.exists(os.path.join(chosen_patient_dir, "imaging.nii.gz")):
 56        subprocess.run(["pip", "install", "-e", os.path.join(path, "kits23"), "--no-deps"])
 57
 58        print("The download might take several hours. Make sure you have consistent internet connection.")
 59
 60        # Run the CLI to download the input images.
 61        subprocess.run(["kits23_download_data"])
 62
 63    # Preprocess the images.
 64    _preprocess_inputs(path)
 65
 66    return data_dir
 67
 68
 69def _preprocess_inputs(path):
 70    patient_dirs = glob(os.path.join(path, "kits23", "dataset", "case*"))
 71
 72    preprocessed_dir = os.path.join(path, "preprocessed")
 73    os.makedirs(preprocessed_dir, exist_ok=True)
 74
 75    for patient_dir in tqdm(patient_dirs, desc="Preprocessing inputs"):
 76        patient_id = os.path.basename(patient_dir)
 77        patient_path = os.path.join(preprocessed_dir, Path(patient_id).with_suffix(".h5"))
 78
 79        if os.path.exists(patient_path):
 80            continue
 81
 82        # Next, we find all rater annotations.
 83        kidney_anns = natsorted(glob(os.path.join(patient_dir, "instances", "kidney_instance-1*")))
 84        tumor_anns = natsorted(glob(os.path.join(patient_dir, "instances", "tumor_instance*")))
 85        cyst_anns = natsorted(glob(os.path.join(patient_dir, "instances", "cyst_instance*")))
 86
 87        import h5py
 88        import nibabel as nib
 89
 90        with h5py.File(patient_path, "w") as f:
 91            # Input image.
 92            raw = nib.load(os.path.join(patient_dir, "imaging.nii.gz")).get_fdata()
 93            f.create_dataset("raw", data=raw, compression="gzip")
 94
 95            # Valid segmentation masks for all classes.
 96            labels = nib.load(os.path.join(patient_dir, "segmentation.nii.gz")).get_fdata()
 97            assert raw.shape == labels.shape, "The shape of inputs and corresponding segmentation does not match."
 98
 99            f.create_dataset("labels/all", data=labels, compression="gzip")
100
101            # Add annotations for kidneys per rater.
102            _k_exclusive = False
103            if not kidney_anns:
104                _k_exclusive = True
105                kidney_anns = natsorted(glob(os.path.join(patient_dir, "instances", "kidney_instance-2*")))
106
107            assert kidney_anns, f"There must be kidney annotations for '{patient_id}'."
108            for p in kidney_anns:
109                masks = np.zeros_like(raw)
110                rater_id = p[-8]  # The rater count
111
112                # Get the other kidney instance.
113                if _k_exclusive:
114                    print("The kidney annotations are numbered strangely.")
115                    other_p = p.replace("instance-2", "instance-3")
116                else:
117                    other_p = p.replace("instance-1", "instance-2")
118
119                # Merge both left and right kidney as one semantic id.
120                masks[nib.load(p).get_fdata() > 0] = 1
121                if os.path.exists(other_p):
122                    masks[nib.load(other_p).get_fdata() > 0] = 1
123                else:
124                    print(f"The second kidney instance does not exist for patient: '{patient_id}'.")
125
126                # Create a hierarchy for the particular rater's kidney annotations.
127                f.create_dataset(f"labels/kidney/rater_{rater_id}", data=masks, compression="gzip")
128
129            # Add annotations for tumor per rater.
130            assert tumor_anns, f"There must be tumor annotations for '{patient_id}'."
131            # Find the raters.
132            raters = [p[-8] for p in tumor_anns]
133            # Get masks per rater
134            unique_raters = np.unique(raters)
135            for rater in unique_raters:
136                masks = np.zeros_like(raw)
137                for p in glob(os.path.join(patient_dir, "instances", f"tumor_instance*-{rater}.nii.gz")):
138                    masks[nib.load(p).get_fdata() > 0] = 1
139
140                f.create_dataset(f"labels/tumor/rater_{rater}", data=masks, compression="gzip")
141
142            # Add annotations for cysts per rater.
143            if cyst_anns:
144                # Find the raters first
145                raters = [p[-8] for p in cyst_anns]
146                # Get masks per rater
147                unique_raters = np.unique(raters)
148                for rater in unique_raters:
149                    masks = np.zeros_like(raw)
150                    for p in glob(os.path.join(patient_dir, "instances", f"cyst_instance*-{rater}.nii.gz")):
151                        masks[nib.load(p).get_fdata() > 0] = 1
152
153                    f.create_dataset(f"labels/cyst/rater_{rater}", data=masks, compression="gzip")
154
155
156def get_kits_paths(path: Union[os.PathLike, str], download: bool = False) -> List[str]:
157    """Get paths to the KiTS data.
158
159    Args:
160        path: Filepath to a folder where the data is downloaded for further processing.
161        download: Whether to download the data if it is not present.
162
163    Returns:
164        List of filepaths for the input data.
165    """
166    data_dir = get_kits_data(path, download)
167    volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5")))
168    return volume_paths
169
170
171def get_kits_dataset(
172    path: Union[os.PathLike, str],
173    patch_shape: Tuple[int, ...],
174    rater: Optional[Literal[1, 2, 3]] = None,
175    annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None,
176    resize_inputs: bool = False,
177    download: bool = False,
178    **kwargs
179) -> Dataset:
180    """Get the KiTS dataset for kidney, tumor and cyst segmentation.
181
182    Args:
183        path: Filepath to a folder where the data is downloaded for further processing.
184        patch_shape: The patch shape to use for training.
185        rater: The choice of rater.
186        annotation_choice: The choice of annotations.
187        resize_inputs:  Whether to resize inputs to the desired patch shape.
188        download: Whether to download the data if it is not present.
189        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
190
191    Returns:
192        The segmentation dataset.
193    """
194    volume_paths = get_kits_paths(path, download)
195
196    if resize_inputs:
197        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
198        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
199            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
200        )
201
202    # TODO: simplify the design below later, to allow:
203    # - multi-rater label loading.
204    # - multi-annotation label loading.
205    # (for now, only 1v1 annotation-rater loading is supported).
206    if rater is None and annotation_choice is None:
207        label_key = "labels/all"
208    else:
209        assert rater is not None and annotation_choice is not None, \
210            "Both rater and annotation_choice must be specified together."
211
212        label_key = f"labels/{annotation_choice}/rater_{rater}"
213
214    return torch_em.default_segmentation_dataset(
215        raw_paths=volume_paths,
216        raw_key="raw",
217        label_paths=volume_paths,
218        label_key=label_key,
219        patch_shape=patch_shape,
220        **kwargs
221    )
222
223
224def get_kits_loader(
225    path: Union[os.PathLike, str],
226    batch_size: int,
227    patch_shape: Tuple[int, ...],
228    rater: Optional[Literal[1, 2, 3]] = None,
229    annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None,
230    resize_inputs: bool = False,
231    download: bool = False,
232    **kwargs
233) -> DataLoader:
234    """Get the KiTS dataloader for kidney, tumor and cyst segmentation.
235
236    Args:
237        path: Filepath to a folder where the data is downloaded for further processing.
238        batch_size: The batch size for training.
239        patch_shape: The patch shape to use for training.
240        rater: The choice of rater.
241        annotation_choice: The choice of annotations.
242        resize_inputs:  Whether to resize inputs to the desired patch shape.
243        download: Whether to download the data if it is not present.
244        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
245
246    Returns:
247        The DataLoader.
248    """
249    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
250    dataset = get_kits_dataset(path, patch_shape, rater, annotation_choice, resize_inputs, download, **ds_kwargs)
251    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://github.com/neheller/kits23'
def get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str:
31def get_kits_data(path: Union[os.PathLike, str], download: bool = False) -> str:
32    """Download the KiTS data.
33
34    Args:
35        path: Filepath to a folder where the data is downloaded for further processing.
36        download: Whether to download the data if it is not present.
37
38    Returns:
39        The folder where the dataset is downloaded and preprocessed.
40    """
41    data_dir = os.path.join(path, "preprocessed")
42    if os.path.exists(data_dir):
43        return data_dir
44
45    os.makedirs(path, exist_ok=True)
46
47    if not download:
48        raise RuntimeError("The dataset is not found and download is set to False.")
49
50    # We clone the environment.
51    if not os.path.exists(os.path.join(path, "kits23")):
52        subprocess.run(["git", "clone", URL, os.path.join(path, "kits23")])
53
54    # We install the package-only (with the assumption that the other necessary packages already exists).
55    chosen_patient_dir = natsorted(glob(os.path.join(path, "kits23", "dataset", "case*")))[-1]
56    if not os.path.exists(os.path.join(chosen_patient_dir, "imaging.nii.gz")):
57        subprocess.run(["pip", "install", "-e", os.path.join(path, "kits23"), "--no-deps"])
58
59        print("The download might take several hours. Make sure you have consistent internet connection.")
60
61        # Run the CLI to download the input images.
62        subprocess.run(["kits23_download_data"])
63
64    # Preprocess the images.
65    _preprocess_inputs(path)
66
67    return data_dir

Download the KiTS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

The folder where the dataset is downloaded and preprocessed.

def get_kits_paths(path: Union[os.PathLike, str], download: bool = False) -> List[str]:
157def get_kits_paths(path: Union[os.PathLike, str], download: bool = False) -> List[str]:
158    """Get paths to the KiTS data.
159
160    Args:
161        path: Filepath to a folder where the data is downloaded for further processing.
162        download: Whether to download the data if it is not present.
163
164    Returns:
165        List of filepaths for the input data.
166    """
167    data_dir = get_kits_data(path, download)
168    volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5")))
169    return volume_paths

Get paths to the KiTS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the input data.

def get_kits_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], rater: Optional[Literal[1, 2, 3]] = None, annotation_choice: Optional[Literal['kidney', 'tumor', 'cyst']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
172def get_kits_dataset(
173    path: Union[os.PathLike, str],
174    patch_shape: Tuple[int, ...],
175    rater: Optional[Literal[1, 2, 3]] = None,
176    annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None,
177    resize_inputs: bool = False,
178    download: bool = False,
179    **kwargs
180) -> Dataset:
181    """Get the KiTS dataset for kidney, tumor and cyst segmentation.
182
183    Args:
184        path: Filepath to a folder where the data is downloaded for further processing.
185        patch_shape: The patch shape to use for training.
186        rater: The choice of rater.
187        annotation_choice: The choice of annotations.
188        resize_inputs:  Whether to resize inputs to the desired patch shape.
189        download: Whether to download the data if it is not present.
190        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
191
192    Returns:
193        The segmentation dataset.
194    """
195    volume_paths = get_kits_paths(path, download)
196
197    if resize_inputs:
198        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
199        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
200            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
201        )
202
203    # TODO: simplify the design below later, to allow:
204    # - multi-rater label loading.
205    # - multi-annotation label loading.
206    # (for now, only 1v1 annotation-rater loading is supported).
207    if rater is None and annotation_choice is None:
208        label_key = "labels/all"
209    else:
210        assert rater is not None and annotation_choice is not None, \
211            "Both rater and annotation_choice must be specified together."
212
213        label_key = f"labels/{annotation_choice}/rater_{rater}"
214
215    return torch_em.default_segmentation_dataset(
216        raw_paths=volume_paths,
217        raw_key="raw",
218        label_paths=volume_paths,
219        label_key=label_key,
220        patch_shape=patch_shape,
221        **kwargs
222    )

Get the KiTS dataset for kidney, tumor and cyst segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • rater: The choice of rater.
  • annotation_choice: The choice of annotations.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_kits_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], rater: Optional[Literal[1, 2, 3]] = None, annotation_choice: Optional[Literal['kidney', 'tumor', 'cyst']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
225def get_kits_loader(
226    path: Union[os.PathLike, str],
227    batch_size: int,
228    patch_shape: Tuple[int, ...],
229    rater: Optional[Literal[1, 2, 3]] = None,
230    annotation_choice: Optional[Literal["kidney", "tumor", "cyst"]] = None,
231    resize_inputs: bool = False,
232    download: bool = False,
233    **kwargs
234) -> DataLoader:
235    """Get the KiTS dataloader for kidney, tumor and cyst segmentation.
236
237    Args:
238        path: Filepath to a folder where the data is downloaded for further processing.
239        batch_size: The batch size for training.
240        patch_shape: The patch shape to use for training.
241        rater: The choice of rater.
242        annotation_choice: The choice of annotations.
243        resize_inputs:  Whether to resize inputs to the desired patch shape.
244        download: Whether to download the data if it is not present.
245        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
246
247    Returns:
248        The DataLoader.
249    """
250    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
251    dataset = get_kits_dataset(path, patch_shape, rater, annotation_choice, resize_inputs, download, **ds_kwargs)
252    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the KiTS dataloader for kidney, tumor and cyst segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • rater: The choice of rater.
  • annotation_choice: The choice of annotations.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.