torch_em.data.datasets.medical.palm

The PALM dataset contains annotations for optic disc and lesion segmentation in Fundus images.

The dataset is from the publication https://doi.org/10.1038/s41597-024-02911-2. Please cite it if you use this dataset for your research.

  1"""The PALM dataset contains annotations for optic disc and lesion segmentation in Fundus images.
  2
  3The dataset is from the publication https://doi.org/10.1038/s41597-024-02911-2.
  4Please cite it if you use this dataset for your research.
  5"""
  6
  7import os
  8import shutil
  9from glob import glob
 10from natsort import natsorted
 11from typing import Union, Tuple, Literal, List
 12
 13import imageio.v3 as imageio
 14
 15from torch.utils.data import Dataset, DataLoader
 16
 17import torch_em
 18
 19from .. import util
 20
 21
 22URL = "https://springernature.figshare.com/ndownloader/files/37786152"
 23CHECKSUM = "21cd568a00a50287370572ea81b50847085819bd2f732331ee9cdc6367e6cd1f"
 24
 25
 26def get_palm_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 27    """Download the PALM data.
 28
 29    Args:
 30        path: Filepath to a folder where the data is downloaded for further processing.
 31        download: Whether to download the data if it is not present.
 32
 33    Returns:
 34        Filepath where the data is downloaded.
 35    """
 36    data_dir = os.path.join(path, "PALM")
 37    if os.path.exists(data_dir):
 38        return data_dir
 39
 40    os.makedirs(path, exist_ok=True)
 41
 42    zip_path = os.path.join(path, "data.zip")
 43    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 44    util.unzip(zip_path=zip_path, dst=path)
 45
 46    shutil.rmtree(os.path.join(path, "__MACOSX"))
 47
 48    return data_dir
 49
 50
 51def _preprocess_labels(label_paths):
 52    neu_label_paths = [p.replace(".bmp", "_preprocessed.tif") for p in label_paths]
 53    for lpath, neu_lpath in zip(label_paths, neu_label_paths):
 54        if os.path.exists(neu_lpath):
 55            continue
 56
 57        label = imageio.imread(lpath)
 58        imageio.imwrite(neu_lpath, (label == 0).astype(int), compression="zlib")
 59
 60    return neu_label_paths
 61
 62
 63def get_palm_paths(
 64    path: Union[os.PathLike, str],
 65    split: Literal["Training", "Validation", "Testing"],
 66    label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc",
 67    download: bool = False
 68) -> Tuple[List[str], List[str]]:
 69    """Get paths to the PALM data.
 70
 71    Args:
 72        path: Filepath to a folder where the downloaded data will be saved.
 73        split: The choice of data split.
 74        label_choice: The choice of label masks.
 75        download: Whether to download the data if it is not present.
 76
 77    Returns:
 78        List of filepaths for the image data.
 79        List of filepaths for the label data.
 80    """
 81    data_dir = get_palm_data(path, download)
 82
 83    assert split in ["Training", "Validation", "Testing"], f"'{split}' is not a valid split."
 84
 85    if label_choice == "disc":
 86        ldir = "Disc Masks"
 87    elif label_choice == "atrophy_lesion":
 88        ldir = "Lesion Masks/Atrophy"
 89    elif label_choice == "detachment_lesion":
 90        ldir = "Lesion Masks/Detachment"
 91    else:
 92        raise ValueError(f"'{label_choice}' is not a valid choice of labels.")
 93
 94    label_paths = natsorted(glob(os.path.join(data_dir, split, ldir, "*.bmp")))
 95    label_paths = _preprocess_labels(label_paths)
 96
 97    raw_paths = [p.replace(ldir, "Images") for p in label_paths]
 98    raw_paths = [p.replace("_preprocessed.tif", ".jpg") for p in raw_paths]
 99
100    assert len(label_paths) == len(raw_paths)
101
102    return raw_paths, label_paths
103
104
105def get_palm_dataset(
106    path: Union[os.PathLike, str],
107    patch_shape: Tuple[int, int],
108    split: Literal["Training", "Validation", "Testing"],
109    label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc",
110    resize_inputs: bool = False,
111    download: bool = False,
112    **kwargs
113) -> Dataset:
114    """Get the PALM dataset for disc and lesion segmentation.
115
116    Args:
117        path: Filepath to a folder where the downloaded data will be saved.
118        patch_shape: The patch shape to use for training.
119        split: The choice of data split.
120        label_choice: The choice of label masks.
121        resize_inputs: Whether to resize the inputs to the expected patch shape.
122        download: Whether to download the data if it is not present.
123        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
124
125    Returns:
126        The segmentation dataset.
127    """
128    raw_paths, label_paths = get_palm_paths(path, split, label_choice, download)
129
130    if resize_inputs:
131        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
132        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
133            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
134        )
135
136    return torch_em.default_segmentation_dataset(
137        raw_paths=raw_paths,
138        raw_key=None,
139        label_paths=label_paths,
140        label_key=None,
141        patch_shape=patch_shape,
142        is_seg_dataset=False,
143        **kwargs
144    )
145
146
147def get_palm_loader(
148    path: Union[os.PathLike, str],
149    batch_size: int,
150    patch_shape: Tuple[int, int],
151    split: Literal["Training", "Validation", "Testing"],
152    label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc",
153    resize_inputs: bool = False,
154    download: bool = False,
155    **kwargs
156) -> DataLoader:
157    """Get the PALM dataloader for disc and lesion segmentation.
158
159    Args:
160        path: Filepath to a folder where the downloaded data will be saved.
161        batch_size: The batch size for training.
162        patch_shape: The patch shape to use for training.
163        split: The choice of data split.
164        label_choice: The choice of label masks.
165        resize_inputs: Whether to resize the inputs to the expected patch shape.
166        download: Whether to download the data if it is not present.
167        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
168
169    Returns:
170        The DataLoader.
171    """
172    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
173    dataset = get_palm_dataset(path, patch_shape, split, label_choice, resize_inputs, download, **ds_kwargs)
174    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://springernature.figshare.com/ndownloader/files/37786152'
CHECKSUM = '21cd568a00a50287370572ea81b50847085819bd2f732331ee9cdc6367e6cd1f'
def get_palm_data(path: Union[os.PathLike, str], download: bool = False) -> str:
27def get_palm_data(path: Union[os.PathLike, str], download: bool = False) -> str:
28    """Download the PALM data.
29
30    Args:
31        path: Filepath to a folder where the data is downloaded for further processing.
32        download: Whether to download the data if it is not present.
33
34    Returns:
35        Filepath where the data is downloaded.
36    """
37    data_dir = os.path.join(path, "PALM")
38    if os.path.exists(data_dir):
39        return data_dir
40
41    os.makedirs(path, exist_ok=True)
42
43    zip_path = os.path.join(path, "data.zip")
44    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
45    util.unzip(zip_path=zip_path, dst=path)
46
47    shutil.rmtree(os.path.join(path, "__MACOSX"))
48
49    return data_dir

Download the PALM data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_palm_paths( path: Union[os.PathLike, str], split: Literal['Training', 'Validation', 'Testing'], label_choice: Literal['disc', 'atrophy_lesion', 'detachment_lesion'] = 'disc', download: bool = False) -> Tuple[List[str], List[str]]:
 64def get_palm_paths(
 65    path: Union[os.PathLike, str],
 66    split: Literal["Training", "Validation", "Testing"],
 67    label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc",
 68    download: bool = False
 69) -> Tuple[List[str], List[str]]:
 70    """Get paths to the PALM data.
 71
 72    Args:
 73        path: Filepath to a folder where the downloaded data will be saved.
 74        split: The choice of data split.
 75        label_choice: The choice of label masks.
 76        download: Whether to download the data if it is not present.
 77
 78    Returns:
 79        List of filepaths for the image data.
 80        List of filepaths for the label data.
 81    """
 82    data_dir = get_palm_data(path, download)
 83
 84    assert split in ["Training", "Validation", "Testing"], f"'{split}' is not a valid split."
 85
 86    if label_choice == "disc":
 87        ldir = "Disc Masks"
 88    elif label_choice == "atrophy_lesion":
 89        ldir = "Lesion Masks/Atrophy"
 90    elif label_choice == "detachment_lesion":
 91        ldir = "Lesion Masks/Detachment"
 92    else:
 93        raise ValueError(f"'{label_choice}' is not a valid choice of labels.")
 94
 95    label_paths = natsorted(glob(os.path.join(data_dir, split, ldir, "*.bmp")))
 96    label_paths = _preprocess_labels(label_paths)
 97
 98    raw_paths = [p.replace(ldir, "Images") for p in label_paths]
 99    raw_paths = [p.replace("_preprocessed.tif", ".jpg") for p in raw_paths]
100
101    assert len(label_paths) == len(raw_paths)
102
103    return raw_paths, label_paths

Get paths to the PALM data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The choice of data split.
  • label_choice: The choice of label masks.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_palm_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['Training', 'Validation', 'Testing'], label_choice: Literal['disc', 'atrophy_lesion', 'detachment_lesion'] = 'disc', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
106def get_palm_dataset(
107    path: Union[os.PathLike, str],
108    patch_shape: Tuple[int, int],
109    split: Literal["Training", "Validation", "Testing"],
110    label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc",
111    resize_inputs: bool = False,
112    download: bool = False,
113    **kwargs
114) -> Dataset:
115    """Get the PALM dataset for disc and lesion segmentation.
116
117    Args:
118        path: Filepath to a folder where the downloaded data will be saved.
119        patch_shape: The patch shape to use for training.
120        split: The choice of data split.
121        label_choice: The choice of label masks.
122        resize_inputs: Whether to resize the inputs to the expected patch shape.
123        download: Whether to download the data if it is not present.
124        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
125
126    Returns:
127        The segmentation dataset.
128    """
129    raw_paths, label_paths = get_palm_paths(path, split, label_choice, download)
130
131    if resize_inputs:
132        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
133        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
134            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
135        )
136
137    return torch_em.default_segmentation_dataset(
138        raw_paths=raw_paths,
139        raw_key=None,
140        label_paths=label_paths,
141        label_key=None,
142        patch_shape=patch_shape,
143        is_seg_dataset=False,
144        **kwargs
145    )

Get the PALM dataset for disc and lesion segmentation.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • label_choice: The choice of label masks.
  • resize_inputs: Whether to resize the inputs to the expected patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_palm_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['Training', 'Validation', 'Testing'], label_choice: Literal['disc', 'atrophy_lesion', 'detachment_lesion'] = 'disc', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
148def get_palm_loader(
149    path: Union[os.PathLike, str],
150    batch_size: int,
151    patch_shape: Tuple[int, int],
152    split: Literal["Training", "Validation", "Testing"],
153    label_choice: Literal["disc", "atrophy_lesion", "detachment_lesion"] = "disc",
154    resize_inputs: bool = False,
155    download: bool = False,
156    **kwargs
157) -> DataLoader:
158    """Get the PALM dataloader for disc and lesion segmentation.
159
160    Args:
161        path: Filepath to a folder where the downloaded data will be saved.
162        batch_size: The batch size for training.
163        patch_shape: The patch shape to use for training.
164        split: The choice of data split.
165        label_choice: The choice of label masks.
166        resize_inputs: Whether to resize the inputs to the expected patch shape.
167        download: Whether to download the data if it is not present.
168        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
169
170    Returns:
171        The DataLoader.
172    """
173    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
174    dataset = get_palm_dataset(path, patch_shape, split, label_choice, resize_inputs, download, **ds_kwargs)
175    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the PALM dataloader for disc and lesion segmentation.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • label_choice: The choice of label masks.
  • resize_inputs: Whether to resize the inputs to the expected patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.