torch_em.data.datasets.medical.oimhs
The OIMHS dataset contains annotations for macular hole and retinal region segmentation in OCT images.
The dataset is from the publication https://doi.org/10.1038/s41597-023-02675-1. Please cite it if you use this dataset for your research.
1"""The OIMHS dataset contains annotations for macular hole and retinal region segmentation in OCT images. 2 3The dataset is from the publication https://doi.org/10.1038/s41597-023-02675-1. 4Please cite it if you use this dataset for your research. 5""" 6 7import os 8from glob import glob 9from tqdm import tqdm 10from pathlib import Path 11from natsort import natsorted 12from typing import Union, Tuple, Literal, List 13 14import json 15import numpy as np 16import imageio.v3 as imageio 17from sklearn.model_selection import train_test_split 18 19from torch.utils.data import Dataset, DataLoader 20 21import torch_em 22 23from .. import util 24 25 26URL = "https://springernature.figshare.com/ndownloader/files/42522673" 27CHECKSUM = "d93ba18964614eb9b0ba4b8dfee269efbb94ff27142e4b5ecf7cc86f3a1f9d80" 28 29LABEL_MAPS = { 30 (255, 255, 0): 1, # choroid 31 (0, 255, 0): 2, # retina 32 (0, 0, 255): 3, # intrarentinal cysts 33 (255, 0, 0): 4 # macular hole 34} 35 36 37def get_oimhs_data(path: Union[os.PathLike, str], download: bool = False) -> str: 38 """Download the OIMHS data. 39 40 Args: 41 path: Filepath to a folder where the data is downloaded for further processing. 42 download: Whether to download the data if it is not present. 43 44 Returns: 45 Filepath where the data is downloaded. 46 """ 47 data_dir = os.path.join(path, "data") 48 if os.path.exists(data_dir): 49 return data_dir 50 51 os.makedirs(path, exist_ok=True) 52 53 zip_path = os.path.join(path, "oimhs_dataset.zip") 54 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 55 util.unzip(zip_path=zip_path, dst=data_dir) 56 57 return data_dir 58 59 60def _create_splits(data_dir, split_file, test_fraction=0.2): 61 eye_dirs = [Path(edir).stem for edir in natsorted(glob(os.path.join(data_dir, "Images", "*")))] 62 63 # let's split the data 64 main_split, test_split = train_test_split(eye_dirs, test_size=test_fraction) 65 train_split, val_split = train_test_split(main_split, test_size=0.1) 66 67 decided_splits = {"train": train_split, "val": val_split, "test": test_split} 68 69 with open(split_file, "w") as f: 70 json.dump(decided_splits, f) 71 72 73def _get_per_split_dirs(data_dir, split_file, split): 74 with open(split_file, "r") as f: 75 data = json.load(f) 76 77 split_data = data[split] 78 split_data = [os.path.join(data_dir, "Images", sdata) for sdata in split_data] 79 return split_data 80 81 82def get_oimhs_paths( 83 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 84) -> Tuple[List[str], List[str]]: 85 """Get paths to the OIMHS data. 86 87 Args: 88 path: Filepath to a folder where the downloaded data will be saved. 89 split: The choice of data split. 90 download: Whether to download the data if it is not present. 91 92 Returns: 93 List of filepaths for the image data. 94 List of filepaths for the label data. 95 """ 96 data_dir = get_oimhs_data(path=path, download=download) 97 98 image_dir = os.path.join(data_dir, "preprocessed", "images") 99 gt_dir = os.path.join(data_dir, "preprocessed", "gt") 100 101 os.makedirs(image_dir, exist_ok=True) 102 os.makedirs(gt_dir, exist_ok=True) 103 104 split_file = os.path.join(path, "split_file.json") 105 if not os.path.exists(split_file): 106 _create_splits(data_dir, split_file) 107 108 eye_dirs = _get_per_split_dirs(data_dir=data_dir, split_file=split_file, split=split) 109 110 image_paths, gt_paths = [], [] 111 for eye_dir in tqdm(eye_dirs, desc="Preprocessing inputs"): 112 eye_id = os.path.split(eye_dir)[-1] 113 all_oct_scan_paths = natsorted(glob(os.path.join(eye_dir, "*.png"))) 114 for per_scan_path in all_oct_scan_paths: 115 scan_id = Path(per_scan_path).stem 116 117 image_path = os.path.join(image_dir, f"{eye_id}_{scan_id}.tif") 118 gt_path = os.path.join(gt_dir, f"{eye_id}_{scan_id}.tif") 119 120 image_paths.append(image_path) 121 gt_paths.append(gt_path) 122 123 if os.path.exists(image_path) and os.path.exists(gt_path): 124 continue 125 126 scan = imageio.imread(per_scan_path) 127 image, gt = scan[:, :512, :], scan[:, 512:, :] 128 129 instances = np.zeros(image.shape[:2]) 130 for lmap in LABEL_MAPS: 131 binary_map = (gt == lmap).all(axis=2) 132 instances[binary_map > 0] = LABEL_MAPS[lmap] 133 134 imageio.imwrite(image_path, image, compression="zlib") 135 imageio.imwrite(gt_path, instances, compression="zlib") 136 137 return image_paths, gt_paths 138 139 140def get_oimhs_dataset( 141 path: Union[os.PathLike, str], 142 patch_shape: Tuple[int, int], 143 split: Literal["train", "val", "test"], 144 resize_inputs: bool = False, 145 download: bool = False, 146 **kwargs 147) -> Dataset: 148 """Get the OIMHS dataset for segmentation of macular hole and retinal regions in OCT scans. 149 150 Args: 151 path: Filepath to a folder where the downloaded data will be saved. 152 patch_shape: The patch shape to use for training. 153 split: The choice of data split. 154 resize_inputs: Whether to resize the inputs to the expected patch shape. 155 download: Whether to download the data if it is not present. 156 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 157 158 Returns: 159 The segmentation dataset. 160 """ 161 image_paths, gt_paths = get_oimhs_paths(path, split, download) 162 163 if resize_inputs: 164 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 165 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 166 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 167 ) 168 169 return torch_em.default_segmentation_dataset( 170 raw_paths=image_paths, 171 raw_key=None, 172 label_paths=gt_paths, 173 label_key=None, 174 patch_shape=patch_shape, 175 is_seg_dataset=False, 176 **kwargs 177 ) 178 179 180def get_oimhs_loader( 181 path: Union[os.PathLike, str], 182 batch_size: int, 183 patch_shape: Tuple[int, int], 184 split: Literal["train", "val", "test"], 185 resize_inputs: bool = False, 186 download: bool = False, 187 **kwargs 188) -> DataLoader: 189 """Get the OIMHS dataloader for segmentation of macular hole and retinal regions in OCT scans. 190 191 Args: 192 path: Filepath to a folder where the downloaded data will be saved. 193 batch_size: The batch size for training. 194 patch_shape: The patch shape to use for training. 195 split: The choice of data split. 196 resize_inputs: Whether to resize the inputs to the expected patch shape. 197 download: Whether to download the data if it is not present. 198 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 199 200 Returns: 201 The DataLoader. 202 """ 203 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 204 dataset = get_oimhs_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 205 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://springernature.figshare.com/ndownloader/files/42522673'
CHECKSUM =
'd93ba18964614eb9b0ba4b8dfee269efbb94ff27142e4b5ecf7cc86f3a1f9d80'
LABEL_MAPS =
{(255, 255, 0): 1, (0, 255, 0): 2, (0, 0, 255): 3, (255, 0, 0): 4}
def
get_oimhs_data(path: Union[os.PathLike, str], download: bool = False) -> str:
38def get_oimhs_data(path: Union[os.PathLike, str], download: bool = False) -> str: 39 """Download the OIMHS data. 40 41 Args: 42 path: Filepath to a folder where the data is downloaded for further processing. 43 download: Whether to download the data if it is not present. 44 45 Returns: 46 Filepath where the data is downloaded. 47 """ 48 data_dir = os.path.join(path, "data") 49 if os.path.exists(data_dir): 50 return data_dir 51 52 os.makedirs(path, exist_ok=True) 53 54 zip_path = os.path.join(path, "oimhs_dataset.zip") 55 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 56 util.unzip(zip_path=zip_path, dst=data_dir) 57 58 return data_dir
Download the OIMHS data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_oimhs_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
83def get_oimhs_paths( 84 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 85) -> Tuple[List[str], List[str]]: 86 """Get paths to the OIMHS data. 87 88 Args: 89 path: Filepath to a folder where the downloaded data will be saved. 90 split: The choice of data split. 91 download: Whether to download the data if it is not present. 92 93 Returns: 94 List of filepaths for the image data. 95 List of filepaths for the label data. 96 """ 97 data_dir = get_oimhs_data(path=path, download=download) 98 99 image_dir = os.path.join(data_dir, "preprocessed", "images") 100 gt_dir = os.path.join(data_dir, "preprocessed", "gt") 101 102 os.makedirs(image_dir, exist_ok=True) 103 os.makedirs(gt_dir, exist_ok=True) 104 105 split_file = os.path.join(path, "split_file.json") 106 if not os.path.exists(split_file): 107 _create_splits(data_dir, split_file) 108 109 eye_dirs = _get_per_split_dirs(data_dir=data_dir, split_file=split_file, split=split) 110 111 image_paths, gt_paths = [], [] 112 for eye_dir in tqdm(eye_dirs, desc="Preprocessing inputs"): 113 eye_id = os.path.split(eye_dir)[-1] 114 all_oct_scan_paths = natsorted(glob(os.path.join(eye_dir, "*.png"))) 115 for per_scan_path in all_oct_scan_paths: 116 scan_id = Path(per_scan_path).stem 117 118 image_path = os.path.join(image_dir, f"{eye_id}_{scan_id}.tif") 119 gt_path = os.path.join(gt_dir, f"{eye_id}_{scan_id}.tif") 120 121 image_paths.append(image_path) 122 gt_paths.append(gt_path) 123 124 if os.path.exists(image_path) and os.path.exists(gt_path): 125 continue 126 127 scan = imageio.imread(per_scan_path) 128 image, gt = scan[:, :512, :], scan[:, 512:, :] 129 130 instances = np.zeros(image.shape[:2]) 131 for lmap in LABEL_MAPS: 132 binary_map = (gt == lmap).all(axis=2) 133 instances[binary_map > 0] = LABEL_MAPS[lmap] 134 135 imageio.imwrite(image_path, image, compression="zlib") 136 imageio.imwrite(gt_path, instances, compression="zlib") 137 138 return image_paths, gt_paths
Get paths to the OIMHS data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The choice of data split.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_oimhs_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
141def get_oimhs_dataset( 142 path: Union[os.PathLike, str], 143 patch_shape: Tuple[int, int], 144 split: Literal["train", "val", "test"], 145 resize_inputs: bool = False, 146 download: bool = False, 147 **kwargs 148) -> Dataset: 149 """Get the OIMHS dataset for segmentation of macular hole and retinal regions in OCT scans. 150 151 Args: 152 path: Filepath to a folder where the downloaded data will be saved. 153 patch_shape: The patch shape to use for training. 154 split: The choice of data split. 155 resize_inputs: Whether to resize the inputs to the expected patch shape. 156 download: Whether to download the data if it is not present. 157 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 158 159 Returns: 160 The segmentation dataset. 161 """ 162 image_paths, gt_paths = get_oimhs_paths(path, split, download) 163 164 if resize_inputs: 165 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 166 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 167 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 168 ) 169 170 return torch_em.default_segmentation_dataset( 171 raw_paths=image_paths, 172 raw_key=None, 173 label_paths=gt_paths, 174 label_key=None, 175 patch_shape=patch_shape, 176 is_seg_dataset=False, 177 **kwargs 178 )
Get the OIMHS dataset for segmentation of macular hole and retinal regions in OCT scans.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize the inputs to the expected patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_oimhs_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
181def get_oimhs_loader( 182 path: Union[os.PathLike, str], 183 batch_size: int, 184 patch_shape: Tuple[int, int], 185 split: Literal["train", "val", "test"], 186 resize_inputs: bool = False, 187 download: bool = False, 188 **kwargs 189) -> DataLoader: 190 """Get the OIMHS dataloader for segmentation of macular hole and retinal regions in OCT scans. 191 192 Args: 193 path: Filepath to a folder where the downloaded data will be saved. 194 batch_size: The batch size for training. 195 patch_shape: The patch shape to use for training. 196 split: The choice of data split. 197 resize_inputs: Whether to resize the inputs to the expected patch shape. 198 download: Whether to download the data if it is not present. 199 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 200 201 Returns: 202 The DataLoader. 203 """ 204 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 205 dataset = get_oimhs_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 206 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the OIMHS dataloader for segmentation of macular hole and retinal regions in OCT scans.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize the inputs to the expected patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.