torch_em.data.datasets.electron_microscopy.kasthuri

The Kasthuri dataset is a segmentation dataset for mitochondrion segmentation in electron microscopy.

The dataset was published in https://doi.org/10.48550/arXiv.1812.06024. Please cite this publication if you use the dataset in your research. We use the version of the dataset from https://sites.google.com/view/connectomics/.

  1"""The Kasthuri dataset is a segmentation dataset for mitochondrion segmentation in electron microscopy.
  2
  3The dataset was published in https://doi.org/10.48550/arXiv.1812.06024.
  4Please cite this publication if you use the dataset in your research.
  5We use the version of the dataset from https://sites.google.com/view/connectomics/.
  6"""
  7
  8import os
  9from concurrent import futures
 10from glob import glob
 11from shutil import rmtree
 12from typing import Tuple, Union
 13
 14import imageio
 15import h5py
 16import numpy as np
 17import torch_em
 18
 19from torch.utils.data import Dataset, DataLoader
 20from tqdm import tqdm
 21from .. import util
 22
 23URL = "http://www.casser.io/files/kasthuri_pp.zip "
 24CHECKSUM = "bbb78fd205ec9b57feb8f93ebbdf1666261cbc3e0305e7f11583ab5157a3d792"
 25
 26# TODO: add sampler for foreground (-1 is empty area)
 27# TODO: and masking for the empty space
 28
 29
 30def _load_volume(path):
 31    files = glob(os.path.join(path, "*.png"))
 32    files.sort()
 33    nz = len(files)
 34
 35    im0 = imageio.imread(files[0])
 36    out = np.zeros((nz,) + im0.shape, dtype=im0.dtype)
 37    out[0] = im0
 38
 39    def _loadz(z):
 40        im = imageio.imread(files[z])
 41        out[z] = im
 42
 43    n_threads = 8
 44    with futures.ThreadPoolExecutor(n_threads) as tp:
 45        list(tqdm(
 46            tp.map(_loadz, range(1, nz)), desc="Load volume", total=nz-1
 47        ))
 48
 49    return out
 50
 51
 52def _create_data(root, inputs, out_path):
 53    raw = _load_volume(os.path.join(root, inputs[0]))
 54    labels_argb = _load_volume(os.path.join(root, inputs[1]))
 55    assert labels_argb.ndim == 4
 56    labels = np.zeros(raw.shape, dtype="int8")
 57
 58    fg_mask = (labels_argb == np.array([255, 255, 255])[None, None, None]).all(axis=-1)
 59    labels[fg_mask] = 1
 60    bg_mask = (labels_argb == np.array([2, 2, 2])[None, None, None]).all(axis=-1)
 61    labels[bg_mask] = -1
 62    assert (np.unique(labels) == np.array([-1, 0, 1])).all()
 63    assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}"
 64    with h5py.File(out_path, "w") as f:
 65        f.create_dataset("raw", data=raw, compression="gzip")
 66        f.create_dataset("labels", data=labels, compression="gzip")
 67
 68
 69def get_kasthuri_data(path: Union[os.PathLike, str], download: bool) -> str:
 70    """Download the kasthuri dataset.
 71
 72    Args:
 73        path: Filepath to a folder where the downloaded data will be saved.
 74        download: Whether to download the data if it is not present.
 75
 76    Returns:
 77        The filepath for the downloaded data.
 78    """
 79    if os.path.exists(path):
 80        return path
 81
 82    os.makedirs(path)
 83    tmp_path = os.path.join(path, "kasthuri.zip")
 84    util.download_source(tmp_path, URL, download, checksum=CHECKSUM)
 85    util.unzip(tmp_path, path, remove=True)
 86
 87    root = os.path.join(path, "Kasthuri++")
 88    assert os.path.exists(root), root
 89
 90    inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]]
 91    outputs = ["kasthuri_train.h5", "kasthuri_test.h5"]
 92    for inp, out in zip(inputs, outputs):
 93        out_path = os.path.join(path, out)
 94        _create_data(root, inp, out_path)
 95
 96    rmtree(root)
 97    return path
 98
 99
100def get_kasthuri_dataset(
101    path: Union[os.PathLike, str],
102    split: str,
103    patch_shape: Tuple[int, int, int],
104    download: bool = False,
105    **kwargs
106) -> Dataset:
107    """Get dataset for EM mitochondrion segmentation in the kasthuri dataset.
108
109    Args:
110        path: Filepath to a folder where the downloaded data will be saved.
111        split: The data split. Either 'train' or 'test'.
112        patch_shape: The patch shape to use for training.
113        download: Whether to download the data if it is not present.
114        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
115
116    Returns:
117        The segmentation dataset.
118    """
119    assert split in ("train", "test")
120    get_kasthuri_data(path, download)
121    data_path = os.path.join(path, f"kasthuri_{split}.h5")
122    assert os.path.exists(data_path), data_path
123    raw_key, label_key = "raw", "labels"
124    return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs)
125
126
127def get_kasthuri_loader(
128    path: Union[os.PathLike, str],
129    split: str,
130    patch_shape: Tuple[int, int, int],
131    batch_size: int,
132    download: bool = False,
133    **kwargs
134) -> DataLoader:
135    """Get dataloader for EM mitochondrion segmentation in the kasthuri dataset.
136
137    Args:
138        path: Filepath to a folder where the downloaded data will be saved.
139        split: The data split. Either 'train' or 'test'.
140        patch_shape: The patch shape to use for training.
141        batch_size: The batch size for training.
142        download: Whether to download the data if it is not present.
143        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
144
145    Returns:
146        The PyTorch DataLoader.
147    """
148    ds_kwargs, loader_kwargs = util.split_kwargs(
149        torch_em.default_segmentation_dataset, **kwargs
150    )
151    dataset = get_kasthuri_dataset(path, split, patch_shape, download=download, **ds_kwargs)
152    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
153    return loader
URL = 'http://www.casser.io/files/kasthuri_pp.zip '
CHECKSUM = 'bbb78fd205ec9b57feb8f93ebbdf1666261cbc3e0305e7f11583ab5157a3d792'
def get_kasthuri_data(path: Union[os.PathLike, str], download: bool) -> str:
70def get_kasthuri_data(path: Union[os.PathLike, str], download: bool) -> str:
71    """Download the kasthuri dataset.
72
73    Args:
74        path: Filepath to a folder where the downloaded data will be saved.
75        download: Whether to download the data if it is not present.
76
77    Returns:
78        The filepath for the downloaded data.
79    """
80    if os.path.exists(path):
81        return path
82
83    os.makedirs(path)
84    tmp_path = os.path.join(path, "kasthuri.zip")
85    util.download_source(tmp_path, URL, download, checksum=CHECKSUM)
86    util.unzip(tmp_path, path, remove=True)
87
88    root = os.path.join(path, "Kasthuri++")
89    assert os.path.exists(root), root
90
91    inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]]
92    outputs = ["kasthuri_train.h5", "kasthuri_test.h5"]
93    for inp, out in zip(inputs, outputs):
94        out_path = os.path.join(path, out)
95        _create_data(root, inp, out_path)
96
97    rmtree(root)
98    return path

Download the kasthuri dataset.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • download: Whether to download the data if it is not present.
Returns:

The filepath for the downloaded data.

def get_kasthuri_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
101def get_kasthuri_dataset(
102    path: Union[os.PathLike, str],
103    split: str,
104    patch_shape: Tuple[int, int, int],
105    download: bool = False,
106    **kwargs
107) -> Dataset:
108    """Get dataset for EM mitochondrion segmentation in the kasthuri dataset.
109
110    Args:
111        path: Filepath to a folder where the downloaded data will be saved.
112        split: The data split. Either 'train' or 'test'.
113        patch_shape: The patch shape to use for training.
114        download: Whether to download the data if it is not present.
115        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
116
117    Returns:
118        The segmentation dataset.
119    """
120    assert split in ("train", "test")
121    get_kasthuri_data(path, download)
122    data_path = os.path.join(path, f"kasthuri_{split}.h5")
123    assert os.path.exists(data_path), data_path
124    raw_key, label_key = "raw", "labels"
125    return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs)

Get dataset for EM mitochondrion segmentation in the kasthuri dataset.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split. Either 'train' or 'test'.
  • patch_shape: The patch shape to use for training.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_kasthuri_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], batch_size: int, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
128def get_kasthuri_loader(
129    path: Union[os.PathLike, str],
130    split: str,
131    patch_shape: Tuple[int, int, int],
132    batch_size: int,
133    download: bool = False,
134    **kwargs
135) -> DataLoader:
136    """Get dataloader for EM mitochondrion segmentation in the kasthuri dataset.
137
138    Args:
139        path: Filepath to a folder where the downloaded data will be saved.
140        split: The data split. Either 'train' or 'test'.
141        patch_shape: The patch shape to use for training.
142        batch_size: The batch size for training.
143        download: Whether to download the data if it is not present.
144        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
145
146    Returns:
147        The PyTorch DataLoader.
148    """
149    ds_kwargs, loader_kwargs = util.split_kwargs(
150        torch_em.default_segmentation_dataset, **kwargs
151    )
152    dataset = get_kasthuri_dataset(path, split, patch_shape, download=download, **ds_kwargs)
153    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
154    return loader

Get dataloader for EM mitochondrion segmentation in the kasthuri dataset.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split. Either 'train' or 'test'.
  • patch_shape: The patch shape to use for training.
  • batch_size: The batch size for training.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The PyTorch DataLoader.