torch_em.data.datasets.medical.amd_sd
The AMD-SD dataset contains annotations for lesion segmentation.
This dataset is from the publication https://doi.org/10.1038/s41597-024-03844-6. Please cite it if you use this dataset for your research.
1"""The AMD-SD dataset contains annotations for lesion segmentation. 2 3This dataset is from the publication https://doi.org/10.1038/s41597-024-03844-6. 4Please cite it if you use this dataset for your research. 5""" 6 7import os 8from glob import glob 9from tqdm import tqdm 10from pathlib import Path 11from natsort import natsorted 12from typing import Union, Tuple, Literal, List 13 14import numpy as np 15import imageio.v3 as imageio 16 17from torch.utils.data import Dataset, DataLoader 18 19import torch_em 20 21from .. import util 22 23 24URL = "https://springernature.figshare.com/ndownloader/files/48777037" 25CHECKSUM = "16793aac36d814e2858362b4a3b9608e6f57120cf2227a81220407571b8fb359" 26 27MAPPING_IDS = { 28 (255, 0, 255): 1, # Pink: Ellipsoid zone (IS/OS) junction disruption 29 (0, 255, 0): 2, # Green: intraretinal fluid (IRF) 30 (255, 0, 0): 3, # Red: subretinal fluid (SRF) 31 (255, 255, 0): 4, # Yellow: subretinal hyperreflective material (SHRM) 32 (0, 0, 255): 5, # Blue: pigment epithelial detachment (PED) 33} 34 35 36def _preprocess_data(data_dir): 37 dirs = glob(os.path.join(data_dir, "images", "*")) 38 for dir in tqdm(dirs, desc="Preprocessing inputs"): 39 dname = os.path.basename(dir) 40 41 image_dir = os.path.join(data_dir, "preprocessed", dname, "images") 42 label_dir = os.path.join(data_dir, "preprocessed", dname, "labels") 43 os.makedirs(image_dir, exist_ok=True) 44 os.makedirs(label_dir, exist_ok=True) 45 46 for ipath in natsorted(glob(os.path.join(dir, "*.png"))): 47 image = imageio.imread(ipath) 48 image, label = image[:, :int(image.shape[1]/2), :], image[:, int(image.shape[1]/2):, :] 49 50 # Normalize the label intensities 51 label = (label / 255).round() * 255 52 53 # Map all the labels to one channel 54 segmentation = np.zeros(label.shape[:2], dtype=np.uint8) 55 for rgb, label_id in MAPPING_IDS.items(): 56 mask = np.all(label == np.array(rgb), axis=-1) 57 segmentation[mask] = label_id 58 59 fname = Path(os.path.basename(ipath)).with_suffix(".tif") 60 imageio.imwrite(os.path.join(image_dir, fname), image) 61 imageio.imwrite(os.path.join(label_dir, fname), segmentation) 62 63 64def get_amd_sd_data(path: Union[os.PathLike, str], download: bool = False): 65 """Download the AMD-SD dataset. 66 67 Args: 68 path: Filepath to a folder where the data is downloaded for further processing. 69 download: Whether to download the data if it is not present. 70 71 Returns: 72 Filepath where the data is downloaded. 73 """ 74 data_dir = os.path.join(path, "AMD-SD") 75 if os.path.exists(data_dir): 76 return data_dir 77 78 os.makedirs(path, exist_ok=True) 79 80 zip_path = os.path.join(path, "AMD-SD.zip") 81 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 82 util.unzip(zip_path=zip_path, dst=path) 83 84 _preprocess_data(data_dir) 85 86 return data_dir 87 88 89def get_amd_sd_paths( 90 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 91) -> Tuple[List[str], List[str]]: 92 """Get paths to the AMD-SD data. 93 94 Args: 95 path: Filepath to a folder where the data is downloaded for further processing. 96 split: The choice of data split. 97 download: Whether to download the data if it is not present. 98 99 Returns: 100 List of filepaths for the image data. 101 List of filepaths for the label data. 102 """ 103 data_dir = get_amd_sd_data(path, download) 104 105 patient_ids = natsorted(glob(os.path.join(data_dir, "preprocessed", "*"))) 106 if split == "train": 107 patient_ids = patient_ids[:100] 108 elif split == "val": 109 patient_ids = patient_ids[100:115] 110 elif split == "test": 111 patient_ids = patient_ids[115:] 112 else: 113 raise ValueError(f"'{split}' is not a valid split.") 114 115 raw_paths, label_paths = [], [] 116 for id in patient_ids: 117 raw_paths.extend(natsorted(glob(os.path.join(id, "images", "*.tif")))) 118 label_paths.extend(natsorted(glob(os.path.join(id, "labels", "*.tif")))) 119 120 assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0 121 122 return raw_paths, label_paths 123 124 125def get_amd_sd_dataset( 126 path: Union[os.PathLike, str], 127 patch_shape: Tuple[int, int], 128 split: Literal['train', 'val', 'test'], 129 resize_inputs: bool = False, 130 download: bool = False, 131 **kwargs 132) -> Dataset: 133 """Get the AMD-SD dataset for lesion segmentation. 134 135 Args: 136 path: Filepath to a folder where the data is downloaded for further processing. 137 patch_shape: The patch shape to use for training. 138 split: The choice of data split. 139 resize_inputs: Whether to resize the inputs. 140 download: Whether to download the data if it is not present. 141 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 142 143 Returns: 144 The segmentation dataset. 145 """ 146 raw_paths, label_paths = get_amd_sd_paths(path, split, download) 147 148 if resize_inputs: 149 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 150 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 151 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 152 ) 153 154 return torch_em.default_segmentation_dataset( 155 raw_paths=raw_paths, 156 raw_key=None, 157 label_paths=label_paths, 158 label_key=None, 159 is_seg_dataset=False, 160 patch_shape=patch_shape, 161 **kwargs 162 ) 163 164 165def get_amd_sd_loader( 166 path: Union[os.PathLike, str], 167 batch_size: int, 168 patch_shape: Tuple[int, int], 169 split: Literal['train', 'val', 'test'], 170 resize_inputs: bool = False, 171 download: bool = False, 172 **kwargs 173) -> DataLoader: 174 """Get the AMD-SD dataloader for lesion segmentation. 175 176 Args: 177 path: Filepath to a folder where the data is downloaded for further processing. 178 batch_size: The batch size for training. 179 patch_shape: The patch shape to use for training. 180 split: The choice of data split. 181 resize_inputs: Whether to resize the inputs. 182 download: Whether to download the data if it is not present. 183 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 184 185 Returns: 186 The DataLoader. 187 """ 188 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 189 dataset = get_amd_sd_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 190 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://springernature.figshare.com/ndownloader/files/48777037'
CHECKSUM =
'16793aac36d814e2858362b4a3b9608e6f57120cf2227a81220407571b8fb359'
MAPPING_IDS =
{(255, 0, 255): 1, (0, 255, 0): 2, (255, 0, 0): 3, (255, 255, 0): 4, (0, 0, 255): 5}
def
get_amd_sd_data(path: Union[os.PathLike, str], download: bool = False):
65def get_amd_sd_data(path: Union[os.PathLike, str], download: bool = False): 66 """Download the AMD-SD dataset. 67 68 Args: 69 path: Filepath to a folder where the data is downloaded for further processing. 70 download: Whether to download the data if it is not present. 71 72 Returns: 73 Filepath where the data is downloaded. 74 """ 75 data_dir = os.path.join(path, "AMD-SD") 76 if os.path.exists(data_dir): 77 return data_dir 78 79 os.makedirs(path, exist_ok=True) 80 81 zip_path = os.path.join(path, "AMD-SD.zip") 82 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 83 util.unzip(zip_path=zip_path, dst=path) 84 85 _preprocess_data(data_dir) 86 87 return data_dir
Download the AMD-SD dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_amd_sd_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
90def get_amd_sd_paths( 91 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 92) -> Tuple[List[str], List[str]]: 93 """Get paths to the AMD-SD data. 94 95 Args: 96 path: Filepath to a folder where the data is downloaded for further processing. 97 split: The choice of data split. 98 download: Whether to download the data if it is not present. 99 100 Returns: 101 List of filepaths for the image data. 102 List of filepaths for the label data. 103 """ 104 data_dir = get_amd_sd_data(path, download) 105 106 patient_ids = natsorted(glob(os.path.join(data_dir, "preprocessed", "*"))) 107 if split == "train": 108 patient_ids = patient_ids[:100] 109 elif split == "val": 110 patient_ids = patient_ids[100:115] 111 elif split == "test": 112 patient_ids = patient_ids[115:] 113 else: 114 raise ValueError(f"'{split}' is not a valid split.") 115 116 raw_paths, label_paths = [], [] 117 for id in patient_ids: 118 raw_paths.extend(natsorted(glob(os.path.join(id, "images", "*.tif")))) 119 label_paths.extend(natsorted(glob(os.path.join(id, "labels", "*.tif")))) 120 121 assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0 122 123 return raw_paths, label_paths
Get paths to the AMD-SD data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The choice of data split.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_amd_sd_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
126def get_amd_sd_dataset( 127 path: Union[os.PathLike, str], 128 patch_shape: Tuple[int, int], 129 split: Literal['train', 'val', 'test'], 130 resize_inputs: bool = False, 131 download: bool = False, 132 **kwargs 133) -> Dataset: 134 """Get the AMD-SD dataset for lesion segmentation. 135 136 Args: 137 path: Filepath to a folder where the data is downloaded for further processing. 138 patch_shape: The patch shape to use for training. 139 split: The choice of data split. 140 resize_inputs: Whether to resize the inputs. 141 download: Whether to download the data if it is not present. 142 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 143 144 Returns: 145 The segmentation dataset. 146 """ 147 raw_paths, label_paths = get_amd_sd_paths(path, split, download) 148 149 if resize_inputs: 150 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 151 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 152 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 153 ) 154 155 return torch_em.default_segmentation_dataset( 156 raw_paths=raw_paths, 157 raw_key=None, 158 label_paths=label_paths, 159 label_key=None, 160 is_seg_dataset=False, 161 patch_shape=patch_shape, 162 **kwargs 163 )
Get the AMD-SD dataset for lesion segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize the inputs.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_amd_sd_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
166def get_amd_sd_loader( 167 path: Union[os.PathLike, str], 168 batch_size: int, 169 patch_shape: Tuple[int, int], 170 split: Literal['train', 'val', 'test'], 171 resize_inputs: bool = False, 172 download: bool = False, 173 **kwargs 174) -> DataLoader: 175 """Get the AMD-SD dataloader for lesion segmentation. 176 177 Args: 178 path: Filepath to a folder where the data is downloaded for further processing. 179 batch_size: The batch size for training. 180 patch_shape: The patch shape to use for training. 181 split: The choice of data split. 182 resize_inputs: Whether to resize the inputs. 183 download: Whether to download the data if it is not present. 184 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 185 186 Returns: 187 The DataLoader. 188 """ 189 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 190 dataset = get_amd_sd_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 191 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the AMD-SD dataloader for lesion segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize the inputs.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.