torch_em.data.datasets.electron_microscopy.axondeepseg

AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. It contains two different data types: TEM and SEM.

The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. Please cite this publication if you use the dataset in your research.

  1"""AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM.
  2It contains two different data types: TEM and SEM.
  3
  4The dataset was published in https://doi.org/10.1038/s41598-018-22181-4.
  5Please cite this publication if you use the dataset in your research.
  6"""
  7
  8import os
  9from glob import glob
 10from shutil import rmtree
 11from typing import Optional, Tuple, Union, Literal, List
 12
 13import imageio
 14import numpy as np
 15
 16from torch.utils.data import Dataset, DataLoader
 17
 18import torch_em
 19
 20from .. import util
 21
 22URLS = {
 23    "sem": "https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip",
 24    "tem": "https://osf.io/download/uewd9"
 25}
 26CHECKSUMS = {
 27    "sem": "12f2f03834c41720badf00131bb7b7a2127e532cf78e01fbea398e1ff800779b",
 28    "tem": "e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95"
 29}
 30
 31
 32def _preprocess_sem_data(out_path):
 33    import h5py
 34
 35    # preprocess the data to get it to a better data format
 36    data_root = os.path.join(out_path, "data_axondeepseg_sem-master")
 37    assert os.path.exists(data_root)
 38
 39    # get the raw data paths
 40    raw_folders = glob(os.path.join(data_root, "sub-rat*"))
 41    raw_folders.sort()
 42    raw_paths = []
 43    for folder in raw_folders:
 44        paths = glob(os.path.join(folder, "micr", "*.png"))
 45        paths.sort()
 46        raw_paths.extend(paths)
 47
 48    # get the label paths
 49    label_folders = glob(os.path.join(
 50        data_root, "derivatives", "labels", "sub-rat*"
 51    ))
 52    label_folders.sort()
 53    label_paths = []
 54    for folder in label_folders:
 55        paths = glob(os.path.join(folder, "micr", "*axonmyelin-manual.png"))
 56        paths.sort()
 57        label_paths.extend(paths)
 58    assert len(raw_paths) == len(label_paths), f"{len(raw_paths)}, {len(label_paths)}"
 59
 60    # process raw data and labels
 61    for i, (rp, lp) in enumerate(zip(raw_paths, label_paths)):
 62        outp = os.path.join(out_path, f"sem_data_{i}.h5")
 63        with h5py.File(outp, "w") as f:
 64
 65            # raw data: invert to match tem em intensities
 66            raw = imageio.imread(rp)
 67            assert np.dtype(raw.dtype) == np.dtype("uint8")
 68            if raw.ndim == 3:  # Some images have extra channels (RGBA or grayscale+alpha).
 69                raw = raw[..., 0]
 70            raw = 255 - raw
 71            f.create_dataset("raw", data=raw, compression="gzip")
 72
 73            # labels: map from
 74            # 0 -> 0
 75            # 127, 128 -> 1
 76            # 255 -> 2
 77            labels = imageio.imread(lp)
 78            if labels.ndim == 3:  # Some labels have an extra alpha channel.
 79                labels = labels[..., 0]
 80            assert labels.shape == raw.shape, f"{labels.shape}, {raw.shape}"
 81            label_vals = np.unique(labels)
 82            # 127, 128: both myelin labels, 130, 233: noise
 83            assert len(np.setdiff1d(label_vals, [0, 127, 128, 130, 233, 255])) == 0, f"{label_vals}"
 84            new_labels = np.zeros_like(labels)
 85            new_labels[labels == 127] = 1
 86            new_labels[labels == 128] = 1
 87            new_labels[labels == 255] = 2
 88            f.create_dataset("labels", data=new_labels, compression="gzip")
 89
 90    # clean up
 91    rmtree(data_root)
 92
 93
 94def _preprocess_tem_data(out_path):
 95    import h5py
 96
 97    data_root = os.path.join(out_path, "TEM_dataset")
 98    folder_names = os.listdir(data_root)
 99    folders = [os.path.join(data_root, fname) for fname in folder_names
100               if os.path.isdir(os.path.join(data_root, fname))]
101    for i, folder in enumerate(folders):
102        data_out = os.path.join(out_path, f"tem_{i}.h5")
103        with h5py.File(data_out, "w") as f:
104            im = imageio.imread(os.path.join(folder, "image.png"))
105            f.create_dataset("raw", data=im, compression="gzip")
106
107            # labels: map from
108            # 0 -> 0
109            # 128 -> 1
110            # 255 -> 2
111            # the rest are noise
112            labels = imageio.imread(os.path.join(folder, "mask.png"))
113            new_labels = np.zeros_like(labels)
114            new_labels[labels == 128] = 1
115            new_labels[labels == 255] = 2
116            f.create_dataset("labels", data=new_labels, compression="gzip")
117
118    # clean up
119    rmtree(data_root)
120
121
122def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str:
123    """Download the AxonDeepSeg data.
124
125    Args:
126        path: Filepath to a folder where the downloaded data will be saved.
127        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
128        download: Whether to download the data if it is not present.
129
130    Returns:
131        The filepath for the downloaded data.
132    """
133
134    # download and unzip the data
135    url, checksum = URLS[name], CHECKSUMS[name]
136    os.makedirs(path, exist_ok=True)
137    out_path = os.path.join(path, name)
138    if os.path.exists(out_path):
139        return out_path
140
141    tmp_path = os.path.join(path, f"{name}.zip")
142    util.download_source(tmp_path, url, download, checksum=checksum)
143    util.unzip(tmp_path, out_path, remove=True)
144
145    if name == "sem":
146        _preprocess_sem_data(out_path)
147    elif name == "tem":
148        _preprocess_tem_data(out_path)
149    else:
150        raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.")
151
152    return out_path
153
154
155def get_axondeepseg_paths(
156    path: Union[str, os.PathLike],
157    name: Literal["sem", "tem"],
158    download: bool = False,
159    val_fraction: Optional[float] = None,
160    split: Optional[str] = None,
161) -> List[str]:
162    """Get paths to the AxonDeepSeg data.
163
164    Args:
165        path: Filepath to a folder where the downloaded data will be saved.
166        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
167        download: Whether to download the data if it is not present.
168        val_fraction: The fraction of the data to use for validation.
169        split: The data split. Either 'train' or 'val'.
170
171    Returns:
172        List of paths for all the data.
173    """
174    all_paths = []
175    for nn in name:
176        data_root = get_axondeepseg_data(path, nn, download)
177        paths = glob(os.path.join(data_root, "*.h5"))
178        paths.sort()
179        if val_fraction is not None:
180            assert split is not None
181            n_samples = int(len(paths) * (1 - val_fraction))
182            paths = paths[:n_samples] if split == "train" else paths[n_samples:]
183        all_paths.extend(paths)
184
185    return all_paths
186
187
188def get_axondeepseg_dataset(
189    path: Union[str, os.PathLike],
190    name: Literal["sem", "tem"],
191    patch_shape: Tuple[int, int],
192    download: bool = False,
193    one_hot_encoding: bool = False,
194    val_fraction: Optional[float] = None,
195    split: Optional[Literal['train', 'val']] = None,
196    **kwargs,
197) -> Dataset:
198    """Get dataset for segmentation of myelinated axons.
199
200    Args:
201        path: Filepath to a folder where the downloaded data will be saved.
202        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
203        patch_shape: The patch shape to use for training.
204        download: Whether to download the data if it is not present.
205        one_hot_encoding: Whether to return the labels as one hot encoding.
206        val_fraction: The fraction of the data to use for validation.
207        split: The data split. Either 'train' or 'val'.
208        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
209
210    Returns:
211        The segmentation dataset.
212    """
213    if isinstance(name, str):
214        name = [name]
215    assert isinstance(name, (tuple, list))
216
217    all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split)
218
219    if one_hot_encoding:
220        if isinstance(one_hot_encoding, bool):
221            # add transformation to go from [0, 1, 2] to one hot encoding
222            class_ids = [0, 1, 2]
223        elif isinstance(one_hot_encoding, int):
224            class_ids = list(range(one_hot_encoding))
225        elif isinstance(one_hot_encoding, (list, tuple)):
226            class_ids = list(one_hot_encoding)
227        else:
228            raise ValueError(
229                f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list."
230            )
231        label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids)
232        msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden."
233        kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
234
235    return torch_em.default_segmentation_dataset(
236        raw_paths=all_paths,
237        raw_key="raw",
238        label_paths=all_paths,
239        label_key="labels",
240        patch_shape=patch_shape,
241        **kwargs
242    )
243
244
245def get_axondeepseg_loader(
246    path: Union[str, os.PathLike],
247    name: Literal["sem", "tem"],
248    patch_shape: Tuple[int, int],
249    batch_size: int,
250    download: bool = False,
251    one_hot_encoding: bool = False,
252    val_fraction: Optional[float] = None,
253    split: Optional[Literal["train", "val"]] = None,
254    **kwargs
255) -> DataLoader:
256    """Get dataloader for the segmentation of myelinated axons.
257
258    Args:
259        path: Filepath to a folder where the downloaded data will be saved.
260        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
261        patch_shape: The patch shape to use for training.
262        batch_size: The batch size for training.
263        download: Whether to download the data if it is not present.
264        one_hot_encoding: Whether to return the labels as one hot encoding.
265        val_fraction: The fraction of the data to use for validation.
266        split: The data split. Either 'train' or 'val'.
267        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
268
269    Returns:
270        The PyTorch DataLoader.
271    """
272    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
273    dataset = get_axondeepseg_dataset(
274        path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding,
275        val_fraction=val_fraction, split=split, **ds_kwargs
276    )
277    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'sem': 'https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip', 'tem': 'https://osf.io/download/uewd9'}
CHECKSUMS = {'sem': '12f2f03834c41720badf00131bb7b7a2127e532cf78e01fbea398e1ff800779b', 'tem': 'e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95'}
def get_axondeepseg_data( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], download: bool = False) -> str:
123def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str:
124    """Download the AxonDeepSeg data.
125
126    Args:
127        path: Filepath to a folder where the downloaded data will be saved.
128        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
129        download: Whether to download the data if it is not present.
130
131    Returns:
132        The filepath for the downloaded data.
133    """
134
135    # download and unzip the data
136    url, checksum = URLS[name], CHECKSUMS[name]
137    os.makedirs(path, exist_ok=True)
138    out_path = os.path.join(path, name)
139    if os.path.exists(out_path):
140        return out_path
141
142    tmp_path = os.path.join(path, f"{name}.zip")
143    util.download_source(tmp_path, url, download, checksum=checksum)
144    util.unzip(tmp_path, out_path, remove=True)
145
146    if name == "sem":
147        _preprocess_sem_data(out_path)
148    elif name == "tem":
149        _preprocess_tem_data(out_path)
150    else:
151        raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.")
152
153    return out_path

Download the AxonDeepSeg data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • download: Whether to download the data if it is not present.
Returns:

The filepath for the downloaded data.

def get_axondeepseg_paths( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], download: bool = False, val_fraction: Optional[float] = None, split: Optional[str] = None) -> List[str]:
156def get_axondeepseg_paths(
157    path: Union[str, os.PathLike],
158    name: Literal["sem", "tem"],
159    download: bool = False,
160    val_fraction: Optional[float] = None,
161    split: Optional[str] = None,
162) -> List[str]:
163    """Get paths to the AxonDeepSeg data.
164
165    Args:
166        path: Filepath to a folder where the downloaded data will be saved.
167        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
168        download: Whether to download the data if it is not present.
169        val_fraction: The fraction of the data to use for validation.
170        split: The data split. Either 'train' or 'val'.
171
172    Returns:
173        List of paths for all the data.
174    """
175    all_paths = []
176    for nn in name:
177        data_root = get_axondeepseg_data(path, nn, download)
178        paths = glob(os.path.join(data_root, "*.h5"))
179        paths.sort()
180        if val_fraction is not None:
181            assert split is not None
182            n_samples = int(len(paths) * (1 - val_fraction))
183            paths = paths[:n_samples] if split == "train" else paths[n_samples:]
184        all_paths.extend(paths)
185
186    return all_paths

Get paths to the AxonDeepSeg data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • download: Whether to download the data if it is not present.
  • val_fraction: The fraction of the data to use for validation.
  • split: The data split. Either 'train' or 'val'.
Returns:

List of paths for all the data.

def get_axondeepseg_dataset( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], patch_shape: Tuple[int, int], download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[Literal['train', 'val']] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
189def get_axondeepseg_dataset(
190    path: Union[str, os.PathLike],
191    name: Literal["sem", "tem"],
192    patch_shape: Tuple[int, int],
193    download: bool = False,
194    one_hot_encoding: bool = False,
195    val_fraction: Optional[float] = None,
196    split: Optional[Literal['train', 'val']] = None,
197    **kwargs,
198) -> Dataset:
199    """Get dataset for segmentation of myelinated axons.
200
201    Args:
202        path: Filepath to a folder where the downloaded data will be saved.
203        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
204        patch_shape: The patch shape to use for training.
205        download: Whether to download the data if it is not present.
206        one_hot_encoding: Whether to return the labels as one hot encoding.
207        val_fraction: The fraction of the data to use for validation.
208        split: The data split. Either 'train' or 'val'.
209        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
210
211    Returns:
212        The segmentation dataset.
213    """
214    if isinstance(name, str):
215        name = [name]
216    assert isinstance(name, (tuple, list))
217
218    all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split)
219
220    if one_hot_encoding:
221        if isinstance(one_hot_encoding, bool):
222            # add transformation to go from [0, 1, 2] to one hot encoding
223            class_ids = [0, 1, 2]
224        elif isinstance(one_hot_encoding, int):
225            class_ids = list(range(one_hot_encoding))
226        elif isinstance(one_hot_encoding, (list, tuple)):
227            class_ids = list(one_hot_encoding)
228        else:
229            raise ValueError(
230                f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list."
231            )
232        label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids)
233        msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden."
234        kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
235
236    return torch_em.default_segmentation_dataset(
237        raw_paths=all_paths,
238        raw_key="raw",
239        label_paths=all_paths,
240        label_key="labels",
241        patch_shape=patch_shape,
242        **kwargs
243    )

Get dataset for segmentation of myelinated axons.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • patch_shape: The patch shape to use for training.
  • download: Whether to download the data if it is not present.
  • one_hot_encoding: Whether to return the labels as one hot encoding.
  • val_fraction: The fraction of the data to use for validation.
  • split: The data split. Either 'train' or 'val'.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_axondeepseg_loader( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], patch_shape: Tuple[int, int], batch_size: int, download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[Literal['train', 'val']] = None, **kwargs) -> torch.utils.data.dataloader.DataLoader:
246def get_axondeepseg_loader(
247    path: Union[str, os.PathLike],
248    name: Literal["sem", "tem"],
249    patch_shape: Tuple[int, int],
250    batch_size: int,
251    download: bool = False,
252    one_hot_encoding: bool = False,
253    val_fraction: Optional[float] = None,
254    split: Optional[Literal["train", "val"]] = None,
255    **kwargs
256) -> DataLoader:
257    """Get dataloader for the segmentation of myelinated axons.
258
259    Args:
260        path: Filepath to a folder where the downloaded data will be saved.
261        name: The name of the dataset to download. Can be either 'sem' or 'tem'.
262        patch_shape: The patch shape to use for training.
263        batch_size: The batch size for training.
264        download: Whether to download the data if it is not present.
265        one_hot_encoding: Whether to return the labels as one hot encoding.
266        val_fraction: The fraction of the data to use for validation.
267        split: The data split. Either 'train' or 'val'.
268        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
269
270    Returns:
271        The PyTorch DataLoader.
272    """
273    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
274    dataset = get_axondeepseg_dataset(
275        path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding,
276        val_fraction=val_fraction, split=split, **ds_kwargs
277    )
278    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get dataloader for the segmentation of myelinated axons.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset to download. Can be either 'sem' or 'tem'.
  • patch_shape: The patch shape to use for training.
  • batch_size: The batch size for training.
  • download: Whether to download the data if it is not present.
  • one_hot_encoding: Whether to return the labels as one hot encoding.
  • val_fraction: The fraction of the data to use for validation.
  • split: The data split. Either 'train' or 'val'.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The PyTorch DataLoader.