torch_em.data.datasets.medical.ircadb
The IRCADb dataset contains annotations for liver segmentation (and several other organs and structures) in 3D CT scans.
The dataset is located at https://www.ircad.fr/research/data-sets/liver-segmentation-3d-ircadb-01/. This dataset is from the publication, referenced in the dataset link above. Please cite it if you use this dataset for your research.
1"""The IRCADb dataset contains annotations for liver segmentation (and several other organs and structures) 2in 3D CT scans. 3 4The dataset is located at https://www.ircad.fr/research/data-sets/liver-segmentation-3d-ircadb-01/. 5This dataset is from the publication, referenced in the dataset link above. 6Please cite it if you use this dataset for your research. 7""" 8 9import os 10from glob import glob 11from tqdm import tqdm 12from natsort import natsorted 13from typing import Union, Tuple, List, Literal, Optional 14 15import numpy as np 16 17from torch.utils.data import Dataset, DataLoader 18 19import torch_em 20 21from .. import util 22 23 24URL = "https://cloud.ircad.fr/index.php/s/JN3z7EynBiwYyjy/download" 25CHECKSUM = None # NOTE: checksums are mismatching for some reason with every new download instance :/ 26 27 28def _preprocess_inputs(path): 29 data_dir = os.path.join(path, "3Dircadb1") 30 patient_dirs = glob(os.path.join(data_dir, "*")) 31 32 # Store all preprocessed images in one place 33 preprocessed_dir = os.path.join(path, "data") 34 os.makedirs(preprocessed_dir, exist_ok=True) 35 36 # Let's extract all files per patient, preprocess them, store the final version and remove the zip files. 37 for pdir in tqdm(patient_dirs, desc="Preprocessing files"): 38 39 patient_name = os.path.basename(pdir) 40 41 # Get all zipfiles 42 masks_file = os.path.join(pdir, "MASKS_DICOM.zip") 43 patient_file = os.path.join(pdir, "PATIENT_DICOM.zip") 44 45 # Unzip all. 46 util.unzip(masks_file, pdir, remove=False) 47 util.unzip(patient_file, pdir, remove=False) 48 49 # Get all files and stack each slice together. 50 import pydicom as dicom 51 images = [dicom.dcmread(p).pixel_array for p in natsorted(glob(os.path.join(pdir, "PATIENT_DICOM", "*")))] 52 images = np.stack(images, axis=0) 53 54 # Get masks per slice per class. 55 masks, mask_names = [], [] 56 for mask_dir in glob(os.path.join(pdir, "MASKS_DICOM", "*")): 57 mask_names.append(os.path.basename(mask_dir)) 58 curr_mask = np.stack( 59 [dicom.dcmread(p).pixel_array for p in natsorted(glob(os.path.join(mask_dir, "*")))], axis=0, 60 ) 61 assert curr_mask.shape == images.shape, "The shapes for images and labels don't match." 62 masks.append(curr_mask) 63 64 # Store them in one place 65 import h5py 66 with h5py.File(os.path.join(preprocessed_dir, f"{patient_name}.h5"), "a") as f: 67 f.create_dataset("raw", data=images, compression="gzip") 68 # Add labels one by one 69 for name, _mask in zip(mask_names, masks): 70 f.create_dataset(f"labels/{name}", data=_mask, compression="gzip") 71 72 73def get_ircadb_data(path: Union[os.PathLike, str], download: bool = False) -> str: 74 """Download the IRCADb dataset. 75 76 Args: 77 path: Filepath to a folder where the data is downloaded for further processing. 78 download: Whether to download the data if it is not present. 79 80 Returns: 81 Filepath where the data is downloaded. 82 """ 83 data_dir = os.path.join(path, "data") 84 if os.path.exists(data_dir): 85 return data_dir 86 87 os.makedirs(path, exist_ok=True) 88 89 zip_path = os.path.join(path, "data.zip") 90 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 91 util.unzip(zip_path=zip_path, dst=path, remove=True) 92 93 _preprocess_inputs(path) 94 95 return data_dir 96 97 98def get_ircadb_paths( 99 path: Union[os.PathLike, str], split: Optional[Literal["train", "val", "test"]] = None, download: bool = False, 100) -> List[str]: 101 """Get paths to the IRCADb data. 102 103 Args: 104 path: Filepath to a folder where the data is downloaded for further processing. 105 download: Whether to download the data if it is not present. 106 107 Returns: 108 List of filepaths for the volumetric data. 109 """ 110 111 data_dir = get_ircadb_data(path, download) 112 volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5"))) 113 114 # Create splits on-the-fly, if desired. 115 if split is not None: 116 if split == "train": 117 volume_paths = volume_paths[:12] 118 elif split == "val": 119 volume_paths = volume_paths[12:15] 120 elif split == "test": 121 volume_paths = volume_paths[15:] 122 else: 123 raise ValueError(f"'{split}' is not a valid split.") 124 125 return volume_paths 126 127 128def get_ircadb_dataset( 129 path: Union[os.PathLike, str], 130 patch_shape: Tuple[int, ...], 131 label_choice: str, 132 split: Optional[Literal["train", "val", "test"]] = None, 133 resize_inputs: bool = False, 134 download: bool = False, 135 **kwargs 136) -> Dataset: 137 """Get the IRCADb dataset for liver (and other organ) segmentation. 138 139 Args: 140 path: Filepath to a folder where the data is downloaded for further processing. 141 patch_shape: The patch shape to use for training. 142 label_choice: The choice of labelled organs. 143 split: The choice of data split. 144 resize_inputs: Whether to resize the inputs to the expected patch shape. 145 download: Whether to download the data if it is not present. 146 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 147 148 Returns: 149 The segmentation dataset. 150 """ 151 volume_paths = get_ircadb_paths(path, split, download) 152 153 # Get the labels in the expected hierarchy name. 154 assert isinstance(label_choice, str) 155 label_choice = f"labels/{label_choice}" 156 157 # Get the parameters for resizing inputs 158 if resize_inputs: 159 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 160 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 161 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 162 ) 163 164 return torch_em.default_segmentation_dataset( 165 raw_paths=volume_paths, 166 raw_key="raw", 167 label_paths=volume_paths, 168 label_key=label_choice, 169 patch_shape=patch_shape, 170 **kwargs 171 ) 172 173 174def get_ircadb_loader( 175 path: Union[os.PathLike, str], 176 batch_size: int, 177 patch_shape: Tuple[int, ...], 178 label_choice: str, 179 split: Optional[Literal["train", "val", "test"]] = None, 180 resize_inputs: bool = False, 181 download: bool = False, 182 **kwargs 183) -> DataLoader: 184 """Get the IRCADb dataloader for liver (and other organ) segmentation. 185 186 Args: 187 path: Filepath to a folder where the data is downloaded for further processing. 188 batch_size: The batch size for training. 189 patch_shape: The patch shape to use for training. 190 label_choice: The choice of labelled organs. 191 split: The choice of data split. 192 resize_inputs: Whether to resize the inputs to the expected patch shape. 193 download: Whether to download the data if it is not present. 194 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 195 196 Returns: 197 The DataLoader. 198 """ 199 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 200 dataset = get_ircadb_dataset(path, patch_shape, label_choice, split, resize_inputs, download, **ds_kwargs) 201 return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
URL =
'https://cloud.ircad.fr/index.php/s/JN3z7EynBiwYyjy/download'
CHECKSUM =
None
def
get_ircadb_data(path: Union[os.PathLike, str], download: bool = False) -> str:
74def get_ircadb_data(path: Union[os.PathLike, str], download: bool = False) -> str: 75 """Download the IRCADb dataset. 76 77 Args: 78 path: Filepath to a folder where the data is downloaded for further processing. 79 download: Whether to download the data if it is not present. 80 81 Returns: 82 Filepath where the data is downloaded. 83 """ 84 data_dir = os.path.join(path, "data") 85 if os.path.exists(data_dir): 86 return data_dir 87 88 os.makedirs(path, exist_ok=True) 89 90 zip_path = os.path.join(path, "data.zip") 91 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 92 util.unzip(zip_path=zip_path, dst=path, remove=True) 93 94 _preprocess_inputs(path) 95 96 return data_dir
Download the IRCADb dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_ircadb_paths( path: Union[os.PathLike, str], split: Optional[Literal['train', 'val', 'test']] = None, download: bool = False) -> List[str]:
99def get_ircadb_paths( 100 path: Union[os.PathLike, str], split: Optional[Literal["train", "val", "test"]] = None, download: bool = False, 101) -> List[str]: 102 """Get paths to the IRCADb data. 103 104 Args: 105 path: Filepath to a folder where the data is downloaded for further processing. 106 download: Whether to download the data if it is not present. 107 108 Returns: 109 List of filepaths for the volumetric data. 110 """ 111 112 data_dir = get_ircadb_data(path, download) 113 volume_paths = natsorted(glob(os.path.join(data_dir, "*.h5"))) 114 115 # Create splits on-the-fly, if desired. 116 if split is not None: 117 if split == "train": 118 volume_paths = volume_paths[:12] 119 elif split == "val": 120 volume_paths = volume_paths[12:15] 121 elif split == "test": 122 volume_paths = volume_paths[15:] 123 else: 124 raise ValueError(f"'{split}' is not a valid split.") 125 126 return volume_paths
Get paths to the IRCADb data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the volumetric data.
def
get_ircadb_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], label_choice: str, split: Optional[Literal['train', 'val', 'test']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
129def get_ircadb_dataset( 130 path: Union[os.PathLike, str], 131 patch_shape: Tuple[int, ...], 132 label_choice: str, 133 split: Optional[Literal["train", "val", "test"]] = None, 134 resize_inputs: bool = False, 135 download: bool = False, 136 **kwargs 137) -> Dataset: 138 """Get the IRCADb dataset for liver (and other organ) segmentation. 139 140 Args: 141 path: Filepath to a folder where the data is downloaded for further processing. 142 patch_shape: The patch shape to use for training. 143 label_choice: The choice of labelled organs. 144 split: The choice of data split. 145 resize_inputs: Whether to resize the inputs to the expected patch shape. 146 download: Whether to download the data if it is not present. 147 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 148 149 Returns: 150 The segmentation dataset. 151 """ 152 volume_paths = get_ircadb_paths(path, split, download) 153 154 # Get the labels in the expected hierarchy name. 155 assert isinstance(label_choice, str) 156 label_choice = f"labels/{label_choice}" 157 158 # Get the parameters for resizing inputs 159 if resize_inputs: 160 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 161 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 162 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 163 ) 164 165 return torch_em.default_segmentation_dataset( 166 raw_paths=volume_paths, 167 raw_key="raw", 168 label_paths=volume_paths, 169 label_key=label_choice, 170 patch_shape=patch_shape, 171 **kwargs 172 )
Get the IRCADb dataset for liver (and other organ) segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- label_choice: The choice of labelled organs.
- split: The choice of data split.
- resize_inputs: Whether to resize the inputs to the expected patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_ircadb_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], label_choice: str, split: Optional[Literal['train', 'val', 'test']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
175def get_ircadb_loader( 176 path: Union[os.PathLike, str], 177 batch_size: int, 178 patch_shape: Tuple[int, ...], 179 label_choice: str, 180 split: Optional[Literal["train", "val", "test"]] = None, 181 resize_inputs: bool = False, 182 download: bool = False, 183 **kwargs 184) -> DataLoader: 185 """Get the IRCADb dataloader for liver (and other organ) segmentation. 186 187 Args: 188 path: Filepath to a folder where the data is downloaded for further processing. 189 batch_size: The batch size for training. 190 patch_shape: The patch shape to use for training. 191 label_choice: The choice of labelled organs. 192 split: The choice of data split. 193 resize_inputs: Whether to resize the inputs to the expected patch shape. 194 download: Whether to download the data if it is not present. 195 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 196 197 Returns: 198 The DataLoader. 199 """ 200 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 201 dataset = get_ircadb_dataset(path, patch_shape, label_choice, split, resize_inputs, download, **ds_kwargs) 202 return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
Get the IRCADb dataloader for liver (and other organ) segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- label_choice: The choice of labelled organs.
- split: The choice of data split.
- resize_inputs: Whether to resize the inputs to the expected patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.