torch_em.data.datasets.electron_microscopy.kasthuri

The Kasthuri dataset is a segmentation dataset for mitochondrion segmentation in electron microscopy.

The dataset was published in https://doi.org/10.48550/arXiv.1812.06024. Please cite this publication if you use the dataset in your research. We use the version of the dataset from https://sites.google.com/view/connectomics/.

  1"""The Kasthuri dataset is a segmentation dataset for mitochondrion segmentation in electron microscopy.
  2
  3The dataset was published in https://doi.org/10.48550/arXiv.1812.06024.
  4Please cite this publication if you use the dataset in your research.
  5We use the version of the dataset from https://sites.google.com/view/connectomics/.
  6"""
  7
  8import os
  9from glob import glob
 10from tqdm import tqdm
 11from shutil import rmtree
 12from concurrent import futures
 13from typing import Tuple, Union
 14
 15import imageio
 16import numpy as np
 17
 18import torch_em
 19
 20from torch.utils.data import Dataset, DataLoader
 21
 22from .. import util
 23
 24
 25URL = "http://www.casser.io/files/kasthuri_pp.zip "
 26CHECKSUM = "bbb78fd205ec9b57feb8f93ebbdf1666261cbc3e0305e7f11583ab5157a3d792"
 27
 28# TODO: add sampler for foreground (-1 is empty area)
 29# TODO: and masking for the empty space
 30
 31
 32def _load_volume(path):
 33    files = glob(os.path.join(path, "*.png"))
 34    files.sort()
 35    nz = len(files)
 36
 37    im0 = imageio.imread(files[0])
 38    out = np.zeros((nz,) + im0.shape, dtype=im0.dtype)
 39    out[0] = im0
 40
 41    def _loadz(z):
 42        im = imageio.imread(files[z])
 43        out[z] = im
 44
 45    n_threads = 8
 46    with futures.ThreadPoolExecutor(n_threads) as tp:
 47        list(tqdm(
 48            tp.map(_loadz, range(1, nz)), desc="Load volume", total=nz-1
 49        ))
 50
 51    return out
 52
 53
 54def _create_data(root, inputs, out_path):
 55    import h5py
 56
 57    raw = _load_volume(os.path.join(root, inputs[0]))
 58    labels_argb = _load_volume(os.path.join(root, inputs[1]))
 59    assert labels_argb.ndim == 4
 60    labels = np.zeros(raw.shape, dtype="int8")
 61
 62    fg_mask = (labels_argb == np.array([255, 255, 255])[None, None, None]).all(axis=-1)
 63    labels[fg_mask] = 1
 64    bg_mask = (labels_argb == np.array([2, 2, 2])[None, None, None]).all(axis=-1)
 65    labels[bg_mask] = -1
 66    assert (np.unique(labels) == np.array([-1, 0, 1])).all()
 67    assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}"
 68    with h5py.File(out_path, "w") as f:
 69        f.create_dataset("raw", data=raw, compression="gzip")
 70        f.create_dataset("labels", data=labels, compression="gzip")
 71
 72
 73def get_kasthuri_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 74    """Download the kasthuri dataset.
 75
 76    Args:
 77        path: Filepath to a folder where the downloaded data will be saved.
 78        download: Whether to download the data if it is not present.
 79
 80    Returns:
 81        The filepath for the downloaded data.
 82    """
 83    if os.path.exists(path):
 84        return path
 85
 86    os.makedirs(path)
 87    tmp_path = os.path.join(path, "kasthuri.zip")
 88    util.download_source(tmp_path, URL, download, checksum=CHECKSUM)
 89    util.unzip(tmp_path, path, remove=True)
 90
 91    root = os.path.join(path, "Kasthuri++")
 92    assert os.path.exists(root), root
 93
 94    inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]]
 95    outputs = ["kasthuri_train.h5", "kasthuri_test.h5"]
 96    for inp, out in zip(inputs, outputs):
 97        out_path = os.path.join(path, out)
 98        _create_data(root, inp, out_path)
 99
100    rmtree(root)
101    return path
102
103
104def get_kasthuri_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> str:
105    """Get paths to the Kasthuri data.
106
107    Args:
108        path: Filepath to a folder where the downloaded data will be saved.
109        split: The data split. Either 'train' or 'test'.
110        download: Whether to download the data if it is not present.
111
112    Returns:
113        The filepath to the stored data.
114    """
115    get_kasthuri_data(path, download)
116    data_path = os.path.join(path, f"kasthuri_{split}.h5")
117    assert os.path.exists(data_path), data_path
118    return data_path
119
120
121def get_kasthuri_dataset(
122    path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], download: bool = False, **kwargs
123) -> Dataset:
124    """Get dataset for EM mitochondrion segmentation in the kasthuri dataset.
125
126    Args:
127        path: Filepath to a folder where the downloaded data will be saved.
128        split: The data split. Either 'train' or 'test'.
129        patch_shape: The patch shape to use for training.
130        download: Whether to download the data if it is not present.
131        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
132
133    Returns:
134        The segmentation dataset.
135    """
136    assert split in ("train", "test")
137
138    data_path = get_kasthuri_paths(path, split, download)
139
140    return torch_em.default_segmentation_dataset(
141        raw_paths=data_path,
142        raw_key="raw",
143        label_paths=data_path,
144        label_key="labels",
145        patch_shape=patch_shape,
146        **kwargs
147    )
148
149
150def get_kasthuri_loader(
151    path: Union[os.PathLike, str],
152    split: str,
153    patch_shape: Tuple[int, int, int],
154    batch_size: int,
155    download: bool = False,
156    **kwargs
157) -> DataLoader:
158    """Get dataloader for EM mitochondrion segmentation in the kasthuri dataset.
159
160    Args:
161        path: Filepath to a folder where the downloaded data will be saved.
162        split: The data split. Either 'train' or 'test'.
163        patch_shape: The patch shape to use for training.
164        batch_size: The batch size for training.
165        download: Whether to download the data if it is not present.
166        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
167
168    Returns:
169        The PyTorch DataLoader.
170    """
171    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
172    dataset = get_kasthuri_dataset(path, split, patch_shape, download=download, **ds_kwargs)
173    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'http://www.casser.io/files/kasthuri_pp.zip '
CHECKSUM = 'bbb78fd205ec9b57feb8f93ebbdf1666261cbc3e0305e7f11583ab5157a3d792'
def get_kasthuri_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 74def get_kasthuri_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 75    """Download the kasthuri dataset.
 76
 77    Args:
 78        path: Filepath to a folder where the downloaded data will be saved.
 79        download: Whether to download the data if it is not present.
 80
 81    Returns:
 82        The filepath for the downloaded data.
 83    """
 84    if os.path.exists(path):
 85        return path
 86
 87    os.makedirs(path)
 88    tmp_path = os.path.join(path, "kasthuri.zip")
 89    util.download_source(tmp_path, URL, download, checksum=CHECKSUM)
 90    util.unzip(tmp_path, path, remove=True)
 91
 92    root = os.path.join(path, "Kasthuri++")
 93    assert os.path.exists(root), root
 94
 95    inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]]
 96    outputs = ["kasthuri_train.h5", "kasthuri_test.h5"]
 97    for inp, out in zip(inputs, outputs):
 98        out_path = os.path.join(path, out)
 99        _create_data(root, inp, out_path)
100
101    rmtree(root)
102    return path

Download the kasthuri dataset.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • download: Whether to download the data if it is not present.
Returns:

The filepath for the downloaded data.

def get_kasthuri_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> str:
105def get_kasthuri_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> str:
106    """Get paths to the Kasthuri data.
107
108    Args:
109        path: Filepath to a folder where the downloaded data will be saved.
110        split: The data split. Either 'train' or 'test'.
111        download: Whether to download the data if it is not present.
112
113    Returns:
114        The filepath to the stored data.
115    """
116    get_kasthuri_data(path, download)
117    data_path = os.path.join(path, f"kasthuri_{split}.h5")
118    assert os.path.exists(data_path), data_path
119    return data_path

Get paths to the Kasthuri data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split. Either 'train' or 'test'.
  • download: Whether to download the data if it is not present.
Returns:

The filepath to the stored data.

def get_kasthuri_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
122def get_kasthuri_dataset(
123    path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], download: bool = False, **kwargs
124) -> Dataset:
125    """Get dataset for EM mitochondrion segmentation in the kasthuri dataset.
126
127    Args:
128        path: Filepath to a folder where the downloaded data will be saved.
129        split: The data split. Either 'train' or 'test'.
130        patch_shape: The patch shape to use for training.
131        download: Whether to download the data if it is not present.
132        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
133
134    Returns:
135        The segmentation dataset.
136    """
137    assert split in ("train", "test")
138
139    data_path = get_kasthuri_paths(path, split, download)
140
141    return torch_em.default_segmentation_dataset(
142        raw_paths=data_path,
143        raw_key="raw",
144        label_paths=data_path,
145        label_key="labels",
146        patch_shape=patch_shape,
147        **kwargs
148    )

Get dataset for EM mitochondrion segmentation in the kasthuri dataset.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split. Either 'train' or 'test'.
  • patch_shape: The patch shape to use for training.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_kasthuri_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], batch_size: int, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
151def get_kasthuri_loader(
152    path: Union[os.PathLike, str],
153    split: str,
154    patch_shape: Tuple[int, int, int],
155    batch_size: int,
156    download: bool = False,
157    **kwargs
158) -> DataLoader:
159    """Get dataloader for EM mitochondrion segmentation in the kasthuri dataset.
160
161    Args:
162        path: Filepath to a folder where the downloaded data will be saved.
163        split: The data split. Either 'train' or 'test'.
164        patch_shape: The patch shape to use for training.
165        batch_size: The batch size for training.
166        download: Whether to download the data if it is not present.
167        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
168
169    Returns:
170        The PyTorch DataLoader.
171    """
172    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
173    dataset = get_kasthuri_dataset(path, split, patch_shape, download=download, **ds_kwargs)
174    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get dataloader for EM mitochondrion segmentation in the kasthuri dataset.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split. Either 'train' or 'test'.
  • patch_shape: The patch shape to use for training.
  • batch_size: The batch size for training.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The PyTorch DataLoader.