torch_em.data.datasets.medical.drive
The DRIVE dataset contains annotations for retinal vessel segmentation in fundus images.
This dataset is from the "DRIVE" challenge: https://drive.grand-challenge.org/. The dataset is from the publication https://doi.org/10.1109/TMI.2004.825627. Please cite them if you use this dataset for your research.
1"""The DRIVE dataset contains annotations for retinal vessel segmentation in 2fundus images. 3 4This dataset is from the "DRIVE" challenge: https://drive.grand-challenge.org/. 5The dataset is from the publication https://doi.org/10.1109/TMI.2004.825627. 6Please cite them if you use this dataset for your research. 7""" 8 9import os 10from glob import glob 11from pathlib import Path 12from typing import Union, Tuple, Literal, List 13 14import imageio.v3 as imageio 15 16from torch.utils.data import Dataset, DataLoader 17 18import torch_em 19 20from .. import util 21 22 23URL = { 24 "train": "https://www.dropbox.com/sh/z4hbbzqai0ilqht/AADp_8oefNFs2bjC2kzl2_Fqa/training.zip?dl=1", 25 "test": "https://www.dropbox.com/sh/z4hbbzqai0ilqht/AABuUJQJ5yG5oCuziYzYu8jWa/test.zip?dl=1" 26} 27 28CHECKSUM = { 29 "train": "7101e19598e2b7aacdbd5e6e7575057b9154a4aaec043e0f4e28902bf4e2e209", 30 "test": "d76c95c98a0353487ffb63b3bb2663c00ed1fde7d8fdfd8c3282c6e310a02731" 31} 32 33 34def get_drive_data(path: Union[os.PathLike, str], download: bool = False) -> str: 35 """Download the DRIVE dataset. 36 37 Args: 38 path: Filepath to a folder where the data is downloaded for further processing. 39 download: Whether to download the data if it is not present. 40 41 Returns: 42 Filepath where the data is downloaded. 43 """ 44 data_dir = os.path.join(path, "training") 45 if os.path.exists(data_dir): 46 return data_dir 47 48 os.makedirs(path, exist_ok=True) 49 50 zip_path = os.path.join(path, "training.zip") 51 util.download_source_gdrive( 52 path=zip_path, url=URL["train"], download=download, checksum=CHECKSUM["train"], download_type="zip", 53 ) 54 util.unzip(zip_path=zip_path, dst=path) 55 56 return data_dir 57 58 59def _get_drive_ground_truth(data_dir): 60 gt_paths = sorted(glob(os.path.join(data_dir, "1st_manual", "*.gif"))) 61 62 neu_gt_dir = os.path.join(data_dir, "gt") 63 if os.path.exists(neu_gt_dir): 64 return sorted(glob(os.path.join(neu_gt_dir, "*.tif"))) 65 else: 66 os.makedirs(neu_gt_dir, exist_ok=True) 67 68 neu_gt_paths = [] 69 for gt_path in gt_paths: 70 gt = imageio.imread(gt_path).squeeze() 71 neu_gt_path = os.path.join( 72 neu_gt_dir, Path(os.path.split(gt_path)[-1]).with_suffix(".tif") 73 ) 74 imageio.imwrite(neu_gt_path, (gt > 0).astype("uint8")) 75 neu_gt_paths.append(neu_gt_path) 76 77 return neu_gt_paths 78 79 80def get_drive_paths( 81 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 82) -> Tuple[List[str], List[str]]: 83 """Get paths to the DRIVE data. 84 85 Args: 86 path: Filepath to a folder where the data is downloaded for further processing. 87 split: The choice of data split. 88 download: Whether to download the data if it is not present. 89 90 Returns: 91 List of filepaths for the image data. 92 List of filepaths for the label data. 93 """ 94 data_dir = get_drive_data(path=path, download=download) 95 96 image_paths = sorted(glob(os.path.join(data_dir, "images", "*.tif"))) 97 gt_paths = _get_drive_ground_truth(data_dir) 98 99 if split == "train": 100 image_paths, gt_paths = image_paths[:10], gt_paths[:10] 101 elif split == "val": 102 image_paths, gt_paths = image_paths[10:14], gt_paths[10:14] 103 elif split == "test": 104 image_paths, gt_paths = image_paths[14:], gt_paths[14:] 105 else: 106 raise ValueError(f"'{split}' is not a valid split.") 107 108 return image_paths, gt_paths 109 110 111def get_drive_dataset( 112 path: Union[os.PathLike, str], 113 patch_shape: Tuple[int, int], 114 split: Literal['train', 'val', 'test'], 115 resize_inputs: bool = False, 116 download: bool = False, 117 **kwargs 118) -> Dataset: 119 """Get the DRIVE dataset for segmentation of retinal blood vessels in fundus images. 120 121 Args: 122 path: Filepath to a folder where the data is downloaded for further processing. 123 patch_shape: The patch shape to use for training. 124 split: The choice of data split. 125 resize_inputs: Whether to resize the inputs to the expected patch shape. 126 download: Whether to download the data if it is not present. 127 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 128 129 Returns: 130 The segmentation dataset. 131 """ 132 image_paths, gt_paths = get_drive_paths(path=path, split=split, download=download) 133 134 if resize_inputs: 135 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 136 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 137 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 138 ) 139 140 return torch_em.default_segmentation_dataset( 141 raw_paths=image_paths, 142 raw_key=None, 143 label_paths=gt_paths, 144 label_key=None, 145 patch_shape=patch_shape, 146 is_seg_dataset=False, 147 **kwargs 148 ) 149 150 151def get_drive_loader( 152 path: Union[os.PathLike, str], 153 batch_size: int, 154 patch_shape: Tuple[int, int], 155 split: str, 156 resize_inputs: bool = False, 157 download: bool = False, 158 **kwargs 159) -> DataLoader: 160 """Get the DRIVE dataloader for segmentation of retinal blood vessels in fundus images. 161 162 Args: 163 path: Filepath to a folder where the data is downloaded for further processing. 164 batch_size: The batch size for training. 165 patch_shape: The patch shape to use for training. 166 split: The choice of data split. 167 resize_inputs: Whether to resize the inputs to the expected patch shape. 168 download: Whether to download the data if it is not present. 169 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 170 171 Returns: 172 The DataLoader. 173 """ 174 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 175 dataset = get_drive_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 176 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
{'train': 'https://www.dropbox.com/sh/z4hbbzqai0ilqht/AADp_8oefNFs2bjC2kzl2_Fqa/training.zip?dl=1', 'test': 'https://www.dropbox.com/sh/z4hbbzqai0ilqht/AABuUJQJ5yG5oCuziYzYu8jWa/test.zip?dl=1'}
CHECKSUM =
{'train': '7101e19598e2b7aacdbd5e6e7575057b9154a4aaec043e0f4e28902bf4e2e209', 'test': 'd76c95c98a0353487ffb63b3bb2663c00ed1fde7d8fdfd8c3282c6e310a02731'}
def
get_drive_data(path: Union[os.PathLike, str], download: bool = False) -> str:
35def get_drive_data(path: Union[os.PathLike, str], download: bool = False) -> str: 36 """Download the DRIVE dataset. 37 38 Args: 39 path: Filepath to a folder where the data is downloaded for further processing. 40 download: Whether to download the data if it is not present. 41 42 Returns: 43 Filepath where the data is downloaded. 44 """ 45 data_dir = os.path.join(path, "training") 46 if os.path.exists(data_dir): 47 return data_dir 48 49 os.makedirs(path, exist_ok=True) 50 51 zip_path = os.path.join(path, "training.zip") 52 util.download_source_gdrive( 53 path=zip_path, url=URL["train"], download=download, checksum=CHECKSUM["train"], download_type="zip", 54 ) 55 util.unzip(zip_path=zip_path, dst=path) 56 57 return data_dir
Download the DRIVE dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_drive_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
81def get_drive_paths( 82 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 83) -> Tuple[List[str], List[str]]: 84 """Get paths to the DRIVE data. 85 86 Args: 87 path: Filepath to a folder where the data is downloaded for further processing. 88 split: The choice of data split. 89 download: Whether to download the data if it is not present. 90 91 Returns: 92 List of filepaths for the image data. 93 List of filepaths for the label data. 94 """ 95 data_dir = get_drive_data(path=path, download=download) 96 97 image_paths = sorted(glob(os.path.join(data_dir, "images", "*.tif"))) 98 gt_paths = _get_drive_ground_truth(data_dir) 99 100 if split == "train": 101 image_paths, gt_paths = image_paths[:10], gt_paths[:10] 102 elif split == "val": 103 image_paths, gt_paths = image_paths[10:14], gt_paths[10:14] 104 elif split == "test": 105 image_paths, gt_paths = image_paths[14:], gt_paths[14:] 106 else: 107 raise ValueError(f"'{split}' is not a valid split.") 108 109 return image_paths, gt_paths
Get paths to the DRIVE data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The choice of data split.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_drive_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
112def get_drive_dataset( 113 path: Union[os.PathLike, str], 114 patch_shape: Tuple[int, int], 115 split: Literal['train', 'val', 'test'], 116 resize_inputs: bool = False, 117 download: bool = False, 118 **kwargs 119) -> Dataset: 120 """Get the DRIVE dataset for segmentation of retinal blood vessels in fundus images. 121 122 Args: 123 path: Filepath to a folder where the data is downloaded for further processing. 124 patch_shape: The patch shape to use for training. 125 split: The choice of data split. 126 resize_inputs: Whether to resize the inputs to the expected patch shape. 127 download: Whether to download the data if it is not present. 128 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 129 130 Returns: 131 The segmentation dataset. 132 """ 133 image_paths, gt_paths = get_drive_paths(path=path, split=split, download=download) 134 135 if resize_inputs: 136 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 137 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 138 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 139 ) 140 141 return torch_em.default_segmentation_dataset( 142 raw_paths=image_paths, 143 raw_key=None, 144 label_paths=gt_paths, 145 label_key=None, 146 patch_shape=patch_shape, 147 is_seg_dataset=False, 148 **kwargs 149 )
Get the DRIVE dataset for segmentation of retinal blood vessels in fundus images.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize the inputs to the expected patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_drive_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: str, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
152def get_drive_loader( 153 path: Union[os.PathLike, str], 154 batch_size: int, 155 patch_shape: Tuple[int, int], 156 split: str, 157 resize_inputs: bool = False, 158 download: bool = False, 159 **kwargs 160) -> DataLoader: 161 """Get the DRIVE dataloader for segmentation of retinal blood vessels in fundus images. 162 163 Args: 164 path: Filepath to a folder where the data is downloaded for further processing. 165 batch_size: The batch size for training. 166 patch_shape: The patch shape to use for training. 167 split: The choice of data split. 168 resize_inputs: Whether to resize the inputs to the expected patch shape. 169 download: Whether to download the data if it is not present. 170 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 171 172 Returns: 173 The DataLoader. 174 """ 175 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 176 dataset = get_drive_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 177 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the DRIVE dataloader for segmentation of retinal blood vessels in fundus images.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize the inputs to the expected patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.