torch_em.data.datasets.medical.drive

  1import os
  2from glob import glob
  3from pathlib import Path
  4from typing import Union, Tuple
  5
  6import imageio.v3 as imageio
  7
  8import torch_em
  9from torch_em.transform.generic import ResizeInputs
 10
 11from .. import util
 12from ... import ImageCollectionDataset
 13
 14
 15URL = {
 16    "train": "https://www.dropbox.com/sh/z4hbbzqai0ilqht/AADp_8oefNFs2bjC2kzl2_Fqa/training.zip?dl=1",
 17    "test": "https://www.dropbox.com/sh/z4hbbzqai0ilqht/AABuUJQJ5yG5oCuziYzYu8jWa/test.zip?dl=1"
 18}
 19
 20CHECKSUM = {
 21    "train": "7101e19598e2b7aacdbd5e6e7575057b9154a4aaec043e0f4e28902bf4e2e209",
 22    "test": "d76c95c98a0353487ffb63b3bb2663c00ed1fde7d8fdfd8c3282c6e310a02731"
 23}
 24
 25
 26def get_drive_data(path, download):
 27    os.makedirs(path, exist_ok=True)
 28
 29    data_dir = os.path.join(path, "training")
 30    if os.path.exists(data_dir):
 31        return data_dir
 32
 33    zip_path = os.path.join(path, "training.zip")
 34    util.download_source_gdrive(
 35        path=zip_path, url=URL["train"], download=download, checksum=CHECKSUM["train"], download_type="zip",
 36    )
 37    util.unzip(zip_path=zip_path, dst=path)
 38
 39    return data_dir
 40
 41
 42def _get_drive_ground_truth(data_dir):
 43    gt_paths = sorted(glob(os.path.join(data_dir, "1st_manual", "*.gif")))
 44
 45    neu_gt_dir = os.path.join(data_dir, "gt")
 46    if os.path.exists(neu_gt_dir):
 47        return sorted(glob(os.path.join(neu_gt_dir, "*.tif")))
 48    else:
 49        os.makedirs(neu_gt_dir, exist_ok=True)
 50
 51    neu_gt_paths = []
 52    for gt_path in gt_paths:
 53        gt = imageio.imread(gt_path).squeeze()
 54        neu_gt_path = os.path.join(
 55            neu_gt_dir, Path(os.path.split(gt_path)[-1]).with_suffix(".tif")
 56        )
 57        imageio.imwrite(neu_gt_path, (gt > 0).astype("uint8"))
 58        neu_gt_paths.append(neu_gt_path)
 59
 60    return neu_gt_paths
 61
 62
 63def _get_drive_paths(path, download):
 64    data_dir = get_drive_data(path=path, download=download)
 65
 66    image_paths = sorted(glob(os.path.join(data_dir, "images", "*.tif")))
 67    gt_paths = _get_drive_ground_truth(data_dir)
 68
 69    return image_paths, gt_paths
 70
 71
 72def get_drive_dataset(
 73    path: Union[os.PathLike, str],
 74    patch_shape: Tuple[int, int],
 75    resize_inputs: bool = False,
 76    download: bool = False,
 77    **kwargs
 78):
 79    """Dataset for segmentation of retinal blood vessels in fundus images.
 80
 81    This dataset is from the "DRIVE" challenge:
 82    - https://drive.grand-challenge.org/
 83    - https://doi.org/10.1109/TMI.2004.825627
 84
 85    Please cite it if you use this dataset for a publication.
 86    """
 87    image_paths, gt_paths = _get_drive_paths(path=path, download=download)
 88
 89    if resize_inputs:
 90        raw_trafo = ResizeInputs(target_shape=patch_shape, is_rgb=True)
 91        label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True)
 92        patch_shape = None
 93    else:
 94        patch_shape = patch_shape
 95        raw_trafo, label_trafo = None, None
 96
 97    dataset = ImageCollectionDataset(
 98        raw_image_paths=image_paths,
 99        label_image_paths=gt_paths,
100        patch_shape=patch_shape,
101        raw_transform=raw_trafo,
102        label_transform=label_trafo,
103        **kwargs
104    )
105
106    return dataset
107
108
109def get_drive_loader(
110    path: Union[os.PathLike, str],
111    patch_shape: Tuple[int, int],
112    batch_size: int,
113    resize_inputs: bool = False,
114    download: bool = False,
115    **kwargs
116):
117    """Dataloader for segmentation of retinal blood vessels in fundus images. See `get_drive_dataset` for details.
118    """
119    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
120    dataset = get_drive_dataset(
121        path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs
122    )
123    loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
124    return loader
URL = {'train': 'https://www.dropbox.com/sh/z4hbbzqai0ilqht/AADp_8oefNFs2bjC2kzl2_Fqa/training.zip?dl=1', 'test': 'https://www.dropbox.com/sh/z4hbbzqai0ilqht/AABuUJQJ5yG5oCuziYzYu8jWa/test.zip?dl=1'}
CHECKSUM = {'train': '7101e19598e2b7aacdbd5e6e7575057b9154a4aaec043e0f4e28902bf4e2e209', 'test': 'd76c95c98a0353487ffb63b3bb2663c00ed1fde7d8fdfd8c3282c6e310a02731'}
def get_drive_data(path, download):
27def get_drive_data(path, download):
28    os.makedirs(path, exist_ok=True)
29
30    data_dir = os.path.join(path, "training")
31    if os.path.exists(data_dir):
32        return data_dir
33
34    zip_path = os.path.join(path, "training.zip")
35    util.download_source_gdrive(
36        path=zip_path, url=URL["train"], download=download, checksum=CHECKSUM["train"], download_type="zip",
37    )
38    util.unzip(zip_path=zip_path, dst=path)
39
40    return data_dir
def get_drive_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], resize_inputs: bool = False, download: bool = False, **kwargs):
 73def get_drive_dataset(
 74    path: Union[os.PathLike, str],
 75    patch_shape: Tuple[int, int],
 76    resize_inputs: bool = False,
 77    download: bool = False,
 78    **kwargs
 79):
 80    """Dataset for segmentation of retinal blood vessels in fundus images.
 81
 82    This dataset is from the "DRIVE" challenge:
 83    - https://drive.grand-challenge.org/
 84    - https://doi.org/10.1109/TMI.2004.825627
 85
 86    Please cite it if you use this dataset for a publication.
 87    """
 88    image_paths, gt_paths = _get_drive_paths(path=path, download=download)
 89
 90    if resize_inputs:
 91        raw_trafo = ResizeInputs(target_shape=patch_shape, is_rgb=True)
 92        label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True)
 93        patch_shape = None
 94    else:
 95        patch_shape = patch_shape
 96        raw_trafo, label_trafo = None, None
 97
 98    dataset = ImageCollectionDataset(
 99        raw_image_paths=image_paths,
100        label_image_paths=gt_paths,
101        patch_shape=patch_shape,
102        raw_transform=raw_trafo,
103        label_transform=label_trafo,
104        **kwargs
105    )
106
107    return dataset

Dataset for segmentation of retinal blood vessels in fundus images.

This dataset is from the "DRIVE" challenge:

Please cite it if you use this dataset for a publication.

def get_drive_loader( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], batch_size: int, resize_inputs: bool = False, download: bool = False, **kwargs):
110def get_drive_loader(
111    path: Union[os.PathLike, str],
112    patch_shape: Tuple[int, int],
113    batch_size: int,
114    resize_inputs: bool = False,
115    download: bool = False,
116    **kwargs
117):
118    """Dataloader for segmentation of retinal blood vessels in fundus images. See `get_drive_dataset` for details.
119    """
120    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
121    dataset = get_drive_dataset(
122        path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs
123    )
124    loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
125    return loader

Dataloader for segmentation of retinal blood vessels in fundus images. See get_drive_dataset for details.