torch_em.data.datasets.electron_microscopy.axondeepseg

AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. It contains two different data types: TEM and SEM. The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. Please cite this publication if you use the dataset in your research.

  1"""AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM.
  2It contains two different data types: TEM and SEM.
  3The dataset was published in https://doi.org/10.1038/s41598-018-22181-4.
  4Please cite this publication if you use the dataset in your research.
  5"""
  6
  7import os
  8from glob import glob
  9from shutil import rmtree
 10from typing import Optional, Tuple, Union
 11
 12import imageio
 13import h5py
 14import numpy as np
 15import torch_em
 16
 17from torch.utils.data import Dataset, DataLoader
 18from .. import util
 19
 20URLS = {
 21    "sem": "https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip",
 22    "tem": "https://osf.io/download/uewd9"
 23}
 24CHECKSUMS = {
 25    "sem": "d334cbacf548f78ce8dd4a597bf86b884bd15a47a230a0ccc46e1ffa94d58426",
 26    "tem": "e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95"
 27}
 28
 29
 30def _preprocess_sem_data(out_path):
 31    # preprocess the data to get it to a better data format
 32    data_root = os.path.join(out_path, "data_axondeepseg_sem-master")
 33    assert os.path.exists(data_root)
 34
 35    # get the raw data paths
 36    raw_folders = glob(os.path.join(data_root, "sub-rat*"))
 37    raw_folders.sort()
 38    raw_paths = []
 39    for folder in raw_folders:
 40        paths = glob(os.path.join(folder, "micr", "*.png"))
 41        paths.sort()
 42        raw_paths.extend(paths)
 43
 44    # get the label paths
 45    label_folders = glob(os.path.join(
 46        data_root, "derivatives", "labels", "sub-rat*"
 47    ))
 48    label_folders.sort()
 49    label_paths = []
 50    for folder in label_folders:
 51        paths = glob(os.path.join(folder, "micr", "*axonmyelin-manual.png"))
 52        paths.sort()
 53        label_paths.extend(paths)
 54    assert len(raw_paths) == len(label_paths), f"{len(raw_paths)}, {len(label_paths)}"
 55
 56    # process raw data and labels
 57    for i, (rp, lp) in enumerate(zip(raw_paths, label_paths)):
 58        outp = os.path.join(out_path, f"sem_data_{i}.h5")
 59        with h5py.File(outp, "w") as f:
 60
 61            # raw data: invert to match tem em intensities
 62            raw = imageio.imread(rp)
 63            assert np.dtype(raw.dtype) == np.dtype("uint8")
 64            if raw.ndim == 3:  # (one of the images is RGBA)
 65                raw = np.mean(raw[..., :-3], axis=-1)
 66            raw = 255 - raw
 67            f.create_dataset("raw", data=raw, compression="gzip")
 68
 69            # labels: map from
 70            # 0 -> 0
 71            # 127, 128 -> 1
 72            # 255 -> 2
 73            labels = imageio.imread(lp)
 74            assert labels.shape == raw.shape, f"{labels.shape}, {raw.shape}"
 75            label_vals = np.unique(labels)
 76            # 127, 128: both myelin labels, 130, 233: noise
 77            assert len(np.setdiff1d(label_vals, [0, 127, 128, 130, 233, 255])) == 0, f"{label_vals}"
 78            new_labels = np.zeros_like(labels)
 79            new_labels[labels == 127] = 1
 80            new_labels[labels == 128] = 1
 81            new_labels[labels == 255] = 2
 82            f.create_dataset("labels", data=new_labels, compression="gzip")
 83
 84    # clean up
 85    rmtree(data_root)
 86
 87
 88def _preprocess_tem_data(out_path):
 89    data_root = os.path.join(out_path, "TEM_dataset")
 90    folder_names = os.listdir(data_root)
 91    folders = [os.path.join(data_root, fname) for fname in folder_names
 92               if os.path.isdir(os.path.join(data_root, fname))]
 93    for i, folder in enumerate(folders):
 94        data_out = os.path.join(out_path, f"tem_{i}.h5")
 95        with h5py.File(data_out, "w") as f:
 96            im = imageio.imread(os.path.join(folder, "image.png"))
 97            f.create_dataset("raw", data=im, compression="gzip")
 98
 99            # labels: map from
100            # 0 -> 0
101            # 128 -> 1
102            # 255 -> 2
103            # the rest are noise
104            labels = imageio.imread(os.path.join(folder, "mask.png"))
105            new_labels = np.zeros_like(labels)
106            new_labels[labels == 128] = 1
107            new_labels[labels == 255] = 2
108            f.create_dataset("labels", data=new_labels, compression="gzip")
109
110    # clean up
111    rmtree(data_root)
112
113
114def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: bool) -> str:
115    """Download the axondeepseg data.
116
117    Args:
118        path: Filepath to a folder where the downloaded data will be saved.
119        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
120        download: Whether to download the data if it is not present.
121
122    Returns:
123        The filepath for the downloaded data.
124    """
125
126    # download and unzip the data
127    url, checksum = URLS[name], CHECKSUMS[name]
128    os.makedirs(path, exist_ok=True)
129    out_path = os.path.join(path, name)
130    if os.path.exists(out_path):
131        return out_path
132
133    tmp_path = os.path.join(path, f"{name}.zip")
134    util.download_source(tmp_path, url, download, checksum=checksum)
135    util.unzip(tmp_path, out_path, remove=True)
136
137    if name == "sem":
138        _preprocess_sem_data(out_path)
139    elif name == "tem":
140        _preprocess_tem_data(out_path)
141    else:
142        raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.")
143
144    return out_path
145
146
147def get_axondeepseg_dataset(
148    path: Union[str, os.PathLike],
149    name: str,
150    patch_shape: Tuple[int, int],
151    download: bool = False,
152    one_hot_encoding: bool = False,
153    val_fraction: Optional[float] = None,
154    split: Optional[str] = None,
155    **kwargs,
156) -> Dataset:
157    """Get dataset for segmnetation of myelinated axons.
158
159    Args:
160        path: Filepath to a folder where the downloaded data will be saved.
161        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
162        patch_shape: The patch shape to use for training.
163        download: Whether to download the data if it is not present.
164        one_hot_encoding: Whether to return the labels as one hot encoding.
165        val_fraction: The fraction of the data to use for validation.
166        split: The data split. Either 'train' or 'val'.
167        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
168
169    Returns:
170        The segmentation dataset.
171    """
172    if isinstance(name, str):
173        name = [name]
174    assert isinstance(name, (tuple, list))
175
176    all_paths = []
177    for nn in name:
178        data_root = get_axondeepseg_data(path, nn, download)
179        paths = glob(os.path.join(data_root, "*.h5"))
180        paths.sort()
181        if val_fraction is not None:
182            assert split is not None
183            n_samples = int(len(paths) * (1 - val_fraction))
184            paths = paths[:n_samples] if split == "train" else paths[n_samples:]
185        all_paths.extend(paths)
186
187    if one_hot_encoding:
188        if isinstance(one_hot_encoding, bool):
189            # add transformation to go from [0, 1, 2] to one hot encoding
190            class_ids = [0, 1, 2]
191        elif isinstance(one_hot_encoding, int):
192            class_ids = list(range(one_hot_encoding))
193        elif isinstance(one_hot_encoding, (list, tuple)):
194            class_ids = list(one_hot_encoding)
195        else:
196            raise ValueError(
197                f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list."
198            )
199        label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids)
200        msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden."
201        kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
202
203    raw_key, label_key = "raw", "labels"
204    return torch_em.default_segmentation_dataset(all_paths, raw_key, all_paths, label_key, patch_shape, **kwargs)
205
206
207def get_axondeepseg_loader(
208    path: Union[str, os.PathLike],
209    name: str,
210    patch_shape: Tuple[int, int],
211    batch_size: int,
212    download: bool = False,
213    one_hot_encoding: bool = False,
214    val_fraction: Optional[float] = None,
215    split: Optional[str] = None,
216    **kwargs
217) -> DataLoader:
218    """Get dataloader for the segmentation of myelinated axons.
219
220    Args:
221        path: Filepath to a folder where the downloaded data will be saved.
222        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
223        patch_shape: The patch shape to use for training.
224        batch_size: The batch size for training.
225        download: Whether to download the data if it is not present.
226        one_hot_encoding: Whether to return the labels as one hot encoding.
227        val_fraction: The fraction of the data to use for validation.
228        split: The data split. Either 'train' or 'val'.
229        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
230
231    Returns:
232        The PyTorch DataLoader.
233    """
234    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
235    dataset = get_axondeepseg_dataset(
236        path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding,
237        val_fraction=val_fraction, split=split, **ds_kwargs
238    )
239    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'sem': 'https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip', 'tem': 'https://osf.io/download/uewd9'}
CHECKSUMS = {'sem': 'd334cbacf548f78ce8dd4a597bf86b884bd15a47a230a0ccc46e1ffa94d58426', 'tem': 'e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95'}
def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: bool) -> str:
115def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: bool) -> str:
116    """Download the axondeepseg data.
117
118    Args:
119        path: Filepath to a folder where the downloaded data will be saved.
120        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
121        download: Whether to download the data if it is not present.
122
123    Returns:
124        The filepath for the downloaded data.
125    """
126
127    # download and unzip the data
128    url, checksum = URLS[name], CHECKSUMS[name]
129    os.makedirs(path, exist_ok=True)
130    out_path = os.path.join(path, name)
131    if os.path.exists(out_path):
132        return out_path
133
134    tmp_path = os.path.join(path, f"{name}.zip")
135    util.download_source(tmp_path, url, download, checksum=checksum)
136    util.unzip(tmp_path, out_path, remove=True)
137
138    if name == "sem":
139        _preprocess_sem_data(out_path)
140    elif name == "tem":
141        _preprocess_tem_data(out_path)
142    else:
143        raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.")
144
145    return out_path

Download the axondeepseg data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • download: Whether to download the data if it is not present.
Returns:

The filepath for the downloaded data.

def get_axondeepseg_dataset( path: Union[str, os.PathLike], name: str, patch_shape: Tuple[int, int], download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[str] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
148def get_axondeepseg_dataset(
149    path: Union[str, os.PathLike],
150    name: str,
151    patch_shape: Tuple[int, int],
152    download: bool = False,
153    one_hot_encoding: bool = False,
154    val_fraction: Optional[float] = None,
155    split: Optional[str] = None,
156    **kwargs,
157) -> Dataset:
158    """Get dataset for segmnetation of myelinated axons.
159
160    Args:
161        path: Filepath to a folder where the downloaded data will be saved.
162        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
163        patch_shape: The patch shape to use for training.
164        download: Whether to download the data if it is not present.
165        one_hot_encoding: Whether to return the labels as one hot encoding.
166        val_fraction: The fraction of the data to use for validation.
167        split: The data split. Either 'train' or 'val'.
168        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
169
170    Returns:
171        The segmentation dataset.
172    """
173    if isinstance(name, str):
174        name = [name]
175    assert isinstance(name, (tuple, list))
176
177    all_paths = []
178    for nn in name:
179        data_root = get_axondeepseg_data(path, nn, download)
180        paths = glob(os.path.join(data_root, "*.h5"))
181        paths.sort()
182        if val_fraction is not None:
183            assert split is not None
184            n_samples = int(len(paths) * (1 - val_fraction))
185            paths = paths[:n_samples] if split == "train" else paths[n_samples:]
186        all_paths.extend(paths)
187
188    if one_hot_encoding:
189        if isinstance(one_hot_encoding, bool):
190            # add transformation to go from [0, 1, 2] to one hot encoding
191            class_ids = [0, 1, 2]
192        elif isinstance(one_hot_encoding, int):
193            class_ids = list(range(one_hot_encoding))
194        elif isinstance(one_hot_encoding, (list, tuple)):
195            class_ids = list(one_hot_encoding)
196        else:
197            raise ValueError(
198                f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list."
199            )
200        label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids)
201        msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden."
202        kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
203
204    raw_key, label_key = "raw", "labels"
205    return torch_em.default_segmentation_dataset(all_paths, raw_key, all_paths, label_key, patch_shape, **kwargs)

Get dataset for segmnetation of myelinated axons.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • patch_shape: The patch shape to use for training.
  • download: Whether to download the data if it is not present.
  • one_hot_encoding: Whether to return the labels as one hot encoding.
  • val_fraction: The fraction of the data to use for validation.
  • split: The data split. Either 'train' or 'val'.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_axondeepseg_loader( path: Union[str, os.PathLike], name: str, patch_shape: Tuple[int, int], batch_size: int, download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[str] = None, **kwargs) -> torch.utils.data.dataloader.DataLoader:
208def get_axondeepseg_loader(
209    path: Union[str, os.PathLike],
210    name: str,
211    patch_shape: Tuple[int, int],
212    batch_size: int,
213    download: bool = False,
214    one_hot_encoding: bool = False,
215    val_fraction: Optional[float] = None,
216    split: Optional[str] = None,
217    **kwargs
218) -> DataLoader:
219    """Get dataloader for the segmentation of myelinated axons.
220
221    Args:
222        path: Filepath to a folder where the downloaded data will be saved.
223        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
224        patch_shape: The patch shape to use for training.
225        batch_size: The batch size for training.
226        download: Whether to download the data if it is not present.
227        one_hot_encoding: Whether to return the labels as one hot encoding.
228        val_fraction: The fraction of the data to use for validation.
229        split: The data split. Either 'train' or 'val'.
230        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
231
232    Returns:
233        The PyTorch DataLoader.
234    """
235    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
236    dataset = get_axondeepseg_dataset(
237        path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding,
238        val_fraction=val_fraction, split=split, **ds_kwargs
239    )
240    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get dataloader for the segmentation of myelinated axons.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • patch_shape: The patch shape to use for training.
  • batch_size: The batch size for training.
  • download: Whether to download the data if it is not present.
  • one_hot_encoding: Whether to return the labels as one hot encoding.
  • val_fraction: The fraction of the data to use for validation.
  • split: The data split. Either 'train' or 'val'.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The PyTorch DataLoader.