torch_em.data.datasets.medical.drive
1import os 2from glob import glob 3from pathlib import Path 4from typing import Union, Tuple 5 6import imageio.v3 as imageio 7 8import torch_em 9from torch_em.transform.generic import ResizeInputs 10 11from .. import util 12from ... import ImageCollectionDataset 13 14 15URL = { 16 "train": "https://www.dropbox.com/sh/z4hbbzqai0ilqht/AADp_8oefNFs2bjC2kzl2_Fqa/training.zip?dl=1", 17 "test": "https://www.dropbox.com/sh/z4hbbzqai0ilqht/AABuUJQJ5yG5oCuziYzYu8jWa/test.zip?dl=1" 18} 19 20CHECKSUM = { 21 "train": "7101e19598e2b7aacdbd5e6e7575057b9154a4aaec043e0f4e28902bf4e2e209", 22 "test": "d76c95c98a0353487ffb63b3bb2663c00ed1fde7d8fdfd8c3282c6e310a02731" 23} 24 25 26def get_drive_data(path, download): 27 os.makedirs(path, exist_ok=True) 28 29 data_dir = os.path.join(path, "training") 30 if os.path.exists(data_dir): 31 return data_dir 32 33 zip_path = os.path.join(path, "training.zip") 34 util.download_source_gdrive( 35 path=zip_path, url=URL["train"], download=download, checksum=CHECKSUM["train"], download_type="zip", 36 ) 37 util.unzip(zip_path=zip_path, dst=path) 38 39 return data_dir 40 41 42def _get_drive_ground_truth(data_dir): 43 gt_paths = sorted(glob(os.path.join(data_dir, "1st_manual", "*.gif"))) 44 45 neu_gt_dir = os.path.join(data_dir, "gt") 46 if os.path.exists(neu_gt_dir): 47 return sorted(glob(os.path.join(neu_gt_dir, "*.tif"))) 48 else: 49 os.makedirs(neu_gt_dir, exist_ok=True) 50 51 neu_gt_paths = [] 52 for gt_path in gt_paths: 53 gt = imageio.imread(gt_path).squeeze() 54 neu_gt_path = os.path.join( 55 neu_gt_dir, Path(os.path.split(gt_path)[-1]).with_suffix(".tif") 56 ) 57 imageio.imwrite(neu_gt_path, (gt > 0).astype("uint8")) 58 neu_gt_paths.append(neu_gt_path) 59 60 return neu_gt_paths 61 62 63def _get_drive_paths(path, download): 64 data_dir = get_drive_data(path=path, download=download) 65 66 image_paths = sorted(glob(os.path.join(data_dir, "images", "*.tif"))) 67 gt_paths = _get_drive_ground_truth(data_dir) 68 69 return image_paths, gt_paths 70 71 72def get_drive_dataset( 73 path: Union[os.PathLike, str], 74 patch_shape: Tuple[int, int], 75 resize_inputs: bool = False, 76 download: bool = False, 77 **kwargs 78): 79 """Dataset for segmentation of retinal blood vessels in fundus images. 80 81 This dataset is from the "DRIVE" challenge: 82 - https://drive.grand-challenge.org/ 83 - https://doi.org/10.1109/TMI.2004.825627 84 85 Please cite it if you use this dataset for a publication. 86 """ 87 image_paths, gt_paths = _get_drive_paths(path=path, download=download) 88 89 if resize_inputs: 90 raw_trafo = ResizeInputs(target_shape=patch_shape, is_rgb=True) 91 label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) 92 patch_shape = None 93 else: 94 patch_shape = patch_shape 95 raw_trafo, label_trafo = None, None 96 97 dataset = ImageCollectionDataset( 98 raw_image_paths=image_paths, 99 label_image_paths=gt_paths, 100 patch_shape=patch_shape, 101 raw_transform=raw_trafo, 102 label_transform=label_trafo, 103 **kwargs 104 ) 105 106 return dataset 107 108 109def get_drive_loader( 110 path: Union[os.PathLike, str], 111 patch_shape: Tuple[int, int], 112 batch_size: int, 113 resize_inputs: bool = False, 114 download: bool = False, 115 **kwargs 116): 117 """Dataloader for segmentation of retinal blood vessels in fundus images. See `get_drive_dataset` for details. 118 """ 119 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 120 dataset = get_drive_dataset( 121 path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs 122 ) 123 loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) 124 return loader
URL =
{'train': 'https://www.dropbox.com/sh/z4hbbzqai0ilqht/AADp_8oefNFs2bjC2kzl2_Fqa/training.zip?dl=1', 'test': 'https://www.dropbox.com/sh/z4hbbzqai0ilqht/AABuUJQJ5yG5oCuziYzYu8jWa/test.zip?dl=1'}
CHECKSUM =
{'train': '7101e19598e2b7aacdbd5e6e7575057b9154a4aaec043e0f4e28902bf4e2e209', 'test': 'd76c95c98a0353487ffb63b3bb2663c00ed1fde7d8fdfd8c3282c6e310a02731'}
def
get_drive_data(path, download):
27def get_drive_data(path, download): 28 os.makedirs(path, exist_ok=True) 29 30 data_dir = os.path.join(path, "training") 31 if os.path.exists(data_dir): 32 return data_dir 33 34 zip_path = os.path.join(path, "training.zip") 35 util.download_source_gdrive( 36 path=zip_path, url=URL["train"], download=download, checksum=CHECKSUM["train"], download_type="zip", 37 ) 38 util.unzip(zip_path=zip_path, dst=path) 39 40 return data_dir
def
get_drive_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], resize_inputs: bool = False, download: bool = False, **kwargs):
73def get_drive_dataset( 74 path: Union[os.PathLike, str], 75 patch_shape: Tuple[int, int], 76 resize_inputs: bool = False, 77 download: bool = False, 78 **kwargs 79): 80 """Dataset for segmentation of retinal blood vessels in fundus images. 81 82 This dataset is from the "DRIVE" challenge: 83 - https://drive.grand-challenge.org/ 84 - https://doi.org/10.1109/TMI.2004.825627 85 86 Please cite it if you use this dataset for a publication. 87 """ 88 image_paths, gt_paths = _get_drive_paths(path=path, download=download) 89 90 if resize_inputs: 91 raw_trafo = ResizeInputs(target_shape=patch_shape, is_rgb=True) 92 label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) 93 patch_shape = None 94 else: 95 patch_shape = patch_shape 96 raw_trafo, label_trafo = None, None 97 98 dataset = ImageCollectionDataset( 99 raw_image_paths=image_paths, 100 label_image_paths=gt_paths, 101 patch_shape=patch_shape, 102 raw_transform=raw_trafo, 103 label_transform=label_trafo, 104 **kwargs 105 ) 106 107 return dataset
Dataset for segmentation of retinal blood vessels in fundus images.
This dataset is from the "DRIVE" challenge:
Please cite it if you use this dataset for a publication.
def
get_drive_loader( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], batch_size: int, resize_inputs: bool = False, download: bool = False, **kwargs):
110def get_drive_loader( 111 path: Union[os.PathLike, str], 112 patch_shape: Tuple[int, int], 113 batch_size: int, 114 resize_inputs: bool = False, 115 download: bool = False, 116 **kwargs 117): 118 """Dataloader for segmentation of retinal blood vessels in fundus images. See `get_drive_dataset` for details. 119 """ 120 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 121 dataset = get_drive_dataset( 122 path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs 123 ) 124 loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) 125 return loader
Dataloader for segmentation of retinal blood vessels in fundus images. See get_drive_dataset
for details.