torch_em.data.datasets.electron_microscopy.kasthuri
The Kasthuri dataset is a segmentation dataset for mitochondrion segmentation in electron microscopy.
The dataset was published in https://doi.org/10.48550/arXiv.1812.06024. Please cite this publication if you use the dataset in your research. We use the version of the dataset from https://sites.google.com/view/connectomics/.
1"""The Kasthuri dataset is a segmentation dataset for mitochondrion segmentation in electron microscopy. 2 3The dataset was published in https://doi.org/10.48550/arXiv.1812.06024. 4Please cite this publication if you use the dataset in your research. 5We use the version of the dataset from https://sites.google.com/view/connectomics/. 6""" 7 8import os 9from concurrent import futures 10from glob import glob 11from shutil import rmtree 12from typing import Tuple, Union 13 14import imageio 15import h5py 16import numpy as np 17import torch_em 18 19from torch.utils.data import Dataset, DataLoader 20from tqdm import tqdm 21from .. import util 22 23URL = "http://www.casser.io/files/kasthuri_pp.zip " 24CHECKSUM = "bbb78fd205ec9b57feb8f93ebbdf1666261cbc3e0305e7f11583ab5157a3d792" 25 26# TODO: add sampler for foreground (-1 is empty area) 27# TODO: and masking for the empty space 28 29 30def _load_volume(path): 31 files = glob(os.path.join(path, "*.png")) 32 files.sort() 33 nz = len(files) 34 35 im0 = imageio.imread(files[0]) 36 out = np.zeros((nz,) + im0.shape, dtype=im0.dtype) 37 out[0] = im0 38 39 def _loadz(z): 40 im = imageio.imread(files[z]) 41 out[z] = im 42 43 n_threads = 8 44 with futures.ThreadPoolExecutor(n_threads) as tp: 45 list(tqdm( 46 tp.map(_loadz, range(1, nz)), desc="Load volume", total=nz-1 47 )) 48 49 return out 50 51 52def _create_data(root, inputs, out_path): 53 raw = _load_volume(os.path.join(root, inputs[0])) 54 labels_argb = _load_volume(os.path.join(root, inputs[1])) 55 assert labels_argb.ndim == 4 56 labels = np.zeros(raw.shape, dtype="int8") 57 58 fg_mask = (labels_argb == np.array([255, 255, 255])[None, None, None]).all(axis=-1) 59 labels[fg_mask] = 1 60 bg_mask = (labels_argb == np.array([2, 2, 2])[None, None, None]).all(axis=-1) 61 labels[bg_mask] = -1 62 assert (np.unique(labels) == np.array([-1, 0, 1])).all() 63 assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}" 64 with h5py.File(out_path, "w") as f: 65 f.create_dataset("raw", data=raw, compression="gzip") 66 f.create_dataset("labels", data=labels, compression="gzip") 67 68 69def get_kasthuri_data(path: Union[os.PathLike, str], download: bool) -> str: 70 """Download the kasthuri dataset. 71 72 Args: 73 path: Filepath to a folder where the downloaded data will be saved. 74 download: Whether to download the data if it is not present. 75 76 Returns: 77 The filepath for the downloaded data. 78 """ 79 if os.path.exists(path): 80 return path 81 82 os.makedirs(path) 83 tmp_path = os.path.join(path, "kasthuri.zip") 84 util.download_source(tmp_path, URL, download, checksum=CHECKSUM) 85 util.unzip(tmp_path, path, remove=True) 86 87 root = os.path.join(path, "Kasthuri++") 88 assert os.path.exists(root), root 89 90 inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]] 91 outputs = ["kasthuri_train.h5", "kasthuri_test.h5"] 92 for inp, out in zip(inputs, outputs): 93 out_path = os.path.join(path, out) 94 _create_data(root, inp, out_path) 95 96 rmtree(root) 97 return path 98 99 100def get_kasthuri_dataset( 101 path: Union[os.PathLike, str], 102 split: str, 103 patch_shape: Tuple[int, int, int], 104 download: bool = False, 105 **kwargs 106) -> Dataset: 107 """Get dataset for EM mitochondrion segmentation in the kasthuri dataset. 108 109 Args: 110 path: Filepath to a folder where the downloaded data will be saved. 111 split: The data split. Either 'train' or 'test'. 112 patch_shape: The patch shape to use for training. 113 download: Whether to download the data if it is not present. 114 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 115 116 Returns: 117 The segmentation dataset. 118 """ 119 assert split in ("train", "test") 120 get_kasthuri_data(path, download) 121 data_path = os.path.join(path, f"kasthuri_{split}.h5") 122 assert os.path.exists(data_path), data_path 123 raw_key, label_key = "raw", "labels" 124 return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) 125 126 127def get_kasthuri_loader( 128 path: Union[os.PathLike, str], 129 split: str, 130 patch_shape: Tuple[int, int, int], 131 batch_size: int, 132 download: bool = False, 133 **kwargs 134) -> DataLoader: 135 """Get dataloader for EM mitochondrion segmentation in the kasthuri dataset. 136 137 Args: 138 path: Filepath to a folder where the downloaded data will be saved. 139 split: The data split. Either 'train' or 'test'. 140 patch_shape: The patch shape to use for training. 141 batch_size: The batch size for training. 142 download: Whether to download the data if it is not present. 143 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 144 145 Returns: 146 The PyTorch DataLoader. 147 """ 148 ds_kwargs, loader_kwargs = util.split_kwargs( 149 torch_em.default_segmentation_dataset, **kwargs 150 ) 151 dataset = get_kasthuri_dataset(path, split, patch_shape, download=download, **ds_kwargs) 152 loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 153 return loader
URL =
'http://www.casser.io/files/kasthuri_pp.zip '
CHECKSUM =
'bbb78fd205ec9b57feb8f93ebbdf1666261cbc3e0305e7f11583ab5157a3d792'
def
get_kasthuri_data(path: Union[os.PathLike, str], download: bool) -> str:
70def get_kasthuri_data(path: Union[os.PathLike, str], download: bool) -> str: 71 """Download the kasthuri dataset. 72 73 Args: 74 path: Filepath to a folder where the downloaded data will be saved. 75 download: Whether to download the data if it is not present. 76 77 Returns: 78 The filepath for the downloaded data. 79 """ 80 if os.path.exists(path): 81 return path 82 83 os.makedirs(path) 84 tmp_path = os.path.join(path, "kasthuri.zip") 85 util.download_source(tmp_path, URL, download, checksum=CHECKSUM) 86 util.unzip(tmp_path, path, remove=True) 87 88 root = os.path.join(path, "Kasthuri++") 89 assert os.path.exists(root), root 90 91 inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]] 92 outputs = ["kasthuri_train.h5", "kasthuri_test.h5"] 93 for inp, out in zip(inputs, outputs): 94 out_path = os.path.join(path, out) 95 _create_data(root, inp, out_path) 96 97 rmtree(root) 98 return path
Download the kasthuri dataset.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- download: Whether to download the data if it is not present.
Returns:
The filepath for the downloaded data.
def
get_kasthuri_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
101def get_kasthuri_dataset( 102 path: Union[os.PathLike, str], 103 split: str, 104 patch_shape: Tuple[int, int, int], 105 download: bool = False, 106 **kwargs 107) -> Dataset: 108 """Get dataset for EM mitochondrion segmentation in the kasthuri dataset. 109 110 Args: 111 path: Filepath to a folder where the downloaded data will be saved. 112 split: The data split. Either 'train' or 'test'. 113 patch_shape: The patch shape to use for training. 114 download: Whether to download the data if it is not present. 115 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 116 117 Returns: 118 The segmentation dataset. 119 """ 120 assert split in ("train", "test") 121 get_kasthuri_data(path, download) 122 data_path = os.path.join(path, f"kasthuri_{split}.h5") 123 assert os.path.exists(data_path), data_path 124 raw_key, label_key = "raw", "labels" 125 return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs)
Get dataset for EM mitochondrion segmentation in the kasthuri dataset.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The data split. Either 'train' or 'test'.
- patch_shape: The patch shape to use for training.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_kasthuri_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], batch_size: int, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
128def get_kasthuri_loader( 129 path: Union[os.PathLike, str], 130 split: str, 131 patch_shape: Tuple[int, int, int], 132 batch_size: int, 133 download: bool = False, 134 **kwargs 135) -> DataLoader: 136 """Get dataloader for EM mitochondrion segmentation in the kasthuri dataset. 137 138 Args: 139 path: Filepath to a folder where the downloaded data will be saved. 140 split: The data split. Either 'train' or 'test'. 141 patch_shape: The patch shape to use for training. 142 batch_size: The batch size for training. 143 download: Whether to download the data if it is not present. 144 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 145 146 Returns: 147 The PyTorch DataLoader. 148 """ 149 ds_kwargs, loader_kwargs = util.split_kwargs( 150 torch_em.default_segmentation_dataset, **kwargs 151 ) 152 dataset = get_kasthuri_dataset(path, split, patch_shape, download=download, **ds_kwargs) 153 loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 154 return loader
Get dataloader for EM mitochondrion segmentation in the kasthuri dataset.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The data split. Either 'train' or 'test'.
- patch_shape: The patch shape to use for training.
- batch_size: The batch size for training.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The PyTorch DataLoader.