torch_em.data.datasets.electron_microscopy.astih

ASTIH is a dataset for axon and myelin segmentation in microscopy images.

It contains diverse microscopy datasets (TEM, SEM, BF) designed to benchmark and train axon and myelin segmentation models. It provides over 60,000 manually segmented fibers across three microscopy modalities.

The dataset is described at https://axondeepseg.github.io/ASTIH/. The dataset is from the publication https://openreview.net/forum?id=ExBq9A8Ypk. Please cite the corresponding publication if you use the dataset in your research.

  1"""ASTIH is a dataset for axon and myelin segmentation in microscopy images.
  2
  3It contains diverse microscopy datasets (TEM, SEM, BF) designed to benchmark
  4and train axon and myelin segmentation models. It provides over 60,000 manually
  5segmented fibers across three microscopy modalities.
  6
  7The dataset is described at https://axondeepseg.github.io/ASTIH/.
  8The dataset is from the publication https://openreview.net/forum?id=ExBq9A8Ypk.
  9Please cite the corresponding publication if you use the dataset in your research.
 10"""
 11
 12import os
 13import io
 14from glob import glob
 15from typing import List, Literal, Optional, Sequence, Tuple, Union
 16
 17import imageio
 18import numpy as np
 19import requests
 20from tqdm import tqdm
 21
 22from torch.utils.data import Dataset, DataLoader
 23
 24import torch_em
 25
 26from .. import util
 27
 28
 29DANDI_API = "https://api.dandiarchive.org/api"
 30
 31DATASETS = {
 32    "TEM1": {
 33        "dandi_id": "001436",
 34        "version": "0.250512.1625",
 35        "description": "TEM Images of Corpus Callosum in Control and Cuprizone-Intoxicated Mice",
 36        "test_subjects": ["sub-nyuMouse26"],
 37        "file_ext": "png",
 38    },
 39    "TEM2": {
 40        "dandi_id": "001350",
 41        "version": "0.250511.1527",
 42        "description": "TEM Images of Corpus Callosum in Flox/SRF-cKO Mice",
 43        "test_subjects": None,  # External test set.
 44        "test_url": "https://github.com/axondeepseg/data_axondeepseg_srf_testing/archive/refs/tags/r20250513-neurips2025.zip",  # noqa
 45        "file_ext": "png",
 46    },
 47    "SEM1": {
 48        "dandi_id": "001442",
 49        "version": "0.250512.1626",
 50        "description": "SEM Images of Rat Spinal Cord",
 51        "test_subjects": ["sub-rat6"],
 52        "file_ext": "png",
 53    },
 54    "BF1": {
 55        "dandi_id": "001440",
 56        "version": "0.250509.1913",
 57        "description": "BF Images of Rat Nerves at Different Regeneration Stages",
 58        "test_subjects": ["sub-uoftRat02", "sub-uoftRat07"],
 59        "file_ext": "png",
 60    },
 61    "BF2": {
 62        "dandi_id": "001630",
 63        "version": "0.251127.1424",
 64        "description": "Bright-Field Images of Rabbit Nerves",
 65        "test_subjects": ["sub-22G132040x3"],
 66        "file_ext": "tif",
 67    },
 68}
 69
 70DATASET_NAMES = list(DATASETS.keys())
 71
 72LABEL_CLASSES = {"background": 0, "myelin": 1, "axon": 2}
 73
 74
 75def _list_dandi_assets(dandi_id, version):
 76    """List all assets in a DANDI dataset via the REST API."""
 77    all_assets = []
 78    url = f"{DANDI_API}/dandisets/{dandi_id}/versions/{version}/assets/?page_size=200"
 79    while url:
 80        r = requests.get(url)
 81        r.raise_for_status()
 82        data = r.json()
 83        all_assets.extend(data["results"])
 84        url = data.get("next")
 85    return all_assets
 86
 87
 88def _download_dandi_asset(asset_id, out_path):
 89    """Download a single DANDI asset by its ID."""
 90    url = f"{DANDI_API}/assets/{asset_id}/download/"
 91    with requests.get(url, stream=True, allow_redirects=True) as r:
 92        r.raise_for_status()
 93        file_size = int(r.headers.get("Content-Length", 0))
 94        desc = f"Download {os.path.basename(out_path)}"
 95        with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(out_path, "wb") as f:
 96            from shutil import copyfileobj
 97            copyfileobj(r_raw, f)
 98
 99
100def _find_image_label_pairs(assets, file_ext):
101    """Find matching image and axonmyelin label pairs from the DANDI asset list."""
102    # Index label assets by their stem.
103    label_map = {}
104    for a in assets:
105        p = a["path"]
106        if "axonmyelin-manual.png" in p:
107            # Extract the image stem: remove the _seg-axonmyelin-manual.png suffix
108            stem = os.path.basename(p).replace("_seg-axonmyelin-manual.png", "")
109            label_map[stem] = a
110
111    # Find images that have a matching label.
112    pairs = []
113    for a in assets:
114        p = a["path"]
115        if "/micr/" in p and not p.startswith("derivatives") and p.endswith(f".{file_ext}"):
116            stem = os.path.basename(p).rsplit(".", 1)[0]
117            if stem in label_map:
118                subject = p.split("/")[0]
119                pairs.append({
120                    "subject": subject,
121                    "image_asset": a,
122                    "label_asset": label_map[stem],
123                    "stem": stem,
124                })
125    return pairs
126
127
128def _preprocess_label(label):
129    """Map label values to: 0=background, 1=myelin, 2=axon."""
130    if label.ndim == 3:
131        label = label[..., 0]
132    new_label = np.zeros_like(label)
133    new_label[(label == 127) | (label == 128)] = 1
134    new_label[label == 255] = 2
135    return new_label
136
137
138def _download_and_preprocess(out_path, dataset_info, split, download):
139    """Download data from DANDI, pair images with labels, and save as h5 files."""
140    import h5py
141
142    if not download:
143        raise RuntimeError(f"Cannot find the data at {out_path}, but download was set to False")
144
145    os.makedirs(out_path, exist_ok=True)
146
147    dandi_id = dataset_info["dandi_id"]
148    version = dataset_info["version"]
149    file_ext = dataset_info["file_ext"]
150    test_subjects = dataset_info["test_subjects"]
151
152    # List and pair assets.
153    assets = _list_dandi_assets(dandi_id, version)
154    pairs = _find_image_label_pairs(assets, file_ext)
155
156    if len(pairs) == 0:
157        raise RuntimeError(f"No image-label pairs found for DANDI:{dandi_id}")
158
159    # Filter by split.
160    if test_subjects is not None:
161        if split == "train":
162            pairs = [p for p in pairs if p["subject"] not in test_subjects]
163        else:
164            pairs = [p for p in pairs if p["subject"] in test_subjects]
165    else:
166        # For datasets with external test sets (TEM2), all DANDI data is training.
167        if split == "test":
168            raise NotImplementedError(
169                "The test set for this dataset is hosted externally. "
170                "Please use the ASTIH repository's get_data.py script for the test split."
171            )
172
173    # Download and preprocess each pair.
174    for pair in tqdm(pairs, desc=f"Processing {split} data"):
175        h5_path = os.path.join(out_path, f"{pair['stem']}.h5")
176        if os.path.exists(h5_path):
177            continue
178
179        # Download image.
180        img_data = requests.get(f"{DANDI_API}/assets/{pair['image_asset']['asset_id']}/download/").content
181        raw = imageio.imread(io.BytesIO(img_data))
182        if raw.ndim == 3:
183            raw = raw[..., 0]
184
185        # Download label.
186        lbl_data = requests.get(f"{DANDI_API}/assets/{pair['label_asset']['asset_id']}/download/").content
187        label = imageio.imread(io.BytesIO(lbl_data))
188        label = _preprocess_label(label)
189
190        assert raw.shape == label.shape, f"Shape mismatch: {raw.shape} vs {label.shape}"
191
192        with h5py.File(h5_path, "w") as f:
193            f.create_dataset("raw", data=raw, compression="gzip")
194            f.create_dataset("labels", data=label, compression="gzip")
195
196
197def get_astih_data(
198    path: Union[os.PathLike, str],
199    name: str,
200    split: Literal["train", "test"],
201    download: bool = False,
202) -> str:
203    """Download the ASTIH data.
204
205    Args:
206        path: Filepath to a folder where the downloaded data will be saved.
207        name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'.
208        split: The data split. Either 'train' or 'test'.
209        download: Whether to download the data if it is not present.
210
211    Returns:
212        The filepath for the downloaded data.
213    """
214    assert name in DATASETS, f"Invalid dataset name: {name}. Choose from {DATASET_NAMES}."
215
216    out_path = os.path.join(path, name, split)
217    if os.path.exists(out_path) and len(glob(os.path.join(out_path, "*.h5"))) > 0:
218        return out_path
219
220    _download_and_preprocess(out_path, DATASETS[name], split, download)
221    return out_path
222
223
224def get_astih_paths(
225    path: Union[os.PathLike, str],
226    name: Optional[Union[str, Sequence[str]]] = None,
227    split: Literal["train", "test"] = "train",
228    download: bool = False,
229) -> List[str]:
230    """Get paths to the ASTIH data.
231
232    Args:
233        path: Filepath to a folder where the downloaded data will be saved.
234        name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'.
235            Can be a single name, a list of names, or None to load all datasets.
236        split: The data split. Either 'train' or 'test'.
237        download: Whether to download the data if it is not present.
238
239    Returns:
240        The filepaths for the stored data.
241    """
242    if name is None:
243        name = DATASET_NAMES
244    elif isinstance(name, str):
245        name = [name]
246
247    all_paths = []
248    for nn in name:
249        data_root = get_astih_data(path, nn, split, download)
250        paths = glob(os.path.join(data_root, "*.h5"))
251        paths.sort()
252        all_paths.extend(paths)
253
254    return all_paths
255
256
257def get_astih_dataset(
258    path: Union[os.PathLike, str],
259    patch_shape: Tuple[int, int],
260    name: Optional[Union[str, Sequence[str]]] = None,
261    split: Literal["train", "test"] = "train",
262    download: bool = False,
263    label_classes: Optional[Sequence[str]] = None,
264    **kwargs,
265) -> Dataset:
266    """Get the ASTIH dataset for axon and myelin segmentation.
267
268    Args:
269        path: Filepath to a folder where the downloaded data will be saved.
270        patch_shape: The patch shape to use for training.
271        name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2',
272            a list of these names to combine datasets, or None to load all datasets.
273        split: The data split. Either 'train' or 'test'.
274        download: Whether to download the data if it is not present.
275        label_classes: The label classes to use for one-hot encoding. Available classes are
276            'background', 'myelin', and 'axon'. By default set to None, which returns
277            the label map with all classes (0=background, 1=myelin, 2=axon).
278        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
279
280    Returns:
281        The segmentation dataset.
282    """
283    all_paths = get_astih_paths(path, name, split, download)
284
285    if label_classes is not None:
286        class_ids = []
287        for cls_name in label_classes:
288            if cls_name not in LABEL_CLASSES:
289                raise ValueError(
290                    f"Invalid class name: '{cls_name}'. Choose from {list(LABEL_CLASSES.keys())}."
291                )
292            class_ids.append(LABEL_CLASSES[cls_name])
293        label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids)
294        msg = "'label_classes' is set, but 'label_transform' is in the kwargs. It will be over-ridden."
295        kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
296
297    return torch_em.default_segmentation_dataset(
298        raw_paths=all_paths,
299        raw_key="raw",
300        label_paths=all_paths,
301        label_key="labels",
302        patch_shape=patch_shape,
303        **kwargs,
304    )
305
306
307def get_astih_loader(
308    path: Union[os.PathLike, str],
309    patch_shape: Tuple[int, int],
310    batch_size: int,
311    name: Optional[Union[str, Sequence[str]]] = None,
312    split: Literal["train", "test"] = "train",
313    download: bool = False,
314    label_classes: Optional[Sequence[str]] = None,
315    **kwargs,
316) -> DataLoader:
317    """Get the DataLoader for axon and myelin segmentation in the ASTIH dataset.
318
319    Args:
320        path: Filepath to a folder where the downloaded data will be saved.
321        patch_shape: The patch shape to use for training.
322        batch_size: The batch size for training.
323        name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2',
324            a list of these names to combine datasets, or None to load all datasets.
325        split: The data split. Either 'train' or 'test'.
326        download: Whether to download the data if it is not present.
327        label_classes: The label classes to use for one-hot encoding. Available classes are
328            'background', 'myelin', and 'axon'. By default set to None, which returns
329            the label map with all classes (0=background, 1=myelin, 2=axon).
330        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
331
332    Returns:
333        The PyTorch DataLoader.
334    """
335    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
336    dataset = get_astih_dataset(
337        path, patch_shape, name=name, split=split, download=download,
338        label_classes=label_classes, **ds_kwargs,
339    )
340    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
DANDI_API = 'https://api.dandiarchive.org/api'
DATASETS = {'TEM1': {'dandi_id': '001436', 'version': '0.250512.1625', 'description': 'TEM Images of Corpus Callosum in Control and Cuprizone-Intoxicated Mice', 'test_subjects': ['sub-nyuMouse26'], 'file_ext': 'png'}, 'TEM2': {'dandi_id': '001350', 'version': '0.250511.1527', 'description': 'TEM Images of Corpus Callosum in Flox/SRF-cKO Mice', 'test_subjects': None, 'test_url': 'https://github.com/axondeepseg/data_axondeepseg_srf_testing/archive/refs/tags/r20250513-neurips2025.zip', 'file_ext': 'png'}, 'SEM1': {'dandi_id': '001442', 'version': '0.250512.1626', 'description': 'SEM Images of Rat Spinal Cord', 'test_subjects': ['sub-rat6'], 'file_ext': 'png'}, 'BF1': {'dandi_id': '001440', 'version': '0.250509.1913', 'description': 'BF Images of Rat Nerves at Different Regeneration Stages', 'test_subjects': ['sub-uoftRat02', 'sub-uoftRat07'], 'file_ext': 'png'}, 'BF2': {'dandi_id': '001630', 'version': '0.251127.1424', 'description': 'Bright-Field Images of Rabbit Nerves', 'test_subjects': ['sub-22G132040x3'], 'file_ext': 'tif'}}
DATASET_NAMES = ['TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2']
LABEL_CLASSES = {'background': 0, 'myelin': 1, 'axon': 2}
def get_astih_data( path: Union[os.PathLike, str], name: str, split: Literal['train', 'test'], download: bool = False) -> str:
198def get_astih_data(
199    path: Union[os.PathLike, str],
200    name: str,
201    split: Literal["train", "test"],
202    download: bool = False,
203) -> str:
204    """Download the ASTIH data.
205
206    Args:
207        path: Filepath to a folder where the downloaded data will be saved.
208        name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'.
209        split: The data split. Either 'train' or 'test'.
210        download: Whether to download the data if it is not present.
211
212    Returns:
213        The filepath for the downloaded data.
214    """
215    assert name in DATASETS, f"Invalid dataset name: {name}. Choose from {DATASET_NAMES}."
216
217    out_path = os.path.join(path, name, split)
218    if os.path.exists(out_path) and len(glob(os.path.join(out_path, "*.h5"))) > 0:
219        return out_path
220
221    _download_and_preprocess(out_path, DATASETS[name], split, download)
222    return out_path

Download the ASTIH data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'.
  • split: The data split. Either 'train' or 'test'.
  • download: Whether to download the data if it is not present.
Returns:

The filepath for the downloaded data.

def get_astih_paths( path: Union[os.PathLike, str], name: Union[str, Sequence[str], NoneType] = None, split: Literal['train', 'test'] = 'train', download: bool = False) -> List[str]:
225def get_astih_paths(
226    path: Union[os.PathLike, str],
227    name: Optional[Union[str, Sequence[str]]] = None,
228    split: Literal["train", "test"] = "train",
229    download: bool = False,
230) -> List[str]:
231    """Get paths to the ASTIH data.
232
233    Args:
234        path: Filepath to a folder where the downloaded data will be saved.
235        name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'.
236            Can be a single name, a list of names, or None to load all datasets.
237        split: The data split. Either 'train' or 'test'.
238        download: Whether to download the data if it is not present.
239
240    Returns:
241        The filepaths for the stored data.
242    """
243    if name is None:
244        name = DATASET_NAMES
245    elif isinstance(name, str):
246        name = [name]
247
248    all_paths = []
249    for nn in name:
250        data_root = get_astih_data(path, nn, split, download)
251        paths = glob(os.path.join(data_root, "*.h5"))
252        paths.sort()
253        all_paths.extend(paths)
254
255    return all_paths

Get paths to the ASTIH data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'. Can be a single name, a list of names, or None to load all datasets.
  • split: The data split. Either 'train' or 'test'.
  • download: Whether to download the data if it is not present.
Returns:

The filepaths for the stored data.

def get_astih_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], name: Union[str, Sequence[str], NoneType] = None, split: Literal['train', 'test'] = 'train', download: bool = False, label_classes: Optional[Sequence[str]] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
258def get_astih_dataset(
259    path: Union[os.PathLike, str],
260    patch_shape: Tuple[int, int],
261    name: Optional[Union[str, Sequence[str]]] = None,
262    split: Literal["train", "test"] = "train",
263    download: bool = False,
264    label_classes: Optional[Sequence[str]] = None,
265    **kwargs,
266) -> Dataset:
267    """Get the ASTIH dataset for axon and myelin segmentation.
268
269    Args:
270        path: Filepath to a folder where the downloaded data will be saved.
271        patch_shape: The patch shape to use for training.
272        name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2',
273            a list of these names to combine datasets, or None to load all datasets.
274        split: The data split. Either 'train' or 'test'.
275        download: Whether to download the data if it is not present.
276        label_classes: The label classes to use for one-hot encoding. Available classes are
277            'background', 'myelin', and 'axon'. By default set to None, which returns
278            the label map with all classes (0=background, 1=myelin, 2=axon).
279        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
280
281    Returns:
282        The segmentation dataset.
283    """
284    all_paths = get_astih_paths(path, name, split, download)
285
286    if label_classes is not None:
287        class_ids = []
288        for cls_name in label_classes:
289            if cls_name not in LABEL_CLASSES:
290                raise ValueError(
291                    f"Invalid class name: '{cls_name}'. Choose from {list(LABEL_CLASSES.keys())}."
292                )
293            class_ids.append(LABEL_CLASSES[cls_name])
294        label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids)
295        msg = "'label_classes' is set, but 'label_transform' is in the kwargs. It will be over-ridden."
296        kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
297
298    return torch_em.default_segmentation_dataset(
299        raw_paths=all_paths,
300        raw_key="raw",
301        label_paths=all_paths,
302        label_key="labels",
303        patch_shape=patch_shape,
304        **kwargs,
305    )

Get the ASTIH dataset for axon and myelin segmentation.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • patch_shape: The patch shape to use for training.
  • name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2', a list of these names to combine datasets, or None to load all datasets.
  • split: The data split. Either 'train' or 'test'.
  • download: Whether to download the data if it is not present.
  • label_classes: The label classes to use for one-hot encoding. Available classes are 'background', 'myelin', and 'axon'. By default set to None, which returns the label map with all classes (0=background, 1=myelin, 2=axon).
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_astih_loader( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], batch_size: int, name: Union[str, Sequence[str], NoneType] = None, split: Literal['train', 'test'] = 'train', download: bool = False, label_classes: Optional[Sequence[str]] = None, **kwargs) -> torch.utils.data.dataloader.DataLoader:
308def get_astih_loader(
309    path: Union[os.PathLike, str],
310    patch_shape: Tuple[int, int],
311    batch_size: int,
312    name: Optional[Union[str, Sequence[str]]] = None,
313    split: Literal["train", "test"] = "train",
314    download: bool = False,
315    label_classes: Optional[Sequence[str]] = None,
316    **kwargs,
317) -> DataLoader:
318    """Get the DataLoader for axon and myelin segmentation in the ASTIH dataset.
319
320    Args:
321        path: Filepath to a folder where the downloaded data will be saved.
322        patch_shape: The patch shape to use for training.
323        batch_size: The batch size for training.
324        name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2',
325            a list of these names to combine datasets, or None to load all datasets.
326        split: The data split. Either 'train' or 'test'.
327        download: Whether to download the data if it is not present.
328        label_classes: The label classes to use for one-hot encoding. Available classes are
329            'background', 'myelin', and 'axon'. By default set to None, which returns
330            the label map with all classes (0=background, 1=myelin, 2=axon).
331        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
332
333    Returns:
334        The PyTorch DataLoader.
335    """
336    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
337    dataset = get_astih_dataset(
338        path, patch_shape, name=name, split=split, download=download,
339        label_classes=label_classes, **ds_kwargs,
340    )
341    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the DataLoader for axon and myelin segmentation in the ASTIH dataset.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • patch_shape: The patch shape to use for training.
  • batch_size: The batch size for training.
  • name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2', a list of these names to combine datasets, or None to load all datasets.
  • split: The data split. Either 'train' or 'test'.
  • download: Whether to download the data if it is not present.
  • label_classes: The label classes to use for one-hot encoding. Available classes are 'background', 'myelin', and 'axon'. By default set to None, which returns the label map with all classes (0=background, 1=myelin, 2=axon).
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The PyTorch DataLoader.