torch_em.data.datasets.medical.amd_sd

The AMD-SD dataset contains annotations for lesion segmentation.

This dataset is from the publication https://doi.org/10.1038/s41597-024-03844-6. Please cite it if you use this dataset for your research.

  1"""The AMD-SD dataset contains annotations for lesion segmentation.
  2
  3This dataset is from the publication https://doi.org/10.1038/s41597-024-03844-6.
  4Please cite it if you use this dataset for your research.
  5"""
  6
  7import os
  8from glob import glob
  9from tqdm import tqdm
 10from pathlib import Path
 11from natsort import natsorted
 12from typing import Union, Tuple, Literal, List
 13
 14import numpy as np
 15import imageio.v3 as imageio
 16
 17from torch.utils.data import Dataset, DataLoader
 18
 19import torch_em
 20
 21from .. import util
 22
 23
 24URL = "https://springernature.figshare.com/ndownloader/files/48777037"
 25CHECKSUM = "16793aac36d814e2858362b4a3b9608e6f57120cf2227a81220407571b8fb359"
 26
 27MAPPING_IDS = {
 28    (255, 0, 255): 1,  # Pink: Ellipsoid zone (IS/OS) junction disruption
 29    (0, 255, 0): 2,    # Green: intraretinal fluid (IRF)
 30    (255, 0, 0): 3,    # Red: subretinal fluid (SRF)
 31    (255, 255, 0): 4,  # Yellow: subretinal hyperreflective material (SHRM)
 32    (0, 0, 255): 5,    # Blue: pigment epithelial detachment (PED)
 33}
 34
 35
 36def _preprocess_data(data_dir):
 37    dirs = glob(os.path.join(data_dir, "images", "*"))
 38    for dir in tqdm(dirs, desc="Preprocessing inputs"):
 39        dname = os.path.basename(dir)
 40
 41        image_dir = os.path.join(data_dir, "preprocessed", dname, "images")
 42        label_dir = os.path.join(data_dir, "preprocessed", dname, "labels")
 43        os.makedirs(image_dir, exist_ok=True)
 44        os.makedirs(label_dir, exist_ok=True)
 45
 46        for ipath in natsorted(glob(os.path.join(dir, "*.png"))):
 47            image = imageio.imread(ipath)
 48            image, label = image[:, :int(image.shape[1]/2), :], image[:, int(image.shape[1]/2):, :]
 49
 50            # Normalize the label intensities
 51            label = (label / 255).round() * 255
 52
 53            # Map all the labels to one channel
 54            segmentation = np.zeros(label.shape[:2], dtype=np.uint8)
 55            for rgb, label_id in MAPPING_IDS.items():
 56                mask = np.all(label == np.array(rgb), axis=-1)
 57                segmentation[mask] = label_id
 58
 59            fname = Path(os.path.basename(ipath)).with_suffix(".tif")
 60            imageio.imwrite(os.path.join(image_dir, fname), image)
 61            imageio.imwrite(os.path.join(label_dir, fname), segmentation)
 62
 63
 64def get_amd_sd_data(path: Union[os.PathLike, str], download: bool = False):
 65    """Download the AMD-SD dataset.
 66
 67    Args:
 68        path: Filepath to a folder where the data is downloaded for further processing.
 69        download: Whether to download the data if it is not present.
 70
 71    Returns:
 72        Filepath where the data is downloaded.
 73    """
 74    data_dir = os.path.join(path, "AMD-SD")
 75    if os.path.exists(data_dir):
 76        return data_dir
 77
 78    os.makedirs(path, exist_ok=True)
 79
 80    zip_path = os.path.join(path, "AMD-SD.zip")
 81    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 82    util.unzip(zip_path=zip_path, dst=path)
 83
 84    _preprocess_data(data_dir)
 85
 86    return data_dir
 87
 88
 89def get_amd_sd_paths(
 90    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 91) -> Tuple[List[str], List[str]]:
 92    """Get paths to the AMD-SD data.
 93
 94    Args:
 95        path: Filepath to a folder where the data is downloaded for further processing.
 96        split: The choice of data split.
 97        download: Whether to download the data if it is not present.
 98
 99    Returns:
100        List of filepaths for the image data.
101        List of filepaths for the label data.
102    """
103    data_dir = get_amd_sd_data(path, download)
104
105    patient_ids = natsorted(glob(os.path.join(data_dir, "preprocessed", "*")))
106    if split == "train":
107        patient_ids = patient_ids[:100]
108    elif split == "val":
109        patient_ids = patient_ids[100:115]
110    elif split == "test":
111        patient_ids = patient_ids[115:]
112    else:
113        raise ValueError(f"'{split}' is not a valid split.")
114
115    raw_paths, label_paths = [], []
116    for id in patient_ids:
117        raw_paths.extend(natsorted(glob(os.path.join(id, "images", "*.tif"))))
118        label_paths.extend(natsorted(glob(os.path.join(id, "labels", "*.tif"))))
119
120    assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0
121
122    return raw_paths, label_paths
123
124
125def get_amd_sd_dataset(
126    path: Union[os.PathLike, str],
127    patch_shape: Tuple[int, int],
128    split: Literal['train', 'val', 'test'],
129    resize_inputs: bool = False,
130    download: bool = False,
131    **kwargs
132) -> Dataset:
133    """Get the AMD-SD dataset for lesion segmentation.
134
135    Args:
136        path: Filepath to a folder where the data is downloaded for further processing.
137        patch_shape: The patch shape to use for training.
138        split: The choice of data split.
139        resize_inputs: Whether to resize the inputs.
140        download: Whether to download the data if it is not present.
141        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
142
143    Returns:
144        The segmentation dataset.
145    """
146    raw_paths, label_paths = get_amd_sd_paths(path, split, download)
147
148    if resize_inputs:
149        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
150        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
151            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
152        )
153
154    return torch_em.default_segmentation_dataset(
155        raw_paths=raw_paths,
156        raw_key=None,
157        label_paths=label_paths,
158        label_key=None,
159        is_seg_dataset=False,
160        patch_shape=patch_shape,
161        **kwargs
162    )
163
164
165def get_amd_sd_loader(
166    path: Union[os.PathLike, str],
167    batch_size: int,
168    patch_shape: Tuple[int, int],
169    split: Literal['train', 'val', 'test'],
170    resize_inputs: bool = False,
171    download: bool = False,
172    **kwargs
173) -> DataLoader:
174    """Get the AMD-SD dataloader for lesion segmentation.
175
176    Args:
177        path: Filepath to a folder where the data is downloaded for further processing.
178        batch_size: The batch size for training.
179        patch_shape: The patch shape to use for training.
180        split: The choice of data split.
181        resize_inputs: Whether to resize the inputs.
182        download: Whether to download the data if it is not present.
183        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
184
185    Returns:
186        The DataLoader.
187    """
188    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
189    dataset = get_amd_sd_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
190    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://springernature.figshare.com/ndownloader/files/48777037'
CHECKSUM = '16793aac36d814e2858362b4a3b9608e6f57120cf2227a81220407571b8fb359'
MAPPING_IDS = {(255, 0, 255): 1, (0, 255, 0): 2, (255, 0, 0): 3, (255, 255, 0): 4, (0, 0, 255): 5}
def get_amd_sd_data(path: Union[os.PathLike, str], download: bool = False):
65def get_amd_sd_data(path: Union[os.PathLike, str], download: bool = False):
66    """Download the AMD-SD dataset.
67
68    Args:
69        path: Filepath to a folder where the data is downloaded for further processing.
70        download: Whether to download the data if it is not present.
71
72    Returns:
73        Filepath where the data is downloaded.
74    """
75    data_dir = os.path.join(path, "AMD-SD")
76    if os.path.exists(data_dir):
77        return data_dir
78
79    os.makedirs(path, exist_ok=True)
80
81    zip_path = os.path.join(path, "AMD-SD.zip")
82    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
83    util.unzip(zip_path=zip_path, dst=path)
84
85    _preprocess_data(data_dir)
86
87    return data_dir

Download the AMD-SD dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_amd_sd_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
 90def get_amd_sd_paths(
 91    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 92) -> Tuple[List[str], List[str]]:
 93    """Get paths to the AMD-SD data.
 94
 95    Args:
 96        path: Filepath to a folder where the data is downloaded for further processing.
 97        split: The choice of data split.
 98        download: Whether to download the data if it is not present.
 99
100    Returns:
101        List of filepaths for the image data.
102        List of filepaths for the label data.
103    """
104    data_dir = get_amd_sd_data(path, download)
105
106    patient_ids = natsorted(glob(os.path.join(data_dir, "preprocessed", "*")))
107    if split == "train":
108        patient_ids = patient_ids[:100]
109    elif split == "val":
110        patient_ids = patient_ids[100:115]
111    elif split == "test":
112        patient_ids = patient_ids[115:]
113    else:
114        raise ValueError(f"'{split}' is not a valid split.")
115
116    raw_paths, label_paths = [], []
117    for id in patient_ids:
118        raw_paths.extend(natsorted(glob(os.path.join(id, "images", "*.tif"))))
119        label_paths.extend(natsorted(glob(os.path.join(id, "labels", "*.tif"))))
120
121    assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0
122
123    return raw_paths, label_paths

Get paths to the AMD-SD data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_amd_sd_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
126def get_amd_sd_dataset(
127    path: Union[os.PathLike, str],
128    patch_shape: Tuple[int, int],
129    split: Literal['train', 'val', 'test'],
130    resize_inputs: bool = False,
131    download: bool = False,
132    **kwargs
133) -> Dataset:
134    """Get the AMD-SD dataset for lesion segmentation.
135
136    Args:
137        path: Filepath to a folder where the data is downloaded for further processing.
138        patch_shape: The patch shape to use for training.
139        split: The choice of data split.
140        resize_inputs: Whether to resize the inputs.
141        download: Whether to download the data if it is not present.
142        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
143
144    Returns:
145        The segmentation dataset.
146    """
147    raw_paths, label_paths = get_amd_sd_paths(path, split, download)
148
149    if resize_inputs:
150        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
151        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
152            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
153        )
154
155    return torch_em.default_segmentation_dataset(
156        raw_paths=raw_paths,
157        raw_key=None,
158        label_paths=label_paths,
159        label_key=None,
160        is_seg_dataset=False,
161        patch_shape=patch_shape,
162        **kwargs
163    )

Get the AMD-SD dataset for lesion segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize the inputs.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_amd_sd_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
166def get_amd_sd_loader(
167    path: Union[os.PathLike, str],
168    batch_size: int,
169    patch_shape: Tuple[int, int],
170    split: Literal['train', 'val', 'test'],
171    resize_inputs: bool = False,
172    download: bool = False,
173    **kwargs
174) -> DataLoader:
175    """Get the AMD-SD dataloader for lesion segmentation.
176
177    Args:
178        path: Filepath to a folder where the data is downloaded for further processing.
179        batch_size: The batch size for training.
180        patch_shape: The patch shape to use for training.
181        split: The choice of data split.
182        resize_inputs: Whether to resize the inputs.
183        download: Whether to download the data if it is not present.
184        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
185
186    Returns:
187        The DataLoader.
188    """
189    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
190    dataset = get_amd_sd_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
191    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the AMD-SD dataloader for lesion segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize the inputs.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.