torch_em.data.datasets.electron_microscopy.astih
ASTIH is a dataset for axon and myelin segmentation in microscopy images.
It contains diverse microscopy datasets (TEM, SEM, BF) designed to benchmark and train axon and myelin segmentation models. It provides over 60,000 manually segmented fibers across three microscopy modalities.
The dataset is described at https://axondeepseg.github.io/ASTIH/. The dataset is from the publication https://openreview.net/forum?id=ExBq9A8Ypk. Please cite the corresponding publication if you use the dataset in your research.
1"""ASTIH is a dataset for axon and myelin segmentation in microscopy images. 2 3It contains diverse microscopy datasets (TEM, SEM, BF) designed to benchmark 4and train axon and myelin segmentation models. It provides over 60,000 manually 5segmented fibers across three microscopy modalities. 6 7The dataset is described at https://axondeepseg.github.io/ASTIH/. 8The dataset is from the publication https://openreview.net/forum?id=ExBq9A8Ypk. 9Please cite the corresponding publication if you use the dataset in your research. 10""" 11 12import os 13import io 14from glob import glob 15from typing import List, Literal, Optional, Sequence, Tuple, Union 16 17import imageio 18import numpy as np 19import requests 20from tqdm import tqdm 21 22from torch.utils.data import Dataset, DataLoader 23 24import torch_em 25 26from .. import util 27 28 29DANDI_API = "https://api.dandiarchive.org/api" 30 31DATASETS = { 32 "TEM1": { 33 "dandi_id": "001436", 34 "version": "0.250512.1625", 35 "description": "TEM Images of Corpus Callosum in Control and Cuprizone-Intoxicated Mice", 36 "test_subjects": ["sub-nyuMouse26"], 37 "file_ext": "png", 38 }, 39 "TEM2": { 40 "dandi_id": "001350", 41 "version": "0.250511.1527", 42 "description": "TEM Images of Corpus Callosum in Flox/SRF-cKO Mice", 43 "test_subjects": None, # External test set. 44 "test_url": "https://github.com/axondeepseg/data_axondeepseg_srf_testing/archive/refs/tags/r20250513-neurips2025.zip", # noqa 45 "file_ext": "png", 46 }, 47 "SEM1": { 48 "dandi_id": "001442", 49 "version": "0.250512.1626", 50 "description": "SEM Images of Rat Spinal Cord", 51 "test_subjects": ["sub-rat6"], 52 "file_ext": "png", 53 }, 54 "BF1": { 55 "dandi_id": "001440", 56 "version": "0.250509.1913", 57 "description": "BF Images of Rat Nerves at Different Regeneration Stages", 58 "test_subjects": ["sub-uoftRat02", "sub-uoftRat07"], 59 "file_ext": "png", 60 }, 61 "BF2": { 62 "dandi_id": "001630", 63 "version": "0.251127.1424", 64 "description": "Bright-Field Images of Rabbit Nerves", 65 "test_subjects": ["sub-22G132040x3"], 66 "file_ext": "tif", 67 }, 68} 69 70DATASET_NAMES = list(DATASETS.keys()) 71 72LABEL_CLASSES = {"background": 0, "myelin": 1, "axon": 2} 73 74 75def _list_dandi_assets(dandi_id, version): 76 """List all assets in a DANDI dataset via the REST API.""" 77 all_assets = [] 78 url = f"{DANDI_API}/dandisets/{dandi_id}/versions/{version}/assets/?page_size=200" 79 while url: 80 r = requests.get(url) 81 r.raise_for_status() 82 data = r.json() 83 all_assets.extend(data["results"]) 84 url = data.get("next") 85 return all_assets 86 87 88def _download_dandi_asset(asset_id, out_path): 89 """Download a single DANDI asset by its ID.""" 90 url = f"{DANDI_API}/assets/{asset_id}/download/" 91 with requests.get(url, stream=True, allow_redirects=True) as r: 92 r.raise_for_status() 93 file_size = int(r.headers.get("Content-Length", 0)) 94 desc = f"Download {os.path.basename(out_path)}" 95 with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(out_path, "wb") as f: 96 from shutil import copyfileobj 97 copyfileobj(r_raw, f) 98 99 100def _find_image_label_pairs(assets, file_ext): 101 """Find matching image and axonmyelin label pairs from the DANDI asset list.""" 102 # Index label assets by their stem. 103 label_map = {} 104 for a in assets: 105 p = a["path"] 106 if "axonmyelin-manual.png" in p: 107 # Extract the image stem: remove the _seg-axonmyelin-manual.png suffix 108 stem = os.path.basename(p).replace("_seg-axonmyelin-manual.png", "") 109 label_map[stem] = a 110 111 # Find images that have a matching label. 112 pairs = [] 113 for a in assets: 114 p = a["path"] 115 if "/micr/" in p and not p.startswith("derivatives") and p.endswith(f".{file_ext}"): 116 stem = os.path.basename(p).rsplit(".", 1)[0] 117 if stem in label_map: 118 subject = p.split("/")[0] 119 pairs.append({ 120 "subject": subject, 121 "image_asset": a, 122 "label_asset": label_map[stem], 123 "stem": stem, 124 }) 125 return pairs 126 127 128def _preprocess_label(label): 129 """Map label values to: 0=background, 1=myelin, 2=axon.""" 130 if label.ndim == 3: 131 label = label[..., 0] 132 new_label = np.zeros_like(label) 133 new_label[(label == 127) | (label == 128)] = 1 134 new_label[label == 255] = 2 135 return new_label 136 137 138def _download_and_preprocess(out_path, dataset_info, split, download): 139 """Download data from DANDI, pair images with labels, and save as h5 files.""" 140 import h5py 141 142 if not download: 143 raise RuntimeError(f"Cannot find the data at {out_path}, but download was set to False") 144 145 os.makedirs(out_path, exist_ok=True) 146 147 dandi_id = dataset_info["dandi_id"] 148 version = dataset_info["version"] 149 file_ext = dataset_info["file_ext"] 150 test_subjects = dataset_info["test_subjects"] 151 152 # List and pair assets. 153 assets = _list_dandi_assets(dandi_id, version) 154 pairs = _find_image_label_pairs(assets, file_ext) 155 156 if len(pairs) == 0: 157 raise RuntimeError(f"No image-label pairs found for DANDI:{dandi_id}") 158 159 # Filter by split. 160 if test_subjects is not None: 161 if split == "train": 162 pairs = [p for p in pairs if p["subject"] not in test_subjects] 163 else: 164 pairs = [p for p in pairs if p["subject"] in test_subjects] 165 else: 166 # For datasets with external test sets (TEM2), all DANDI data is training. 167 if split == "test": 168 raise NotImplementedError( 169 "The test set for this dataset is hosted externally. " 170 "Please use the ASTIH repository's get_data.py script for the test split." 171 ) 172 173 # Download and preprocess each pair. 174 for pair in tqdm(pairs, desc=f"Processing {split} data"): 175 h5_path = os.path.join(out_path, f"{pair['stem']}.h5") 176 if os.path.exists(h5_path): 177 continue 178 179 # Download image. 180 img_data = requests.get(f"{DANDI_API}/assets/{pair['image_asset']['asset_id']}/download/").content 181 raw = imageio.imread(io.BytesIO(img_data)) 182 if raw.ndim == 3: 183 raw = raw[..., 0] 184 185 # Download label. 186 lbl_data = requests.get(f"{DANDI_API}/assets/{pair['label_asset']['asset_id']}/download/").content 187 label = imageio.imread(io.BytesIO(lbl_data)) 188 label = _preprocess_label(label) 189 190 assert raw.shape == label.shape, f"Shape mismatch: {raw.shape} vs {label.shape}" 191 192 with h5py.File(h5_path, "w") as f: 193 f.create_dataset("raw", data=raw, compression="gzip") 194 f.create_dataset("labels", data=label, compression="gzip") 195 196 197def get_astih_data( 198 path: Union[os.PathLike, str], 199 name: str, 200 split: Literal["train", "test"], 201 download: bool = False, 202) -> str: 203 """Download the ASTIH data. 204 205 Args: 206 path: Filepath to a folder where the downloaded data will be saved. 207 name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'. 208 split: The data split. Either 'train' or 'test'. 209 download: Whether to download the data if it is not present. 210 211 Returns: 212 The filepath for the downloaded data. 213 """ 214 assert name in DATASETS, f"Invalid dataset name: {name}. Choose from {DATASET_NAMES}." 215 216 out_path = os.path.join(path, name, split) 217 if os.path.exists(out_path) and len(glob(os.path.join(out_path, "*.h5"))) > 0: 218 return out_path 219 220 _download_and_preprocess(out_path, DATASETS[name], split, download) 221 return out_path 222 223 224def get_astih_paths( 225 path: Union[os.PathLike, str], 226 name: Optional[Union[str, Sequence[str]]] = None, 227 split: Literal["train", "test"] = "train", 228 download: bool = False, 229) -> List[str]: 230 """Get paths to the ASTIH data. 231 232 Args: 233 path: Filepath to a folder where the downloaded data will be saved. 234 name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'. 235 Can be a single name, a list of names, or None to load all datasets. 236 split: The data split. Either 'train' or 'test'. 237 download: Whether to download the data if it is not present. 238 239 Returns: 240 The filepaths for the stored data. 241 """ 242 if name is None: 243 name = DATASET_NAMES 244 elif isinstance(name, str): 245 name = [name] 246 247 all_paths = [] 248 for nn in name: 249 data_root = get_astih_data(path, nn, split, download) 250 paths = glob(os.path.join(data_root, "*.h5")) 251 paths.sort() 252 all_paths.extend(paths) 253 254 return all_paths 255 256 257def get_astih_dataset( 258 path: Union[os.PathLike, str], 259 patch_shape: Tuple[int, int], 260 name: Optional[Union[str, Sequence[str]]] = None, 261 split: Literal["train", "test"] = "train", 262 download: bool = False, 263 label_classes: Optional[Sequence[str]] = None, 264 **kwargs, 265) -> Dataset: 266 """Get the ASTIH dataset for axon and myelin segmentation. 267 268 Args: 269 path: Filepath to a folder where the downloaded data will be saved. 270 patch_shape: The patch shape to use for training. 271 name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2', 272 a list of these names to combine datasets, or None to load all datasets. 273 split: The data split. Either 'train' or 'test'. 274 download: Whether to download the data if it is not present. 275 label_classes: The label classes to use for one-hot encoding. Available classes are 276 'background', 'myelin', and 'axon'. By default set to None, which returns 277 the label map with all classes (0=background, 1=myelin, 2=axon). 278 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 279 280 Returns: 281 The segmentation dataset. 282 """ 283 all_paths = get_astih_paths(path, name, split, download) 284 285 if label_classes is not None: 286 class_ids = [] 287 for cls_name in label_classes: 288 if cls_name not in LABEL_CLASSES: 289 raise ValueError( 290 f"Invalid class name: '{cls_name}'. Choose from {list(LABEL_CLASSES.keys())}." 291 ) 292 class_ids.append(LABEL_CLASSES[cls_name]) 293 label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids) 294 msg = "'label_classes' is set, but 'label_transform' is in the kwargs. It will be over-ridden." 295 kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 296 297 return torch_em.default_segmentation_dataset( 298 raw_paths=all_paths, 299 raw_key="raw", 300 label_paths=all_paths, 301 label_key="labels", 302 patch_shape=patch_shape, 303 **kwargs, 304 ) 305 306 307def get_astih_loader( 308 path: Union[os.PathLike, str], 309 patch_shape: Tuple[int, int], 310 batch_size: int, 311 name: Optional[Union[str, Sequence[str]]] = None, 312 split: Literal["train", "test"] = "train", 313 download: bool = False, 314 label_classes: Optional[Sequence[str]] = None, 315 **kwargs, 316) -> DataLoader: 317 """Get the DataLoader for axon and myelin segmentation in the ASTIH dataset. 318 319 Args: 320 path: Filepath to a folder where the downloaded data will be saved. 321 patch_shape: The patch shape to use for training. 322 batch_size: The batch size for training. 323 name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2', 324 a list of these names to combine datasets, or None to load all datasets. 325 split: The data split. Either 'train' or 'test'. 326 download: Whether to download the data if it is not present. 327 label_classes: The label classes to use for one-hot encoding. Available classes are 328 'background', 'myelin', and 'axon'. By default set to None, which returns 329 the label map with all classes (0=background, 1=myelin, 2=axon). 330 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 331 332 Returns: 333 The PyTorch DataLoader. 334 """ 335 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 336 dataset = get_astih_dataset( 337 path, patch_shape, name=name, split=split, download=download, 338 label_classes=label_classes, **ds_kwargs, 339 ) 340 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
198def get_astih_data( 199 path: Union[os.PathLike, str], 200 name: str, 201 split: Literal["train", "test"], 202 download: bool = False, 203) -> str: 204 """Download the ASTIH data. 205 206 Args: 207 path: Filepath to a folder where the downloaded data will be saved. 208 name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'. 209 split: The data split. Either 'train' or 'test'. 210 download: Whether to download the data if it is not present. 211 212 Returns: 213 The filepath for the downloaded data. 214 """ 215 assert name in DATASETS, f"Invalid dataset name: {name}. Choose from {DATASET_NAMES}." 216 217 out_path = os.path.join(path, name, split) 218 if os.path.exists(out_path) and len(glob(os.path.join(out_path, "*.h5"))) > 0: 219 return out_path 220 221 _download_and_preprocess(out_path, DATASETS[name], split, download) 222 return out_path
Download the ASTIH data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'.
- split: The data split. Either 'train' or 'test'.
- download: Whether to download the data if it is not present.
Returns:
The filepath for the downloaded data.
225def get_astih_paths( 226 path: Union[os.PathLike, str], 227 name: Optional[Union[str, Sequence[str]]] = None, 228 split: Literal["train", "test"] = "train", 229 download: bool = False, 230) -> List[str]: 231 """Get paths to the ASTIH data. 232 233 Args: 234 path: Filepath to a folder where the downloaded data will be saved. 235 name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'. 236 Can be a single name, a list of names, or None to load all datasets. 237 split: The data split. Either 'train' or 'test'. 238 download: Whether to download the data if it is not present. 239 240 Returns: 241 The filepaths for the stored data. 242 """ 243 if name is None: 244 name = DATASET_NAMES 245 elif isinstance(name, str): 246 name = [name] 247 248 all_paths = [] 249 for nn in name: 250 data_root = get_astih_data(path, nn, split, download) 251 paths = glob(os.path.join(data_root, "*.h5")) 252 paths.sort() 253 all_paths.extend(paths) 254 255 return all_paths
Get paths to the ASTIH data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset. Available names are 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2'. Can be a single name, a list of names, or None to load all datasets.
- split: The data split. Either 'train' or 'test'.
- download: Whether to download the data if it is not present.
Returns:
The filepaths for the stored data.
258def get_astih_dataset( 259 path: Union[os.PathLike, str], 260 patch_shape: Tuple[int, int], 261 name: Optional[Union[str, Sequence[str]]] = None, 262 split: Literal["train", "test"] = "train", 263 download: bool = False, 264 label_classes: Optional[Sequence[str]] = None, 265 **kwargs, 266) -> Dataset: 267 """Get the ASTIH dataset for axon and myelin segmentation. 268 269 Args: 270 path: Filepath to a folder where the downloaded data will be saved. 271 patch_shape: The patch shape to use for training. 272 name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2', 273 a list of these names to combine datasets, or None to load all datasets. 274 split: The data split. Either 'train' or 'test'. 275 download: Whether to download the data if it is not present. 276 label_classes: The label classes to use for one-hot encoding. Available classes are 277 'background', 'myelin', and 'axon'. By default set to None, which returns 278 the label map with all classes (0=background, 1=myelin, 2=axon). 279 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 280 281 Returns: 282 The segmentation dataset. 283 """ 284 all_paths = get_astih_paths(path, name, split, download) 285 286 if label_classes is not None: 287 class_ids = [] 288 for cls_name in label_classes: 289 if cls_name not in LABEL_CLASSES: 290 raise ValueError( 291 f"Invalid class name: '{cls_name}'. Choose from {list(LABEL_CLASSES.keys())}." 292 ) 293 class_ids.append(LABEL_CLASSES[cls_name]) 294 label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids) 295 msg = "'label_classes' is set, but 'label_transform' is in the kwargs. It will be over-ridden." 296 kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 297 298 return torch_em.default_segmentation_dataset( 299 raw_paths=all_paths, 300 raw_key="raw", 301 label_paths=all_paths, 302 label_key="labels", 303 patch_shape=patch_shape, 304 **kwargs, 305 )
Get the ASTIH dataset for axon and myelin segmentation.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- patch_shape: The patch shape to use for training.
- name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2', a list of these names to combine datasets, or None to load all datasets.
- split: The data split. Either 'train' or 'test'.
- download: Whether to download the data if it is not present.
- label_classes: The label classes to use for one-hot encoding. Available classes are 'background', 'myelin', and 'axon'. By default set to None, which returns the label map with all classes (0=background, 1=myelin, 2=axon).
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset.
Returns:
The segmentation dataset.
308def get_astih_loader( 309 path: Union[os.PathLike, str], 310 patch_shape: Tuple[int, int], 311 batch_size: int, 312 name: Optional[Union[str, Sequence[str]]] = None, 313 split: Literal["train", "test"] = "train", 314 download: bool = False, 315 label_classes: Optional[Sequence[str]] = None, 316 **kwargs, 317) -> DataLoader: 318 """Get the DataLoader for axon and myelin segmentation in the ASTIH dataset. 319 320 Args: 321 path: Filepath to a folder where the downloaded data will be saved. 322 patch_shape: The patch shape to use for training. 323 batch_size: The batch size for training. 324 name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2', 325 a list of these names to combine datasets, or None to load all datasets. 326 split: The data split. Either 'train' or 'test'. 327 download: Whether to download the data if it is not present. 328 label_classes: The label classes to use for one-hot encoding. Available classes are 329 'background', 'myelin', and 'axon'. By default set to None, which returns 330 the label map with all classes (0=background, 1=myelin, 2=axon). 331 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 332 333 Returns: 334 The PyTorch DataLoader. 335 """ 336 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 337 dataset = get_astih_dataset( 338 path, patch_shape, name=name, split=split, download=download, 339 label_classes=label_classes, **ds_kwargs, 340 ) 341 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the DataLoader for axon and myelin segmentation in the ASTIH dataset.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- patch_shape: The patch shape to use for training.
- batch_size: The batch size for training.
- name: The name of the dataset. Can be one of 'TEM1', 'TEM2', 'SEM1', 'BF1', 'BF2', a list of these names to combine datasets, or None to load all datasets.
- split: The data split. Either 'train' or 'test'.
- download: Whether to download the data if it is not present.
- label_classes: The label classes to use for one-hot encoding. Available classes are 'background', 'myelin', and 'axon'. By default set to None, which returns the label map with all classes (0=background, 1=myelin, 2=axon).
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_datasetor for the PyTorch DataLoader.
Returns:
The PyTorch DataLoader.