torch_em.data.datasets.electron_microscopy.axondeepseg

AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. It contains two different data types: TEM and SEM.

The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. Please cite this publication if you use the dataset in your research.

  1"""AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM.
  2It contains two different data types: TEM and SEM.
  3
  4The dataset was published in https://doi.org/10.1038/s41598-018-22181-4.
  5Please cite this publication if you use the dataset in your research.
  6"""
  7
  8import os
  9from glob import glob
 10from shutil import rmtree
 11from typing import Optional, Tuple, Union, Literal, List
 12
 13import imageio
 14import numpy as np
 15
 16from torch.utils.data import Dataset, DataLoader
 17
 18import torch_em
 19
 20from .. import util
 21
 22URLS = {
 23    "sem": "https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip",
 24    "tem": "https://osf.io/download/uewd9"
 25}
 26CHECKSUMS = {
 27    "sem": "d334cbacf548f78ce8dd4a597bf86b884bd15a47a230a0ccc46e1ffa94d58426",
 28    "tem": "e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95"
 29}
 30
 31
 32def _preprocess_sem_data(out_path):
 33    import h5py
 34
 35    # preprocess the data to get it to a better data format
 36    data_root = os.path.join(out_path, "data_axondeepseg_sem-master")
 37    assert os.path.exists(data_root)
 38
 39    # get the raw data paths
 40    raw_folders = glob(os.path.join(data_root, "sub-rat*"))
 41    raw_folders.sort()
 42    raw_paths = []
 43    for folder in raw_folders:
 44        paths = glob(os.path.join(folder, "micr", "*.png"))
 45        paths.sort()
 46        raw_paths.extend(paths)
 47
 48    # get the label paths
 49    label_folders = glob(os.path.join(
 50        data_root, "derivatives", "labels", "sub-rat*"
 51    ))
 52    label_folders.sort()
 53    label_paths = []
 54    for folder in label_folders:
 55        paths = glob(os.path.join(folder, "micr", "*axonmyelin-manual.png"))
 56        paths.sort()
 57        label_paths.extend(paths)
 58    assert len(raw_paths) == len(label_paths), f"{len(raw_paths)}, {len(label_paths)}"
 59
 60    # process raw data and labels
 61    for i, (rp, lp) in enumerate(zip(raw_paths, label_paths)):
 62        outp = os.path.join(out_path, f"sem_data_{i}.h5")
 63        with h5py.File(outp, "w") as f:
 64
 65            # raw data: invert to match tem em intensities
 66            raw = imageio.imread(rp)
 67            assert np.dtype(raw.dtype) == np.dtype("uint8")
 68            if raw.ndim == 3:  # (one of the images is RGBA)
 69                raw = np.mean(raw[..., :-3], axis=-1)
 70            raw = 255 - raw
 71            f.create_dataset("raw", data=raw, compression="gzip")
 72
 73            # labels: map from
 74            # 0 -> 0
 75            # 127, 128 -> 1
 76            # 255 -> 2
 77            labels = imageio.imread(lp)
 78            assert labels.shape == raw.shape, f"{labels.shape}, {raw.shape}"
 79            label_vals = np.unique(labels)
 80            # 127, 128: both myelin labels, 130, 233: noise
 81            assert len(np.setdiff1d(label_vals, [0, 127, 128, 130, 233, 255])) == 0, f"{label_vals}"
 82            new_labels = np.zeros_like(labels)
 83            new_labels[labels == 127] = 1
 84            new_labels[labels == 128] = 1
 85            new_labels[labels == 255] = 2
 86            f.create_dataset("labels", data=new_labels, compression="gzip")
 87
 88    # clean up
 89    rmtree(data_root)
 90
 91
 92def _preprocess_tem_data(out_path):
 93    import h5py
 94
 95    data_root = os.path.join(out_path, "TEM_dataset")
 96    folder_names = os.listdir(data_root)
 97    folders = [os.path.join(data_root, fname) for fname in folder_names
 98               if os.path.isdir(os.path.join(data_root, fname))]
 99    for i, folder in enumerate(folders):
100        data_out = os.path.join(out_path, f"tem_{i}.h5")
101        with h5py.File(data_out, "w") as f:
102            im = imageio.imread(os.path.join(folder, "image.png"))
103            f.create_dataset("raw", data=im, compression="gzip")
104
105            # labels: map from
106            # 0 -> 0
107            # 128 -> 1
108            # 255 -> 2
109            # the rest are noise
110            labels = imageio.imread(os.path.join(folder, "mask.png"))
111            new_labels = np.zeros_like(labels)
112            new_labels[labels == 128] = 1
113            new_labels[labels == 255] = 2
114            f.create_dataset("labels", data=new_labels, compression="gzip")
115
116    # clean up
117    rmtree(data_root)
118
119
120def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str:
121    """Download the AxonDeepSeg data.
122
123    Args:
124        path: Filepath to a folder where the downloaded data will be saved.
125        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
126        download: Whether to download the data if it is not present.
127
128    Returns:
129        The filepath for the downloaded data.
130    """
131
132    # download and unzip the data
133    url, checksum = URLS[name], CHECKSUMS[name]
134    os.makedirs(path, exist_ok=True)
135    out_path = os.path.join(path, name)
136    if os.path.exists(out_path):
137        return out_path
138
139    tmp_path = os.path.join(path, f"{name}.zip")
140    util.download_source(tmp_path, url, download, checksum=checksum)
141    util.unzip(tmp_path, out_path, remove=True)
142
143    if name == "sem":
144        _preprocess_sem_data(out_path)
145    elif name == "tem":
146        _preprocess_tem_data(out_path)
147    else:
148        raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.")
149
150    return out_path
151
152
153def get_axondeepseg_paths(
154    path: Union[str, os.PathLike],
155    name: Literal["sem", "tem"],
156    download: bool = False,
157    val_fraction: Optional[float] = None,
158    split: Optional[str] = None,
159) -> List[str]:
160    """Get paths to the AxonDeepSeg data.
161
162    Args:
163        path: Filepath to a folder where the downloaded data will be saved.
164        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
165        download: Whether to download the data if it is not present.
166        val_fraction: The fraction of the data to use for validation.
167        split: The data split. Either 'train' or 'val'.
168
169    Returns:
170        List of paths for all the data.
171    """
172    all_paths = []
173    for nn in name:
174        data_root = get_axondeepseg_data(path, nn, download)
175        paths = glob(os.path.join(data_root, "*.h5"))
176        paths.sort()
177        if val_fraction is not None:
178            assert split is not None
179            n_samples = int(len(paths) * (1 - val_fraction))
180            paths = paths[:n_samples] if split == "train" else paths[n_samples:]
181        all_paths.extend(paths)
182
183    return all_paths
184
185
186def get_axondeepseg_dataset(
187    path: Union[str, os.PathLike],
188    name: Literal["sem", "tem"],
189    patch_shape: Tuple[int, int],
190    download: bool = False,
191    one_hot_encoding: bool = False,
192    val_fraction: Optional[float] = None,
193    split: Optional[Literal['train', 'val']] = None,
194    **kwargs,
195) -> Dataset:
196    """Get dataset for segmentation of myelinated axons.
197
198    Args:
199        path: Filepath to a folder where the downloaded data will be saved.
200        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
201        patch_shape: The patch shape to use for training.
202        download: Whether to download the data if it is not present.
203        one_hot_encoding: Whether to return the labels as one hot encoding.
204        val_fraction: The fraction of the data to use for validation.
205        split: The data split. Either 'train' or 'val'.
206        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
207
208    Returns:
209        The segmentation dataset.
210    """
211    if isinstance(name, str):
212        name = [name]
213    assert isinstance(name, (tuple, list))
214
215    all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split)
216
217    if one_hot_encoding:
218        if isinstance(one_hot_encoding, bool):
219            # add transformation to go from [0, 1, 2] to one hot encoding
220            class_ids = [0, 1, 2]
221        elif isinstance(one_hot_encoding, int):
222            class_ids = list(range(one_hot_encoding))
223        elif isinstance(one_hot_encoding, (list, tuple)):
224            class_ids = list(one_hot_encoding)
225        else:
226            raise ValueError(
227                f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list."
228            )
229        label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids)
230        msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden."
231        kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
232
233    return torch_em.default_segmentation_dataset(
234        raw_paths=all_paths,
235        raw_key="raw",
236        label_paths=all_paths,
237        label_key="labels",
238        patch_shape=patch_shape,
239        **kwargs
240    )
241
242
243def get_axondeepseg_loader(
244    path: Union[str, os.PathLike],
245    name: Literal["sem", "tem"],
246    patch_shape: Tuple[int, int],
247    batch_size: int,
248    download: bool = False,
249    one_hot_encoding: bool = False,
250    val_fraction: Optional[float] = None,
251    split: Optional[Literal["train", "val"]] = None,
252    **kwargs
253) -> DataLoader:
254    """Get dataloader for the segmentation of myelinated axons.
255
256    Args:
257        path: Filepath to a folder where the downloaded data will be saved.
258        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
259        patch_shape: The patch shape to use for training.
260        batch_size: The batch size for training.
261        download: Whether to download the data if it is not present.
262        one_hot_encoding: Whether to return the labels as one hot encoding.
263        val_fraction: The fraction of the data to use for validation.
264        split: The data split. Either 'train' or 'val'.
265        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
266
267    Returns:
268        The PyTorch DataLoader.
269    """
270    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
271    dataset = get_axondeepseg_dataset(
272        path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding,
273        val_fraction=val_fraction, split=split, **ds_kwargs
274    )
275    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'sem': 'https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip', 'tem': 'https://osf.io/download/uewd9'}
CHECKSUMS = {'sem': 'd334cbacf548f78ce8dd4a597bf86b884bd15a47a230a0ccc46e1ffa94d58426', 'tem': 'e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95'}
def get_axondeepseg_data( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], download: bool = False) -> str:
121def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str:
122    """Download the AxonDeepSeg data.
123
124    Args:
125        path: Filepath to a folder where the downloaded data will be saved.
126        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
127        download: Whether to download the data if it is not present.
128
129    Returns:
130        The filepath for the downloaded data.
131    """
132
133    # download and unzip the data
134    url, checksum = URLS[name], CHECKSUMS[name]
135    os.makedirs(path, exist_ok=True)
136    out_path = os.path.join(path, name)
137    if os.path.exists(out_path):
138        return out_path
139
140    tmp_path = os.path.join(path, f"{name}.zip")
141    util.download_source(tmp_path, url, download, checksum=checksum)
142    util.unzip(tmp_path, out_path, remove=True)
143
144    if name == "sem":
145        _preprocess_sem_data(out_path)
146    elif name == "tem":
147        _preprocess_tem_data(out_path)
148    else:
149        raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.")
150
151    return out_path

Download the AxonDeepSeg data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • download: Whether to download the data if it is not present.
Returns:

The filepath for the downloaded data.

def get_axondeepseg_paths( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], download: bool = False, val_fraction: Optional[float] = None, split: Optional[str] = None) -> List[str]:
154def get_axondeepseg_paths(
155    path: Union[str, os.PathLike],
156    name: Literal["sem", "tem"],
157    download: bool = False,
158    val_fraction: Optional[float] = None,
159    split: Optional[str] = None,
160) -> List[str]:
161    """Get paths to the AxonDeepSeg data.
162
163    Args:
164        path: Filepath to a folder where the downloaded data will be saved.
165        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
166        download: Whether to download the data if it is not present.
167        val_fraction: The fraction of the data to use for validation.
168        split: The data split. Either 'train' or 'val'.
169
170    Returns:
171        List of paths for all the data.
172    """
173    all_paths = []
174    for nn in name:
175        data_root = get_axondeepseg_data(path, nn, download)
176        paths = glob(os.path.join(data_root, "*.h5"))
177        paths.sort()
178        if val_fraction is not None:
179            assert split is not None
180            n_samples = int(len(paths) * (1 - val_fraction))
181            paths = paths[:n_samples] if split == "train" else paths[n_samples:]
182        all_paths.extend(paths)
183
184    return all_paths

Get paths to the AxonDeepSeg data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • download: Whether to download the data if it is not present.
  • val_fraction: The fraction of the data to use for validation.
  • split: The data split. Either 'train' or 'val'.
Returns:

List of paths for all the data.

def get_axondeepseg_dataset( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], patch_shape: Tuple[int, int], download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[Literal['train', 'val']] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
187def get_axondeepseg_dataset(
188    path: Union[str, os.PathLike],
189    name: Literal["sem", "tem"],
190    patch_shape: Tuple[int, int],
191    download: bool = False,
192    one_hot_encoding: bool = False,
193    val_fraction: Optional[float] = None,
194    split: Optional[Literal['train', 'val']] = None,
195    **kwargs,
196) -> Dataset:
197    """Get dataset for segmentation of myelinated axons.
198
199    Args:
200        path: Filepath to a folder where the downloaded data will be saved.
201        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
202        patch_shape: The patch shape to use for training.
203        download: Whether to download the data if it is not present.
204        one_hot_encoding: Whether to return the labels as one hot encoding.
205        val_fraction: The fraction of the data to use for validation.
206        split: The data split. Either 'train' or 'val'.
207        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
208
209    Returns:
210        The segmentation dataset.
211    """
212    if isinstance(name, str):
213        name = [name]
214    assert isinstance(name, (tuple, list))
215
216    all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split)
217
218    if one_hot_encoding:
219        if isinstance(one_hot_encoding, bool):
220            # add transformation to go from [0, 1, 2] to one hot encoding
221            class_ids = [0, 1, 2]
222        elif isinstance(one_hot_encoding, int):
223            class_ids = list(range(one_hot_encoding))
224        elif isinstance(one_hot_encoding, (list, tuple)):
225            class_ids = list(one_hot_encoding)
226        else:
227            raise ValueError(
228                f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list."
229            )
230        label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids)
231        msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden."
232        kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
233
234    return torch_em.default_segmentation_dataset(
235        raw_paths=all_paths,
236        raw_key="raw",
237        label_paths=all_paths,
238        label_key="labels",
239        patch_shape=patch_shape,
240        **kwargs
241    )

Get dataset for segmentation of myelinated axons.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • patch_shape: The patch shape to use for training.
  • download: Whether to download the data if it is not present.
  • one_hot_encoding: Whether to return the labels as one hot encoding.
  • val_fraction: The fraction of the data to use for validation.
  • split: The data split. Either 'train' or 'val'.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_axondeepseg_loader( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], patch_shape: Tuple[int, int], batch_size: int, download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[Literal['train', 'val']] = None, **kwargs) -> torch.utils.data.dataloader.DataLoader:
244def get_axondeepseg_loader(
245    path: Union[str, os.PathLike],
246    name: Literal["sem", "tem"],
247    patch_shape: Tuple[int, int],
248    batch_size: int,
249    download: bool = False,
250    one_hot_encoding: bool = False,
251    val_fraction: Optional[float] = None,
252    split: Optional[Literal["train", "val"]] = None,
253    **kwargs
254) -> DataLoader:
255    """Get dataloader for the segmentation of myelinated axons.
256
257    Args:
258        path: Filepath to a folder where the downloaded data will be saved.
259        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
260        patch_shape: The patch shape to use for training.
261        batch_size: The batch size for training.
262        download: Whether to download the data if it is not present.
263        one_hot_encoding: Whether to return the labels as one hot encoding.
264        val_fraction: The fraction of the data to use for validation.
265        split: The data split. Either 'train' or 'val'.
266        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
267
268    Returns:
269        The PyTorch DataLoader.
270    """
271    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
272    dataset = get_axondeepseg_dataset(
273        path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding,
274        val_fraction=val_fraction, split=split, **ds_kwargs
275    )
276    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get dataloader for the segmentation of myelinated axons.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • patch_shape: The patch shape to use for training.
  • batch_size: The batch size for training.
  • download: Whether to download the data if it is not present.
  • one_hot_encoding: Whether to return the labels as one hot encoding.
  • val_fraction: The fraction of the data to use for validation.
  • split: The data split. Either 'train' or 'val'.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The PyTorch DataLoader.