torch_em.data.datasets.medical.drive

The DRIVE dataset contains annotations for retinal vessel segmentation in fundus images.

This dataset is from the "DRIVE" challenge: https://drive.grand-challenge.org/. The dataset is from the publication https://doi.org/10.1109/TMI.2004.825627. Please cite them if you use this dataset for your research.

  1"""The DRIVE dataset contains annotations for retinal vessel segmentation in
  2fundus images.
  3
  4This dataset is from the "DRIVE" challenge: https://drive.grand-challenge.org/.
  5The dataset is from the publication https://doi.org/10.1109/TMI.2004.825627.
  6Please cite them if you use this dataset for your research.
  7"""
  8
  9import os
 10from glob import glob
 11from pathlib import Path
 12from typing import Union, Tuple, Literal, List
 13
 14import imageio.v3 as imageio
 15
 16from torch.utils.data import Dataset, DataLoader
 17
 18import torch_em
 19
 20from .. import util
 21
 22
 23URL = {
 24    "train": "https://www.dropbox.com/sh/z4hbbzqai0ilqht/AADp_8oefNFs2bjC2kzl2_Fqa/training.zip?dl=1",
 25    "test": "https://www.dropbox.com/sh/z4hbbzqai0ilqht/AABuUJQJ5yG5oCuziYzYu8jWa/test.zip?dl=1"
 26}
 27
 28CHECKSUM = {
 29    "train": "7101e19598e2b7aacdbd5e6e7575057b9154a4aaec043e0f4e28902bf4e2e209",
 30    "test": "d76c95c98a0353487ffb63b3bb2663c00ed1fde7d8fdfd8c3282c6e310a02731"
 31}
 32
 33
 34def get_drive_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 35    """Download the DRIVE dataset.
 36
 37    Args:
 38        path: Filepath to a folder where the data is downloaded for further processing.
 39        download: Whether to download the data if it is not present.
 40
 41    Returns:
 42        Filepath where the data is downloaded.
 43    """
 44    data_dir = os.path.join(path, "training")
 45    if os.path.exists(data_dir):
 46        return data_dir
 47
 48    os.makedirs(path, exist_ok=True)
 49
 50    zip_path = os.path.join(path, "training.zip")
 51    util.download_source_gdrive(
 52        path=zip_path, url=URL["train"], download=download, checksum=CHECKSUM["train"], download_type="zip",
 53    )
 54    util.unzip(zip_path=zip_path, dst=path)
 55
 56    return data_dir
 57
 58
 59def _get_drive_ground_truth(data_dir):
 60    gt_paths = sorted(glob(os.path.join(data_dir, "1st_manual", "*.gif")))
 61
 62    neu_gt_dir = os.path.join(data_dir, "gt")
 63    if os.path.exists(neu_gt_dir):
 64        return sorted(glob(os.path.join(neu_gt_dir, "*.tif")))
 65    else:
 66        os.makedirs(neu_gt_dir, exist_ok=True)
 67
 68    neu_gt_paths = []
 69    for gt_path in gt_paths:
 70        gt = imageio.imread(gt_path).squeeze()
 71        neu_gt_path = os.path.join(
 72            neu_gt_dir, Path(os.path.split(gt_path)[-1]).with_suffix(".tif")
 73        )
 74        imageio.imwrite(neu_gt_path, (gt > 0).astype("uint8"))
 75        neu_gt_paths.append(neu_gt_path)
 76
 77    return neu_gt_paths
 78
 79
 80def get_drive_paths(
 81    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 82) -> Tuple[List[str], List[str]]:
 83    """Get paths to the DRIVE data.
 84
 85    Args:
 86        path: Filepath to a folder where the data is downloaded for further processing.
 87        split: The choice of data split.
 88        download: Whether to download the data if it is not present.
 89
 90    Returns:
 91        List of filepaths for the image data.
 92        List of filepaths for the label data.
 93    """
 94    data_dir = get_drive_data(path=path, download=download)
 95
 96    image_paths = sorted(glob(os.path.join(data_dir, "images", "*.tif")))
 97    gt_paths = _get_drive_ground_truth(data_dir)
 98
 99    if split == "train":
100        image_paths, gt_paths = image_paths[:10], gt_paths[:10]
101    elif split == "val":
102        image_paths, gt_paths = image_paths[10:14], gt_paths[10:14]
103    elif split == "test":
104        image_paths, gt_paths = image_paths[14:], gt_paths[14:]
105    else:
106        raise ValueError(f"'{split}' is not a valid split.")
107
108    return image_paths, gt_paths
109
110
111def get_drive_dataset(
112    path: Union[os.PathLike, str],
113    patch_shape: Tuple[int, int],
114    split: Literal['train', 'val', 'test'],
115    resize_inputs: bool = False,
116    download: bool = False,
117    **kwargs
118) -> Dataset:
119    """Get the DRIVE dataset for segmentation of retinal blood vessels in fundus images.
120
121    Args:
122        path: Filepath to a folder where the data is downloaded for further processing.
123        patch_shape: The patch shape to use for training.
124        split: The choice of data split.
125        resize_inputs: Whether to resize the inputs to the expected patch shape.
126        download: Whether to download the data if it is not present.
127        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
128
129    Returns:
130        The segmentation dataset.
131    """
132    image_paths, gt_paths = get_drive_paths(path=path, split=split, download=download)
133
134    if resize_inputs:
135        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
136        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
137            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
138        )
139
140    return torch_em.default_segmentation_dataset(
141        raw_paths=image_paths,
142        raw_key=None,
143        label_paths=gt_paths,
144        label_key=None,
145        patch_shape=patch_shape,
146        is_seg_dataset=False,
147        **kwargs
148    )
149
150
151def get_drive_loader(
152    path: Union[os.PathLike, str],
153    batch_size: int,
154    patch_shape: Tuple[int, int],
155    split: str,
156    resize_inputs: bool = False,
157    download: bool = False,
158    **kwargs
159) -> DataLoader:
160    """Get the DRIVE dataloader for segmentation of retinal blood vessels in fundus images.
161
162    Args:
163        path: Filepath to a folder where the data is downloaded for further processing.
164        batch_size: The batch size for training.
165        patch_shape: The patch shape to use for training.
166        split: The choice of data split.
167        resize_inputs: Whether to resize the inputs to the expected patch shape.
168        download: Whether to download the data if it is not present.
169        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
170
171    Returns:
172        The DataLoader.
173    """
174    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
175    dataset = get_drive_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
176    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = {'train': 'https://www.dropbox.com/sh/z4hbbzqai0ilqht/AADp_8oefNFs2bjC2kzl2_Fqa/training.zip?dl=1', 'test': 'https://www.dropbox.com/sh/z4hbbzqai0ilqht/AABuUJQJ5yG5oCuziYzYu8jWa/test.zip?dl=1'}
CHECKSUM = {'train': '7101e19598e2b7aacdbd5e6e7575057b9154a4aaec043e0f4e28902bf4e2e209', 'test': 'd76c95c98a0353487ffb63b3bb2663c00ed1fde7d8fdfd8c3282c6e310a02731'}
def get_drive_data(path: Union[os.PathLike, str], download: bool = False) -> str:
35def get_drive_data(path: Union[os.PathLike, str], download: bool = False) -> str:
36    """Download the DRIVE dataset.
37
38    Args:
39        path: Filepath to a folder where the data is downloaded for further processing.
40        download: Whether to download the data if it is not present.
41
42    Returns:
43        Filepath where the data is downloaded.
44    """
45    data_dir = os.path.join(path, "training")
46    if os.path.exists(data_dir):
47        return data_dir
48
49    os.makedirs(path, exist_ok=True)
50
51    zip_path = os.path.join(path, "training.zip")
52    util.download_source_gdrive(
53        path=zip_path, url=URL["train"], download=download, checksum=CHECKSUM["train"], download_type="zip",
54    )
55    util.unzip(zip_path=zip_path, dst=path)
56
57    return data_dir

Download the DRIVE dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_drive_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
 81def get_drive_paths(
 82    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 83) -> Tuple[List[str], List[str]]:
 84    """Get paths to the DRIVE data.
 85
 86    Args:
 87        path: Filepath to a folder where the data is downloaded for further processing.
 88        split: The choice of data split.
 89        download: Whether to download the data if it is not present.
 90
 91    Returns:
 92        List of filepaths for the image data.
 93        List of filepaths for the label data.
 94    """
 95    data_dir = get_drive_data(path=path, download=download)
 96
 97    image_paths = sorted(glob(os.path.join(data_dir, "images", "*.tif")))
 98    gt_paths = _get_drive_ground_truth(data_dir)
 99
100    if split == "train":
101        image_paths, gt_paths = image_paths[:10], gt_paths[:10]
102    elif split == "val":
103        image_paths, gt_paths = image_paths[10:14], gt_paths[10:14]
104    elif split == "test":
105        image_paths, gt_paths = image_paths[14:], gt_paths[14:]
106    else:
107        raise ValueError(f"'{split}' is not a valid split.")
108
109    return image_paths, gt_paths

Get paths to the DRIVE data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_drive_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
112def get_drive_dataset(
113    path: Union[os.PathLike, str],
114    patch_shape: Tuple[int, int],
115    split: Literal['train', 'val', 'test'],
116    resize_inputs: bool = False,
117    download: bool = False,
118    **kwargs
119) -> Dataset:
120    """Get the DRIVE dataset for segmentation of retinal blood vessels in fundus images.
121
122    Args:
123        path: Filepath to a folder where the data is downloaded for further processing.
124        patch_shape: The patch shape to use for training.
125        split: The choice of data split.
126        resize_inputs: Whether to resize the inputs to the expected patch shape.
127        download: Whether to download the data if it is not present.
128        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
129
130    Returns:
131        The segmentation dataset.
132    """
133    image_paths, gt_paths = get_drive_paths(path=path, split=split, download=download)
134
135    if resize_inputs:
136        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
137        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
138            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
139        )
140
141    return torch_em.default_segmentation_dataset(
142        raw_paths=image_paths,
143        raw_key=None,
144        label_paths=gt_paths,
145        label_key=None,
146        patch_shape=patch_shape,
147        is_seg_dataset=False,
148        **kwargs
149    )

Get the DRIVE dataset for segmentation of retinal blood vessels in fundus images.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize the inputs to the expected patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_drive_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: str, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
152def get_drive_loader(
153    path: Union[os.PathLike, str],
154    batch_size: int,
155    patch_shape: Tuple[int, int],
156    split: str,
157    resize_inputs: bool = False,
158    download: bool = False,
159    **kwargs
160) -> DataLoader:
161    """Get the DRIVE dataloader for segmentation of retinal blood vessels in fundus images.
162
163    Args:
164        path: Filepath to a folder where the data is downloaded for further processing.
165        batch_size: The batch size for training.
166        patch_shape: The patch shape to use for training.
167        split: The choice of data split.
168        resize_inputs: Whether to resize the inputs to the expected patch shape.
169        download: Whether to download the data if it is not present.
170        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
171
172    Returns:
173        The DataLoader.
174    """
175    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
176    dataset = get_drive_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
177    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the DRIVE dataloader for segmentation of retinal blood vessels in fundus images.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize the inputs to the expected patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.