torch_em.data.datasets.medical.han_seg
The HaN-Seg dataset contains annotations for head and neck organs in CT scans.
This dataset is from Podobnik et al. - https://doi.org/10.1002/mp.16197 Please cite it if you use it in a publication.
1"""The HaN-Seg dataset contains annotations for head and neck organs in CT scans. 2 3This dataset is from Podobnik et al. - https://doi.org/10.1002/mp.16197 4Please cite it if you use it in a publication. 5""" 6 7import os 8from glob import glob 9from tqdm import tqdm 10from pathlib import Path 11from natsort import natsorted 12from typing import Union, Tuple, List 13 14from torch.utils.data import Dataset, DataLoader 15 16import torch_em 17 18from .. import util 19 20 21URL = "https://zenodo.org/records/7442914/files/HaN-Seg.zip" 22CHECKSUM = "20226dd717f334dc1b1afe961b3375f946fa56b64a80bf5349128f90c0bbfa5f" 23 24 25def get_han_seg_data(path: Union[os.PathLike, str], download: bool = False) -> str: 26 """Get the HaN-Seg dataset. 27 28 Args: 29 path: Filepath to a folder where the data is downloaded for further processing. 30 download: Whether to download the data if it is not present. 31 32 Returns: 33 Filepath where the data is downloaded. 34 """ 35 data_dir = os.path.join(path, "HaN-Seg") 36 if os.path.exists(data_dir): 37 return data_dir 38 39 os.makedirs(path, exist_ok=True) 40 41 zip_path = os.path.join(path, "HaN-Seg.zip") 42 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 43 util.unzip(zip_path=zip_path, dst=path, remove=False) 44 45 return data_dir 46 47 48def get_han_seg_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]: 49 """Get the HaN-Seg dataset. 50 51 Args: 52 path: Filepath to a folder where the data is downloaded for further processing. 53 download: Whether to download the data if it is not present. 54 55 Returns: 56 List of filepaths for the image data. 57 List of filepaths for the label data. 58 """ 59 import nrrd 60 import numpy as np 61 import nibabel as nib 62 63 data_dir = get_han_seg_data(path=path, download=download) 64 65 image_dir = os.path.join(data_dir, "set_1", "preprocessed", "images") 66 gt_dir = os.path.join(data_dir, "set_1", "preprocessed", "ground_truth") 67 os.makedirs(image_dir, exist_ok=True) 68 os.makedirs(gt_dir, exist_ok=True) 69 70 image_paths, gt_paths = [], [] 71 all_case_dirs = natsorted(glob(os.path.join(data_dir, "set_1", "case_*"))) 72 for case_dir in tqdm(all_case_dirs): 73 image_path = os.path.join(image_dir, f"{os.path.split(case_dir)[-1]}_ct.nii.gz") 74 gt_path = os.path.join(gt_dir, f"{os.path.split(case_dir)[-1]}.nii.gz") 75 image_paths.append(image_path) 76 gt_paths.append(gt_path) 77 if os.path.exists(image_path) and os.path.exists(gt_path): 78 continue 79 80 all_nrrd_paths = natsorted(glob(os.path.join(case_dir, "*.nrrd"))) 81 all_volumes, all_volume_ids = [], [] 82 for nrrd_path in all_nrrd_paths: 83 image_id = Path(nrrd_path).stem 84 85 # we skip the MRI volumes 86 if image_id.endswith("_MR_T1"): 87 continue 88 89 data, header = nrrd.read(nrrd_path) 90 all_volumes.append(data) 91 all_volume_ids.append(image_id) 92 93 raw = all_volumes[0] 94 raw = nib.Nifti2Image(raw, np.eye(4)) 95 nib.save(raw, image_path) 96 97 gt = np.zeros(raw.shape) 98 for idx, per_organ in enumerate(all_volumes[1:], 1): 99 gt[per_organ > 0] = idx 100 gt = nib.Nifti2Image(gt, np.eye(4)) 101 nib.save(gt, gt_path) 102 103 return image_paths, gt_paths 104 105 106def get_han_seg_dataset( 107 path: Union[os.PathLike, str], 108 patch_shape: Tuple[int, ...], 109 resize_inputs: bool = False, 110 download: bool = False, 111 **kwargs 112) -> Dataset: 113 """Get the HaN-Seg dataset for head and neck organ segmentation. 114 115 Args: 116 path: Filepath to a folder where the data is downloaded for further processing. 117 patch_shape: The patch shape to use for training. 118 resize_inputs: Whether to resize inputs to the desired patch shape. 119 download: Whether to download the data if it is not present. 120 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 121 122 Returns: 123 The segmentation dataset.. 124 """ 125 image_paths, gt_paths = get_han_seg_paths(path, download) 126 127 if resize_inputs: 128 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 129 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 130 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs, 131 ) 132 133 return torch_em.default_segmentation_dataset( 134 raw_paths=image_paths, 135 raw_key="data", 136 label_paths=gt_paths, 137 label_key="data", 138 patch_shape=patch_shape, 139 **kwargs 140 ) 141 142 143def get_han_seg_loader( 144 path: Union[os.PathLike, str], 145 batch_size: int, 146 patch_shape: Tuple[int, ...], 147 resize_inputs: bool = False, 148 download: bool = False, 149 **kwargs 150) -> DataLoader: 151 """Get the HaN-Seg dataloader for head and neck organ segmentation. 152 153 Args: 154 path: Filepath to a folder where the data is downloaded for further processing. 155 batch_size: The batch size for training. 156 patch_shape: The patch shape to use for training. 157 resize_inputs: Whether to resize inputs to the desired patch shape. 158 download: Whether to download the data if it is not present. 159 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 160 161 Returns: 162 The DataLoader. 163 """ 164 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 165 dataset = get_han_seg_dataset(path, patch_shape, resize_inputs, download, **ds_kwargs) 166 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://zenodo.org/records/7442914/files/HaN-Seg.zip'
CHECKSUM =
'20226dd717f334dc1b1afe961b3375f946fa56b64a80bf5349128f90c0bbfa5f'
def
get_han_seg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
26def get_han_seg_data(path: Union[os.PathLike, str], download: bool = False) -> str: 27 """Get the HaN-Seg dataset. 28 29 Args: 30 path: Filepath to a folder where the data is downloaded for further processing. 31 download: Whether to download the data if it is not present. 32 33 Returns: 34 Filepath where the data is downloaded. 35 """ 36 data_dir = os.path.join(path, "HaN-Seg") 37 if os.path.exists(data_dir): 38 return data_dir 39 40 os.makedirs(path, exist_ok=True) 41 42 zip_path = os.path.join(path, "HaN-Seg.zip") 43 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 44 util.unzip(zip_path=zip_path, dst=path, remove=False) 45 46 return data_dir
Get the HaN-Seg dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_han_seg_paths( path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]:
49def get_han_seg_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]: 50 """Get the HaN-Seg dataset. 51 52 Args: 53 path: Filepath to a folder where the data is downloaded for further processing. 54 download: Whether to download the data if it is not present. 55 56 Returns: 57 List of filepaths for the image data. 58 List of filepaths for the label data. 59 """ 60 import nrrd 61 import numpy as np 62 import nibabel as nib 63 64 data_dir = get_han_seg_data(path=path, download=download) 65 66 image_dir = os.path.join(data_dir, "set_1", "preprocessed", "images") 67 gt_dir = os.path.join(data_dir, "set_1", "preprocessed", "ground_truth") 68 os.makedirs(image_dir, exist_ok=True) 69 os.makedirs(gt_dir, exist_ok=True) 70 71 image_paths, gt_paths = [], [] 72 all_case_dirs = natsorted(glob(os.path.join(data_dir, "set_1", "case_*"))) 73 for case_dir in tqdm(all_case_dirs): 74 image_path = os.path.join(image_dir, f"{os.path.split(case_dir)[-1]}_ct.nii.gz") 75 gt_path = os.path.join(gt_dir, f"{os.path.split(case_dir)[-1]}.nii.gz") 76 image_paths.append(image_path) 77 gt_paths.append(gt_path) 78 if os.path.exists(image_path) and os.path.exists(gt_path): 79 continue 80 81 all_nrrd_paths = natsorted(glob(os.path.join(case_dir, "*.nrrd"))) 82 all_volumes, all_volume_ids = [], [] 83 for nrrd_path in all_nrrd_paths: 84 image_id = Path(nrrd_path).stem 85 86 # we skip the MRI volumes 87 if image_id.endswith("_MR_T1"): 88 continue 89 90 data, header = nrrd.read(nrrd_path) 91 all_volumes.append(data) 92 all_volume_ids.append(image_id) 93 94 raw = all_volumes[0] 95 raw = nib.Nifti2Image(raw, np.eye(4)) 96 nib.save(raw, image_path) 97 98 gt = np.zeros(raw.shape) 99 for idx, per_organ in enumerate(all_volumes[1:], 1): 100 gt[per_organ > 0] = idx 101 gt = nib.Nifti2Image(gt, np.eye(4)) 102 nib.save(gt, gt_path) 103 104 return image_paths, gt_paths
Get the HaN-Seg dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_han_seg_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
107def get_han_seg_dataset( 108 path: Union[os.PathLike, str], 109 patch_shape: Tuple[int, ...], 110 resize_inputs: bool = False, 111 download: bool = False, 112 **kwargs 113) -> Dataset: 114 """Get the HaN-Seg dataset for head and neck organ segmentation. 115 116 Args: 117 path: Filepath to a folder where the data is downloaded for further processing. 118 patch_shape: The patch shape to use for training. 119 resize_inputs: Whether to resize inputs to the desired patch shape. 120 download: Whether to download the data if it is not present. 121 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 122 123 Returns: 124 The segmentation dataset.. 125 """ 126 image_paths, gt_paths = get_han_seg_paths(path, download) 127 128 if resize_inputs: 129 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 130 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 131 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs, 132 ) 133 134 return torch_em.default_segmentation_dataset( 135 raw_paths=image_paths, 136 raw_key="data", 137 label_paths=gt_paths, 138 label_key="data", 139 patch_shape=patch_shape, 140 **kwargs 141 )
Get the HaN-Seg dataset for head and neck organ segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset..
def
get_han_seg_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
144def get_han_seg_loader( 145 path: Union[os.PathLike, str], 146 batch_size: int, 147 patch_shape: Tuple[int, ...], 148 resize_inputs: bool = False, 149 download: bool = False, 150 **kwargs 151) -> DataLoader: 152 """Get the HaN-Seg dataloader for head and neck organ segmentation. 153 154 Args: 155 path: Filepath to a folder where the data is downloaded for further processing. 156 batch_size: The batch size for training. 157 patch_shape: The patch shape to use for training. 158 resize_inputs: Whether to resize inputs to the desired patch shape. 159 download: Whether to download the data if it is not present. 160 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 161 162 Returns: 163 The DataLoader. 164 """ 165 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 166 dataset = get_han_seg_dataset(path, patch_shape, resize_inputs, download, **ds_kwargs) 167 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the HaN-Seg dataloader for head and neck organ segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.