torch_em.data.datasets.electron_microscopy.lucchi
The Lucchi dataset is a segmentation dataset for mitochondrion segmentation in electron microscopy.
The dataset was published in https://doi.org/10.48550/arXiv.1812.06024. Please cite this publication if you use the dataset in your research. We use the version of the dataset from https://sites.google.com/view/connectomics/.
1"""The Lucchi dataset is a segmentation dataset for mitochondrion segmentation in electron microscopy. 2 3The dataset was published in https://doi.org/10.48550/arXiv.1812.06024. 4Please cite this publication if you use the dataset in your research. 5We use the version of the dataset from https://sites.google.com/view/connectomics/. 6""" 7 8import os 9from concurrent import futures 10from glob import glob 11from shutil import rmtree 12from typing import Tuple, Union 13 14import imageio 15import h5py 16import numpy as np 17import torch_em 18 19from torch.utils.data import Dataset, DataLoader 20from tqdm import tqdm 21from .. import util 22 23URL = "http://www.casser.io/files/lucchi_pp.zip" 24CHECKSUM = "770ce9e98fc6f29c1b1a250c637e6c5125f2b5f1260e5a7687b55a79e2e8844d" 25 26 27def _load_volume(path, pattern): 28 nz = len(glob(os.path.join(path, "*.png"))) 29 im0 = imageio.imread(os.path.join(path, pattern % 0)) 30 out = np.zeros((nz,) + im0.shape, dtype=im0.dtype) 31 out[0] = im0 32 33 def _loadz(z): 34 im = imageio.imread(os.path.join(path, pattern % z)) 35 out[z] = im 36 37 n_threads = 8 38 with futures.ThreadPoolExecutor(n_threads) as tp: 39 list(tqdm( 40 tp.map(_loadz, range(1, nz)), desc="Load volume", total=nz-1 41 )) 42 43 return out 44 45 46def _create_data(root, inputs, out_path): 47 raw = _load_volume(os.path.join(root, inputs[0]), pattern="mask%04i.png") 48 labels_argb = _load_volume(os.path.join(root, inputs[1]), pattern="%i.png") 49 if labels_argb.ndim == 4: 50 labels = np.zeros(raw.shape, dtype="uint8") 51 fg_mask = (labels_argb == np.array([255, 255, 255, 255])[None, None, None]).all(axis=-1) 52 labels[fg_mask] = 1 53 else: 54 assert labels_argb.ndim == 3 55 labels = labels_argb 56 labels[labels == 255] = 1 57 assert (np.unique(labels) == np.array([0, 1])).all() 58 assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}" 59 with h5py.File(out_path, "w") as f: 60 f.create_dataset("raw", data=raw, compression="gzip") 61 f.create_dataset("labels", data=labels.astype("uint8"), compression="gzip") 62 63 64def get_lucchi_data(path: Union[os.PathLike, str], split: str, download: bool) -> str: 65 """Download the lucchi dataset. 66 67 Args: 68 path: Filepath to a folder where the downloaded data will be saved. 69 split: The split to download, either 'train' or 'test'. 70 download: Whether to download the data if it is not present. 71 72 Returns: 73 The filepath for the downloaded data. 74 """ 75 data_path = os.path.join(path, f"lucchi_{split}.h5") 76 if os.path.exists(data_path): 77 return data_path 78 79 os.makedirs(path) 80 tmp_path = os.path.join(path, "lucchi.zip") 81 util.download_source(tmp_path, URL, download, checksum=CHECKSUM) 82 util.unzip(tmp_path, path, remove=True) 83 84 root = os.path.join(path, "Lucchi++") 85 assert os.path.exists(root), root 86 87 inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]] 88 outputs = ["lucchi_train.h5", "lucchi_test.h5"] 89 for inp, out in zip(inputs, outputs): 90 out_path = os.path.join(path, out) 91 _create_data(root, inp, out_path) 92 rmtree(root) 93 94 assert os.path.exists(data_path), data_path 95 return data_path 96 97 98def get_lucchi_dataset( 99 path: Union[os.PathLike, str], 100 split: str, 101 patch_shape: Tuple[int, int, int], 102 download: bool = False, 103 **kwargs 104) -> Dataset: 105 """Get dataset for EM mitochondrion segmentation in the lucchi dataset. 106 107 Args: 108 path: Filepath to a folder where the downloaded data will be saved. 109 split: The data split. Either 'train' or 'test'. 110 patch_shape: The patch shape to use for training. 111 download: Whether to download the data if it is not present. 112 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 113 114 Returns: 115 The segmentation dataset. 116 """ 117 assert split in ("train", "test") 118 data_path = get_lucchi_data(path, split, download) 119 raw_key, label_key = "raw", "labels" 120 return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) 121 122 123def get_lucchi_loader( 124 path: Union[os.PathLike, str], 125 split: str, 126 patch_shape: Tuple[int, int, int], 127 batch_size: int, 128 download: bool = False, 129 **kwargs 130) -> DataLoader: 131 """Get dataloader for EM mitochondrion segmentation in the lucchi dataset. 132 133 Args: 134 path: Filepath to a folder where the downloaded data will be saved. 135 split: The data split. Either 'train' or 'test'. 136 patch_shape: The patch shape to use for training. 137 batch_size: The batch size for training. 138 download: Whether to download the data if it is not present. 139 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 140 141 Returns: 142 The PyTorch DataLoader. 143 """ 144 ds_kwargs, loader_kwargs = util.split_kwargs( 145 torch_em.default_segmentation_dataset, **kwargs 146 ) 147 dataset = get_lucchi_dataset(path, split, patch_shape, download=download, **ds_kwargs) 148 loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 149 return loader
URL =
'http://www.casser.io/files/lucchi_pp.zip'
CHECKSUM =
'770ce9e98fc6f29c1b1a250c637e6c5125f2b5f1260e5a7687b55a79e2e8844d'
def
get_lucchi_data(path: Union[os.PathLike, str], split: str, download: bool) -> str:
65def get_lucchi_data(path: Union[os.PathLike, str], split: str, download: bool) -> str: 66 """Download the lucchi dataset. 67 68 Args: 69 path: Filepath to a folder where the downloaded data will be saved. 70 split: The split to download, either 'train' or 'test'. 71 download: Whether to download the data if it is not present. 72 73 Returns: 74 The filepath for the downloaded data. 75 """ 76 data_path = os.path.join(path, f"lucchi_{split}.h5") 77 if os.path.exists(data_path): 78 return data_path 79 80 os.makedirs(path) 81 tmp_path = os.path.join(path, "lucchi.zip") 82 util.download_source(tmp_path, URL, download, checksum=CHECKSUM) 83 util.unzip(tmp_path, path, remove=True) 84 85 root = os.path.join(path, "Lucchi++") 86 assert os.path.exists(root), root 87 88 inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]] 89 outputs = ["lucchi_train.h5", "lucchi_test.h5"] 90 for inp, out in zip(inputs, outputs): 91 out_path = os.path.join(path, out) 92 _create_data(root, inp, out_path) 93 rmtree(root) 94 95 assert os.path.exists(data_path), data_path 96 return data_path
Download the lucchi dataset.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The split to download, either 'train' or 'test'.
- download: Whether to download the data if it is not present.
Returns:
The filepath for the downloaded data.
def
get_lucchi_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
99def get_lucchi_dataset( 100 path: Union[os.PathLike, str], 101 split: str, 102 patch_shape: Tuple[int, int, int], 103 download: bool = False, 104 **kwargs 105) -> Dataset: 106 """Get dataset for EM mitochondrion segmentation in the lucchi dataset. 107 108 Args: 109 path: Filepath to a folder where the downloaded data will be saved. 110 split: The data split. Either 'train' or 'test'. 111 patch_shape: The patch shape to use for training. 112 download: Whether to download the data if it is not present. 113 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 114 115 Returns: 116 The segmentation dataset. 117 """ 118 assert split in ("train", "test") 119 data_path = get_lucchi_data(path, split, download) 120 raw_key, label_key = "raw", "labels" 121 return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs)
Get dataset for EM mitochondrion segmentation in the lucchi dataset.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The data split. Either 'train' or 'test'.
- patch_shape: The patch shape to use for training.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_lucchi_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], batch_size: int, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
124def get_lucchi_loader( 125 path: Union[os.PathLike, str], 126 split: str, 127 patch_shape: Tuple[int, int, int], 128 batch_size: int, 129 download: bool = False, 130 **kwargs 131) -> DataLoader: 132 """Get dataloader for EM mitochondrion segmentation in the lucchi dataset. 133 134 Args: 135 path: Filepath to a folder where the downloaded data will be saved. 136 split: The data split. Either 'train' or 'test'. 137 patch_shape: The patch shape to use for training. 138 batch_size: The batch size for training. 139 download: Whether to download the data if it is not present. 140 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 141 142 Returns: 143 The PyTorch DataLoader. 144 """ 145 ds_kwargs, loader_kwargs = util.split_kwargs( 146 torch_em.default_segmentation_dataset, **kwargs 147 ) 148 dataset = get_lucchi_dataset(path, split, patch_shape, download=download, **ds_kwargs) 149 loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 150 return loader
Get dataloader for EM mitochondrion segmentation in the lucchi dataset.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The data split. Either 'train' or 'test'.
- patch_shape: The patch shape to use for training.
- batch_size: The batch size for training.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The PyTorch DataLoader.