torch_em.data.datasets.medical.oimhs

The OIMHS dataset contains annotations for macular hole and retinal region segmentation in OCT images.

The dataset is from the publication https://doi.org/10.1038/s41597-023-02675-1. Please cite it if you use this dataset for your research.

  1"""The OIMHS dataset contains annotations for macular hole and retinal region segmentation in OCT images.
  2
  3The dataset is from the publication https://doi.org/10.1038/s41597-023-02675-1.
  4Please cite it if you use this dataset for your research.
  5"""
  6
  7import os
  8from glob import glob
  9from tqdm import tqdm
 10from pathlib import Path
 11from natsort import natsorted
 12from typing import Union, Tuple, Literal, List
 13
 14import json
 15import numpy as np
 16import imageio.v3 as imageio
 17from sklearn.model_selection import train_test_split
 18
 19from torch.utils.data import Dataset, DataLoader
 20
 21import torch_em
 22
 23from .. import util
 24
 25
 26URL = "https://springernature.figshare.com/ndownloader/files/42522673"
 27CHECKSUM = "d93ba18964614eb9b0ba4b8dfee269efbb94ff27142e4b5ecf7cc86f3a1f9d80"
 28
 29LABEL_MAPS = {
 30    (255, 255, 0): 1,  # choroid
 31    (0, 255, 0): 2,  # retina
 32    (0, 0, 255): 3,  # intrarentinal cysts
 33    (255, 0, 0): 4  # macular hole
 34}
 35
 36
 37def get_oimhs_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 38    """Download the OIMHS data.
 39
 40    Args:
 41        path: Filepath to a folder where the data is downloaded for further processing.
 42        download: Whether to download the data if it is not present.
 43
 44    Returns:
 45        Filepath where the data is downloaded.
 46    """
 47    data_dir = os.path.join(path, "data")
 48    if os.path.exists(data_dir):
 49        return data_dir
 50
 51    os.makedirs(path, exist_ok=True)
 52
 53    zip_path = os.path.join(path, "oimhs_dataset.zip")
 54    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 55    util.unzip(zip_path=zip_path, dst=data_dir)
 56
 57    return data_dir
 58
 59
 60def _create_splits(data_dir, split_file, test_fraction=0.2):
 61    eye_dirs = [Path(edir).stem for edir in natsorted(glob(os.path.join(data_dir, "Images", "*")))]
 62
 63    # let's split the data
 64    main_split, test_split = train_test_split(eye_dirs, test_size=test_fraction)
 65    train_split, val_split = train_test_split(main_split, test_size=0.1)
 66
 67    decided_splits = {"train": train_split, "val": val_split, "test": test_split}
 68
 69    with open(split_file, "w") as f:
 70        json.dump(decided_splits, f)
 71
 72
 73def _get_per_split_dirs(data_dir, split_file, split):
 74    with open(split_file, "r") as f:
 75        data = json.load(f)
 76
 77    split_data = data[split]
 78    split_data = [os.path.join(data_dir, "Images", sdata) for sdata in split_data]
 79    return split_data
 80
 81
 82def get_oimhs_paths(
 83    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 84) -> Tuple[List[str], List[str]]:
 85    """Get paths to the OIMHS data.
 86
 87    Args:
 88        path: Filepath to a folder where the downloaded data will be saved.
 89        split: The choice of data split.
 90        download: Whether to download the data if it is not present.
 91
 92    Returns:
 93        List of filepaths for the image data.
 94        List of filepaths for the label data.
 95    """
 96    data_dir = get_oimhs_data(path=path, download=download)
 97
 98    image_dir = os.path.join(data_dir, "preprocessed", "images")
 99    gt_dir = os.path.join(data_dir, "preprocessed", "gt")
100
101    os.makedirs(image_dir, exist_ok=True)
102    os.makedirs(gt_dir, exist_ok=True)
103
104    split_file = os.path.join(path, "split_file.json")
105    if not os.path.exists(split_file):
106        _create_splits(data_dir, split_file)
107
108    eye_dirs = _get_per_split_dirs(data_dir=data_dir, split_file=split_file, split=split)
109
110    image_paths, gt_paths = [], []
111    for eye_dir in tqdm(eye_dirs, desc="Preprocessing inputs"):
112        eye_id = os.path.split(eye_dir)[-1]
113        all_oct_scan_paths = natsorted(glob(os.path.join(eye_dir, "*.png")))
114        for per_scan_path in all_oct_scan_paths:
115            scan_id = Path(per_scan_path).stem
116
117            image_path = os.path.join(image_dir, f"{eye_id}_{scan_id}.tif")
118            gt_path = os.path.join(gt_dir, f"{eye_id}_{scan_id}.tif")
119
120            image_paths.append(image_path)
121            gt_paths.append(gt_path)
122
123            if os.path.exists(image_path) and os.path.exists(gt_path):
124                continue
125
126            scan = imageio.imread(per_scan_path)
127            image, gt = scan[:, :512, :], scan[:, 512:, :]
128
129            instances = np.zeros(image.shape[:2])
130            for lmap in LABEL_MAPS:
131                binary_map = (gt == lmap).all(axis=2)
132                instances[binary_map > 0] = LABEL_MAPS[lmap]
133
134            imageio.imwrite(image_path, image, compression="zlib")
135            imageio.imwrite(gt_path, instances, compression="zlib")
136
137    return image_paths, gt_paths
138
139
140def get_oimhs_dataset(
141    path: Union[os.PathLike, str],
142    patch_shape: Tuple[int, int],
143    split: Literal["train", "val", "test"],
144    resize_inputs: bool = False,
145    download: bool = False,
146    **kwargs
147) -> Dataset:
148    """Get the OIMHS dataset for segmentation of macular hole and retinal regions in OCT scans.
149
150    Args:
151        path: Filepath to a folder where the downloaded data will be saved.
152        patch_shape: The patch shape to use for training.
153        split: The choice of data split.
154        resize_inputs: Whether to resize the inputs to the expected patch shape.
155        download: Whether to download the data if it is not present.
156        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
157
158    Returns:
159        The segmentation dataset.
160    """
161    image_paths, gt_paths = get_oimhs_paths(path, split, download)
162
163    if resize_inputs:
164        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
165        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
166            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
167        )
168
169    return torch_em.default_segmentation_dataset(
170        raw_paths=image_paths,
171        raw_key=None,
172        label_paths=gt_paths,
173        label_key=None,
174        patch_shape=patch_shape,
175        is_seg_dataset=False,
176        **kwargs
177    )
178
179
180def get_oimhs_loader(
181    path: Union[os.PathLike, str],
182    batch_size: int,
183    patch_shape: Tuple[int, int],
184    split: Literal["train", "val", "test"],
185    resize_inputs: bool = False,
186    download: bool = False,
187    **kwargs
188) -> DataLoader:
189    """Get the OIMHS dataloader for segmentation of macular hole and retinal regions in OCT scans.
190
191    Args:
192        path: Filepath to a folder where the downloaded data will be saved.
193        batch_size: The batch size for training.
194        patch_shape: The patch shape to use for training.
195        split: The choice of data split.
196        resize_inputs: Whether to resize the inputs to the expected patch shape.
197        download: Whether to download the data if it is not present.
198        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
199
200    Returns:
201        The DataLoader.
202    """
203    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
204    dataset = get_oimhs_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
205    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://springernature.figshare.com/ndownloader/files/42522673'
CHECKSUM = 'd93ba18964614eb9b0ba4b8dfee269efbb94ff27142e4b5ecf7cc86f3a1f9d80'
LABEL_MAPS = {(255, 255, 0): 1, (0, 255, 0): 2, (0, 0, 255): 3, (255, 0, 0): 4}
def get_oimhs_data(path: Union[os.PathLike, str], download: bool = False) -> str:
38def get_oimhs_data(path: Union[os.PathLike, str], download: bool = False) -> str:
39    """Download the OIMHS data.
40
41    Args:
42        path: Filepath to a folder where the data is downloaded for further processing.
43        download: Whether to download the data if it is not present.
44
45    Returns:
46        Filepath where the data is downloaded.
47    """
48    data_dir = os.path.join(path, "data")
49    if os.path.exists(data_dir):
50        return data_dir
51
52    os.makedirs(path, exist_ok=True)
53
54    zip_path = os.path.join(path, "oimhs_dataset.zip")
55    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
56    util.unzip(zip_path=zip_path, dst=data_dir)
57
58    return data_dir

Download the OIMHS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_oimhs_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
 83def get_oimhs_paths(
 84    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 85) -> Tuple[List[str], List[str]]:
 86    """Get paths to the OIMHS data.
 87
 88    Args:
 89        path: Filepath to a folder where the downloaded data will be saved.
 90        split: The choice of data split.
 91        download: Whether to download the data if it is not present.
 92
 93    Returns:
 94        List of filepaths for the image data.
 95        List of filepaths for the label data.
 96    """
 97    data_dir = get_oimhs_data(path=path, download=download)
 98
 99    image_dir = os.path.join(data_dir, "preprocessed", "images")
100    gt_dir = os.path.join(data_dir, "preprocessed", "gt")
101
102    os.makedirs(image_dir, exist_ok=True)
103    os.makedirs(gt_dir, exist_ok=True)
104
105    split_file = os.path.join(path, "split_file.json")
106    if not os.path.exists(split_file):
107        _create_splits(data_dir, split_file)
108
109    eye_dirs = _get_per_split_dirs(data_dir=data_dir, split_file=split_file, split=split)
110
111    image_paths, gt_paths = [], []
112    for eye_dir in tqdm(eye_dirs, desc="Preprocessing inputs"):
113        eye_id = os.path.split(eye_dir)[-1]
114        all_oct_scan_paths = natsorted(glob(os.path.join(eye_dir, "*.png")))
115        for per_scan_path in all_oct_scan_paths:
116            scan_id = Path(per_scan_path).stem
117
118            image_path = os.path.join(image_dir, f"{eye_id}_{scan_id}.tif")
119            gt_path = os.path.join(gt_dir, f"{eye_id}_{scan_id}.tif")
120
121            image_paths.append(image_path)
122            gt_paths.append(gt_path)
123
124            if os.path.exists(image_path) and os.path.exists(gt_path):
125                continue
126
127            scan = imageio.imread(per_scan_path)
128            image, gt = scan[:, :512, :], scan[:, 512:, :]
129
130            instances = np.zeros(image.shape[:2])
131            for lmap in LABEL_MAPS:
132                binary_map = (gt == lmap).all(axis=2)
133                instances[binary_map > 0] = LABEL_MAPS[lmap]
134
135            imageio.imwrite(image_path, image, compression="zlib")
136            imageio.imwrite(gt_path, instances, compression="zlib")
137
138    return image_paths, gt_paths

Get paths to the OIMHS data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The choice of data split.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_oimhs_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
141def get_oimhs_dataset(
142    path: Union[os.PathLike, str],
143    patch_shape: Tuple[int, int],
144    split: Literal["train", "val", "test"],
145    resize_inputs: bool = False,
146    download: bool = False,
147    **kwargs
148) -> Dataset:
149    """Get the OIMHS dataset for segmentation of macular hole and retinal regions in OCT scans.
150
151    Args:
152        path: Filepath to a folder where the downloaded data will be saved.
153        patch_shape: The patch shape to use for training.
154        split: The choice of data split.
155        resize_inputs: Whether to resize the inputs to the expected patch shape.
156        download: Whether to download the data if it is not present.
157        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
158
159    Returns:
160        The segmentation dataset.
161    """
162    image_paths, gt_paths = get_oimhs_paths(path, split, download)
163
164    if resize_inputs:
165        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
166        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
167            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
168        )
169
170    return torch_em.default_segmentation_dataset(
171        raw_paths=image_paths,
172        raw_key=None,
173        label_paths=gt_paths,
174        label_key=None,
175        patch_shape=patch_shape,
176        is_seg_dataset=False,
177        **kwargs
178    )

Get the OIMHS dataset for segmentation of macular hole and retinal regions in OCT scans.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize the inputs to the expected patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_oimhs_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
181def get_oimhs_loader(
182    path: Union[os.PathLike, str],
183    batch_size: int,
184    patch_shape: Tuple[int, int],
185    split: Literal["train", "val", "test"],
186    resize_inputs: bool = False,
187    download: bool = False,
188    **kwargs
189) -> DataLoader:
190    """Get the OIMHS dataloader for segmentation of macular hole and retinal regions in OCT scans.
191
192    Args:
193        path: Filepath to a folder where the downloaded data will be saved.
194        batch_size: The batch size for training.
195        patch_shape: The patch shape to use for training.
196        split: The choice of data split.
197        resize_inputs: Whether to resize the inputs to the expected patch shape.
198        download: Whether to download the data if it is not present.
199        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
200
201    Returns:
202        The DataLoader.
203    """
204    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
205    dataset = get_oimhs_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
206    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the OIMHS dataloader for segmentation of macular hole and retinal regions in OCT scans.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize the inputs to the expected patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.