torch_em.data.datasets.medical.palm
The PALM dataset contains annotations for optic disc and lesion segmentation in Fundus images.
The dataset is from the publication https://doi.org/10.1038/s41597-024-02911-2. Please cite it if you use this dataset for your research.
1"""The PALM dataset contains annotations for optic disc and lesion segmentation in Fundus images. 2 3The dataset is from the publication https://doi.org/10.1038/s41597-024-02911-2. 4Please cite it if you use this dataset for your research. 5""" 6 7import os 8import shutil 9from glob import glob 10from natsort import natsorted 11from typing import Union, Tuple, Literal, List 12 13import imageio.v3 as imageio 14 15from torch.utils.data import Dataset, DataLoader 16 17import torch_em 18 19from .. import util 20 21 22URL = "https://springernature.figshare.com/ndownloader/files/37786152" 23CHECKSUM = "21cd568a00a50287370572ea81b50847085819bd2f732331ee9cdc6367e6cd1f" 24 25 26def get_palm_data(path: Union[os.PathLike, str], download: bool = False) -> str: 27 """Download the PALM data. 28 29 Args: 30 path: Filepath to a folder where the data is downloaded for further processing. 31 download: Whether to download the data if it is not present. 32 33 Returns: 34 Filepath where the data is downloaded. 35 """ 36 data_dir = os.path.join(path, "PALM") 37 if os.path.exists(data_dir): 38 return data_dir 39 40 os.makedirs(path, exist_ok=True) 41 42 zip_path = os.path.join(path, "data.zip") 43 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 44 util.unzip(zip_path=zip_path, dst=path) 45 46 shutil.rmtree(os.path.join(path, "__MACOSX")) 47 48 return data_dir 49 50 51def _preprocess_labels(label_paths): 52 neu_label_paths = [p.replace(".bmp", "_preprocessed.tif") for p in label_paths] 53 for lpath, neu_lpath in zip(label_paths, neu_label_paths): 54 if os.path.exists(neu_lpath): 55 continue 56 57 label = imageio.imread(lpath) 58 imageio.imwrite(neu_lpath, (label == 0).astype(int), compression="zlib") 59 60 return neu_label_paths 61 62 63def get_palm_paths( 64 path: Union[os.PathLike, str], 65 split: Literal["Training", "Validation", "Testing"], 66 label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc", 67 download: bool = False 68) -> Tuple[List[str], List[str]]: 69 """Get paths to the PALM data. 70 71 Args: 72 path: Filepath to a folder where the downloaded data will be saved. 73 split: The choice of data split. 74 label_choice: The choice of label masks. 75 download: Whether to download the data if it is not present. 76 77 Returns: 78 List of filepaths for the image data. 79 List of filepaths for the label data. 80 """ 81 data_dir = get_palm_data(path, download) 82 83 assert split in ["Training", "Validation", "Testing"], f"'{split}' is not a valid split." 84 85 if label_choice == "disc": 86 ldir = "Disc Masks" 87 elif label_choice == "atrophy_lesion": 88 ldir = "Lesion Masks/Atrophy" 89 elif label_choice == "detachment_lesion": 90 ldir = "Lesion Masks/Detachment" 91 else: 92 raise ValueError(f"'{label_choice}' is not a valid choice of labels.") 93 94 label_paths = natsorted(glob(os.path.join(data_dir, split, ldir, "*.bmp"))) 95 label_paths = _preprocess_labels(label_paths) 96 97 raw_paths = [p.replace(ldir, "Images") for p in label_paths] 98 raw_paths = [p.replace("_preprocessed.tif", ".jpg") for p in raw_paths] 99 100 assert len(label_paths) == len(raw_paths) 101 102 return raw_paths, label_paths 103 104 105def get_palm_dataset( 106 path: Union[os.PathLike, str], 107 patch_shape: Tuple[int, int], 108 split: Literal["Training", "Validation", "Testing"], 109 label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc", 110 resize_inputs: bool = False, 111 download: bool = False, 112 **kwargs 113) -> Dataset: 114 """Get the PALM dataset for disc and lesion segmentation. 115 116 Args: 117 path: Filepath to a folder where the downloaded data will be saved. 118 patch_shape: The patch shape to use for training. 119 split: The choice of data split. 120 label_choice: The choice of label masks. 121 resize_inputs: Whether to resize the inputs to the expected patch shape. 122 download: Whether to download the data if it is not present. 123 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 124 125 Returns: 126 The segmentation dataset. 127 """ 128 raw_paths, label_paths = get_palm_paths(path, split, label_choice, download) 129 130 if resize_inputs: 131 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 132 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 133 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 134 ) 135 136 return torch_em.default_segmentation_dataset( 137 raw_paths=raw_paths, 138 raw_key=None, 139 label_paths=label_paths, 140 label_key=None, 141 patch_shape=patch_shape, 142 is_seg_dataset=False, 143 **kwargs 144 ) 145 146 147def get_palm_loader( 148 path: Union[os.PathLike, str], 149 batch_size: int, 150 patch_shape: Tuple[int, int], 151 split: Literal["Training", "Validation", "Testing"], 152 label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc", 153 resize_inputs: bool = False, 154 download: bool = False, 155 **kwargs 156) -> DataLoader: 157 """Get the PALM dataloader for disc and lesion segmentation. 158 159 Args: 160 path: Filepath to a folder where the downloaded data will be saved. 161 batch_size: The batch size for training. 162 patch_shape: The patch shape to use for training. 163 split: The choice of data split. 164 label_choice: The choice of label masks. 165 resize_inputs: Whether to resize the inputs to the expected patch shape. 166 download: Whether to download the data if it is not present. 167 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 168 169 Returns: 170 The DataLoader. 171 """ 172 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 173 dataset = get_palm_dataset(path, patch_shape, split, label_choice, resize_inputs, download, **ds_kwargs) 174 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://springernature.figshare.com/ndownloader/files/37786152'
CHECKSUM =
'21cd568a00a50287370572ea81b50847085819bd2f732331ee9cdc6367e6cd1f'
def
get_palm_data(path: Union[os.PathLike, str], download: bool = False) -> str:
27def get_palm_data(path: Union[os.PathLike, str], download: bool = False) -> str: 28 """Download the PALM data. 29 30 Args: 31 path: Filepath to a folder where the data is downloaded for further processing. 32 download: Whether to download the data if it is not present. 33 34 Returns: 35 Filepath where the data is downloaded. 36 """ 37 data_dir = os.path.join(path, "PALM") 38 if os.path.exists(data_dir): 39 return data_dir 40 41 os.makedirs(path, exist_ok=True) 42 43 zip_path = os.path.join(path, "data.zip") 44 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 45 util.unzip(zip_path=zip_path, dst=path) 46 47 shutil.rmtree(os.path.join(path, "__MACOSX")) 48 49 return data_dir
Download the PALM data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_palm_paths( path: Union[os.PathLike, str], split: Literal['Training', 'Validation', 'Testing'], label_choice: Literal['disc', 'atrophy_lesion', 'detachment_lesion'] = 'disc', download: bool = False) -> Tuple[List[str], List[str]]:
64def get_palm_paths( 65 path: Union[os.PathLike, str], 66 split: Literal["Training", "Validation", "Testing"], 67 label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc", 68 download: bool = False 69) -> Tuple[List[str], List[str]]: 70 """Get paths to the PALM data. 71 72 Args: 73 path: Filepath to a folder where the downloaded data will be saved. 74 split: The choice of data split. 75 label_choice: The choice of label masks. 76 download: Whether to download the data if it is not present. 77 78 Returns: 79 List of filepaths for the image data. 80 List of filepaths for the label data. 81 """ 82 data_dir = get_palm_data(path, download) 83 84 assert split in ["Training", "Validation", "Testing"], f"'{split}' is not a valid split." 85 86 if label_choice == "disc": 87 ldir = "Disc Masks" 88 elif label_choice == "atrophy_lesion": 89 ldir = "Lesion Masks/Atrophy" 90 elif label_choice == "detachment_lesion": 91 ldir = "Lesion Masks/Detachment" 92 else: 93 raise ValueError(f"'{label_choice}' is not a valid choice of labels.") 94 95 label_paths = natsorted(glob(os.path.join(data_dir, split, ldir, "*.bmp"))) 96 label_paths = _preprocess_labels(label_paths) 97 98 raw_paths = [p.replace(ldir, "Images") for p in label_paths] 99 raw_paths = [p.replace("_preprocessed.tif", ".jpg") for p in raw_paths] 100 101 assert len(label_paths) == len(raw_paths) 102 103 return raw_paths, label_paths
Get paths to the PALM data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The choice of data split.
- label_choice: The choice of label masks.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_palm_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['Training', 'Validation', 'Testing'], label_choice: Literal['disc', 'atrophy_lesion', 'detachment_lesion'] = 'disc', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
106def get_palm_dataset( 107 path: Union[os.PathLike, str], 108 patch_shape: Tuple[int, int], 109 split: Literal["Training", "Validation", "Testing"], 110 label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc", 111 resize_inputs: bool = False, 112 download: bool = False, 113 **kwargs 114) -> Dataset: 115 """Get the PALM dataset for disc and lesion segmentation. 116 117 Args: 118 path: Filepath to a folder where the downloaded data will be saved. 119 patch_shape: The patch shape to use for training. 120 split: The choice of data split. 121 label_choice: The choice of label masks. 122 resize_inputs: Whether to resize the inputs to the expected patch shape. 123 download: Whether to download the data if it is not present. 124 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 125 126 Returns: 127 The segmentation dataset. 128 """ 129 raw_paths, label_paths = get_palm_paths(path, split, label_choice, download) 130 131 if resize_inputs: 132 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 133 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 134 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 135 ) 136 137 return torch_em.default_segmentation_dataset( 138 raw_paths=raw_paths, 139 raw_key=None, 140 label_paths=label_paths, 141 label_key=None, 142 patch_shape=patch_shape, 143 is_seg_dataset=False, 144 **kwargs 145 )
Get the PALM dataset for disc and lesion segmentation.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- label_choice: The choice of label masks.
- resize_inputs: Whether to resize the inputs to the expected patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_palm_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['Training', 'Validation', 'Testing'], label_choice: Literal['disc', 'atrophy_lesion', 'detachment_lesion'] = 'disc', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
148def get_palm_loader( 149 path: Union[os.PathLike, str], 150 batch_size: int, 151 patch_shape: Tuple[int, int], 152 split: Literal["Training", "Validation", "Testing"], 153 label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc", 154 resize_inputs: bool = False, 155 download: bool = False, 156 **kwargs 157) -> DataLoader: 158 """Get the PALM dataloader for disc and lesion segmentation. 159 160 Args: 161 path: Filepath to a folder where the downloaded data will be saved. 162 batch_size: The batch size for training. 163 patch_shape: The patch shape to use for training. 164 split: The choice of data split. 165 label_choice: The choice of label masks. 166 resize_inputs: Whether to resize the inputs to the expected patch shape. 167 download: Whether to download the data if it is not present. 168 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 169 170 Returns: 171 The DataLoader. 172 """ 173 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 174 dataset = get_palm_dataset(path, patch_shape, split, label_choice, resize_inputs, download, **ds_kwargs) 175 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the PALM dataloader for disc and lesion segmentation.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- label_choice: The choice of label masks.
- resize_inputs: Whether to resize the inputs to the expected patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.