torch_em.data.datasets.medical.han_seg

The HaN-Seg dataset contains annotations for head and neck organs in CT scans.

This dataset is from Podobnik et al. - https://doi.org/10.1002/mp.16197 Please cite it if you use it in a publication.

  1"""The HaN-Seg dataset contains annotations for head and neck organs in CT scans.
  2
  3This dataset is from Podobnik et al. - https://doi.org/10.1002/mp.16197
  4Please cite it if you use it in a publication.
  5"""
  6
  7import os
  8from glob import glob
  9from tqdm import tqdm
 10from pathlib import Path
 11from natsort import natsorted
 12from typing import Union, Tuple, List
 13
 14from torch.utils.data import Dataset, DataLoader
 15
 16import torch_em
 17
 18from .. import util
 19
 20
 21URL = "https://zenodo.org/records/7442914/files/HaN-Seg.zip"
 22CHECKSUM = "20226dd717f334dc1b1afe961b3375f946fa56b64a80bf5349128f90c0bbfa5f"
 23
 24
 25def get_han_seg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 26    """Get the HaN-Seg dataset.
 27
 28    Args:
 29        path: Filepath to a folder where the data is downloaded for further processing.
 30        download: Whether to download the data if it is not present.
 31
 32    Returns:
 33        Filepath where the data is downloaded.
 34    """
 35    data_dir = os.path.join(path, "HaN-Seg")
 36    if os.path.exists(data_dir):
 37        return data_dir
 38
 39    os.makedirs(path, exist_ok=True)
 40
 41    zip_path = os.path.join(path, "HaN-Seg.zip")
 42    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 43    util.unzip(zip_path=zip_path, dst=path, remove=False)
 44
 45    return data_dir
 46
 47
 48def get_han_seg_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]:
 49    """Get the HaN-Seg dataset.
 50
 51    Args:
 52        path: Filepath to a folder where the data is downloaded for further processing.
 53        download: Whether to download the data if it is not present.
 54
 55    Returns:
 56        List of filepaths for the image data.
 57        List of filepaths for the label data.
 58    """
 59    import nrrd
 60    import numpy as np
 61    import nibabel as nib
 62
 63    data_dir = get_han_seg_data(path=path, download=download)
 64
 65    image_dir = os.path.join(data_dir, "set_1", "preprocessed", "images")
 66    gt_dir = os.path.join(data_dir, "set_1", "preprocessed", "ground_truth")
 67    os.makedirs(image_dir, exist_ok=True)
 68    os.makedirs(gt_dir, exist_ok=True)
 69
 70    image_paths, gt_paths = [], []
 71    all_case_dirs = natsorted(glob(os.path.join(data_dir, "set_1", "case_*")))
 72    for case_dir in tqdm(all_case_dirs):
 73        image_path = os.path.join(image_dir, f"{os.path.split(case_dir)[-1]}_ct.nii.gz")
 74        gt_path = os.path.join(gt_dir, f"{os.path.split(case_dir)[-1]}.nii.gz")
 75        image_paths.append(image_path)
 76        gt_paths.append(gt_path)
 77        if os.path.exists(image_path) and os.path.exists(gt_path):
 78            continue
 79
 80        all_nrrd_paths = natsorted(glob(os.path.join(case_dir, "*.nrrd")))
 81        all_volumes, all_volume_ids = [], []
 82        for nrrd_path in all_nrrd_paths:
 83            image_id = Path(nrrd_path).stem
 84
 85            # we skip the MRI volumes
 86            if image_id.endswith("_MR_T1"):
 87                continue
 88
 89            data, header = nrrd.read(nrrd_path)
 90            all_volumes.append(data)
 91            all_volume_ids.append(image_id)
 92
 93        raw = all_volumes[0]
 94        raw = nib.Nifti2Image(raw, np.eye(4))
 95        nib.save(raw, image_path)
 96
 97        gt = np.zeros(raw.shape)
 98        for idx, per_organ in enumerate(all_volumes[1:], 1):
 99            gt[per_organ > 0] = idx
100        gt = nib.Nifti2Image(gt, np.eye(4))
101        nib.save(gt, gt_path)
102
103    return image_paths, gt_paths
104
105
106def get_han_seg_dataset(
107    path: Union[os.PathLike, str],
108    patch_shape: Tuple[int, ...],
109    resize_inputs: bool = False,
110    download: bool = False,
111    **kwargs
112) -> Dataset:
113    """Get the HaN-Seg dataset for head and neck organ segmentation.
114
115    Args:
116        path: Filepath to a folder where the data is downloaded for further processing.
117        patch_shape: The patch shape to use for training.
118        resize_inputs: Whether to resize inputs to the desired patch shape.
119        download: Whether to download the data if it is not present.
120        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
121
122    Returns:
123        The segmentation dataset..
124    """
125    image_paths, gt_paths = get_han_seg_paths(path, download)
126
127    if resize_inputs:
128        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
129        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
130            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs,
131        )
132
133    return torch_em.default_segmentation_dataset(
134        raw_paths=image_paths,
135        raw_key="data",
136        label_paths=gt_paths,
137        label_key="data",
138        patch_shape=patch_shape,
139        **kwargs
140    )
141
142
143def get_han_seg_loader(
144    path: Union[os.PathLike, str],
145    batch_size: int,
146    patch_shape: Tuple[int, ...],
147    resize_inputs: bool = False,
148    download: bool = False,
149    **kwargs
150) -> DataLoader:
151    """Get the HaN-Seg dataloader for head and neck organ segmentation.
152
153    Args:
154        path: Filepath to a folder where the data is downloaded for further processing.
155        batch_size: The batch size for training.
156        patch_shape: The patch shape to use for training.
157        resize_inputs: Whether to resize inputs to the desired patch shape.
158        download: Whether to download the data if it is not present.
159        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
160
161    Returns:
162        The DataLoader.
163    """
164    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
165    dataset = get_han_seg_dataset(path, patch_shape, resize_inputs, download, **ds_kwargs)
166    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://zenodo.org/records/7442914/files/HaN-Seg.zip'
CHECKSUM = '20226dd717f334dc1b1afe961b3375f946fa56b64a80bf5349128f90c0bbfa5f'
def get_han_seg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
26def get_han_seg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
27    """Get the HaN-Seg dataset.
28
29    Args:
30        path: Filepath to a folder where the data is downloaded for further processing.
31        download: Whether to download the data if it is not present.
32
33    Returns:
34        Filepath where the data is downloaded.
35    """
36    data_dir = os.path.join(path, "HaN-Seg")
37    if os.path.exists(data_dir):
38        return data_dir
39
40    os.makedirs(path, exist_ok=True)
41
42    zip_path = os.path.join(path, "HaN-Seg.zip")
43    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
44    util.unzip(zip_path=zip_path, dst=path, remove=False)
45
46    return data_dir

Get the HaN-Seg dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_han_seg_paths( path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]:
 49def get_han_seg_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]:
 50    """Get the HaN-Seg dataset.
 51
 52    Args:
 53        path: Filepath to a folder where the data is downloaded for further processing.
 54        download: Whether to download the data if it is not present.
 55
 56    Returns:
 57        List of filepaths for the image data.
 58        List of filepaths for the label data.
 59    """
 60    import nrrd
 61    import numpy as np
 62    import nibabel as nib
 63
 64    data_dir = get_han_seg_data(path=path, download=download)
 65
 66    image_dir = os.path.join(data_dir, "set_1", "preprocessed", "images")
 67    gt_dir = os.path.join(data_dir, "set_1", "preprocessed", "ground_truth")
 68    os.makedirs(image_dir, exist_ok=True)
 69    os.makedirs(gt_dir, exist_ok=True)
 70
 71    image_paths, gt_paths = [], []
 72    all_case_dirs = natsorted(glob(os.path.join(data_dir, "set_1", "case_*")))
 73    for case_dir in tqdm(all_case_dirs):
 74        image_path = os.path.join(image_dir, f"{os.path.split(case_dir)[-1]}_ct.nii.gz")
 75        gt_path = os.path.join(gt_dir, f"{os.path.split(case_dir)[-1]}.nii.gz")
 76        image_paths.append(image_path)
 77        gt_paths.append(gt_path)
 78        if os.path.exists(image_path) and os.path.exists(gt_path):
 79            continue
 80
 81        all_nrrd_paths = natsorted(glob(os.path.join(case_dir, "*.nrrd")))
 82        all_volumes, all_volume_ids = [], []
 83        for nrrd_path in all_nrrd_paths:
 84            image_id = Path(nrrd_path).stem
 85
 86            # we skip the MRI volumes
 87            if image_id.endswith("_MR_T1"):
 88                continue
 89
 90            data, header = nrrd.read(nrrd_path)
 91            all_volumes.append(data)
 92            all_volume_ids.append(image_id)
 93
 94        raw = all_volumes[0]
 95        raw = nib.Nifti2Image(raw, np.eye(4))
 96        nib.save(raw, image_path)
 97
 98        gt = np.zeros(raw.shape)
 99        for idx, per_organ in enumerate(all_volumes[1:], 1):
100            gt[per_organ > 0] = idx
101        gt = nib.Nifti2Image(gt, np.eye(4))
102        nib.save(gt, gt_path)
103
104    return image_paths, gt_paths

Get the HaN-Seg dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_han_seg_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
107def get_han_seg_dataset(
108    path: Union[os.PathLike, str],
109    patch_shape: Tuple[int, ...],
110    resize_inputs: bool = False,
111    download: bool = False,
112    **kwargs
113) -> Dataset:
114    """Get the HaN-Seg dataset for head and neck organ segmentation.
115
116    Args:
117        path: Filepath to a folder where the data is downloaded for further processing.
118        patch_shape: The patch shape to use for training.
119        resize_inputs: Whether to resize inputs to the desired patch shape.
120        download: Whether to download the data if it is not present.
121        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
122
123    Returns:
124        The segmentation dataset..
125    """
126    image_paths, gt_paths = get_han_seg_paths(path, download)
127
128    if resize_inputs:
129        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
130        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
131            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs,
132        )
133
134    return torch_em.default_segmentation_dataset(
135        raw_paths=image_paths,
136        raw_key="data",
137        label_paths=gt_paths,
138        label_key="data",
139        patch_shape=patch_shape,
140        **kwargs
141    )

Get the HaN-Seg dataset for head and neck organ segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset..

def get_han_seg_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
144def get_han_seg_loader(
145    path: Union[os.PathLike, str],
146    batch_size: int,
147    patch_shape: Tuple[int, ...],
148    resize_inputs: bool = False,
149    download: bool = False,
150    **kwargs
151) -> DataLoader:
152    """Get the HaN-Seg dataloader for head and neck organ segmentation.
153
154    Args:
155        path: Filepath to a folder where the data is downloaded for further processing.
156        batch_size: The batch size for training.
157        patch_shape: The patch shape to use for training.
158        resize_inputs: Whether to resize inputs to the desired patch shape.
159        download: Whether to download the data if it is not present.
160        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
161
162    Returns:
163        The DataLoader.
164    """
165    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
166    dataset = get_han_seg_dataset(path, patch_shape, resize_inputs, download, **ds_kwargs)
167    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the HaN-Seg dataloader for head and neck organ segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.