torch_em.data.datasets.electron_microscopy.axondeepseg
AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. It contains two different data types: TEM and SEM. The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. Please cite this publication if you use the dataset in your research.
1"""AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. 2It contains two different data types: TEM and SEM. 3The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. 4Please cite this publication if you use the dataset in your research. 5""" 6 7import os 8from glob import glob 9from shutil import rmtree 10from typing import Optional, Tuple, Union 11 12import imageio 13import h5py 14import numpy as np 15import torch_em 16 17from torch.utils.data import Dataset, DataLoader 18from .. import util 19 20URLS = { 21 "sem": "https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip", 22 "tem": "https://osf.io/download/uewd9" 23} 24CHECKSUMS = { 25 "sem": "d334cbacf548f78ce8dd4a597bf86b884bd15a47a230a0ccc46e1ffa94d58426", 26 "tem": "e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95" 27} 28 29 30def _preprocess_sem_data(out_path): 31 # preprocess the data to get it to a better data format 32 data_root = os.path.join(out_path, "data_axondeepseg_sem-master") 33 assert os.path.exists(data_root) 34 35 # get the raw data paths 36 raw_folders = glob(os.path.join(data_root, "sub-rat*")) 37 raw_folders.sort() 38 raw_paths = [] 39 for folder in raw_folders: 40 paths = glob(os.path.join(folder, "micr", "*.png")) 41 paths.sort() 42 raw_paths.extend(paths) 43 44 # get the label paths 45 label_folders = glob(os.path.join( 46 data_root, "derivatives", "labels", "sub-rat*" 47 )) 48 label_folders.sort() 49 label_paths = [] 50 for folder in label_folders: 51 paths = glob(os.path.join(folder, "micr", "*axonmyelin-manual.png")) 52 paths.sort() 53 label_paths.extend(paths) 54 assert len(raw_paths) == len(label_paths), f"{len(raw_paths)}, {len(label_paths)}" 55 56 # process raw data and labels 57 for i, (rp, lp) in enumerate(zip(raw_paths, label_paths)): 58 outp = os.path.join(out_path, f"sem_data_{i}.h5") 59 with h5py.File(outp, "w") as f: 60 61 # raw data: invert to match tem em intensities 62 raw = imageio.imread(rp) 63 assert np.dtype(raw.dtype) == np.dtype("uint8") 64 if raw.ndim == 3: # (one of the images is RGBA) 65 raw = np.mean(raw[..., :-3], axis=-1) 66 raw = 255 - raw 67 f.create_dataset("raw", data=raw, compression="gzip") 68 69 # labels: map from 70 # 0 -> 0 71 # 127, 128 -> 1 72 # 255 -> 2 73 labels = imageio.imread(lp) 74 assert labels.shape == raw.shape, f"{labels.shape}, {raw.shape}" 75 label_vals = np.unique(labels) 76 # 127, 128: both myelin labels, 130, 233: noise 77 assert len(np.setdiff1d(label_vals, [0, 127, 128, 130, 233, 255])) == 0, f"{label_vals}" 78 new_labels = np.zeros_like(labels) 79 new_labels[labels == 127] = 1 80 new_labels[labels == 128] = 1 81 new_labels[labels == 255] = 2 82 f.create_dataset("labels", data=new_labels, compression="gzip") 83 84 # clean up 85 rmtree(data_root) 86 87 88def _preprocess_tem_data(out_path): 89 data_root = os.path.join(out_path, "TEM_dataset") 90 folder_names = os.listdir(data_root) 91 folders = [os.path.join(data_root, fname) for fname in folder_names 92 if os.path.isdir(os.path.join(data_root, fname))] 93 for i, folder in enumerate(folders): 94 data_out = os.path.join(out_path, f"tem_{i}.h5") 95 with h5py.File(data_out, "w") as f: 96 im = imageio.imread(os.path.join(folder, "image.png")) 97 f.create_dataset("raw", data=im, compression="gzip") 98 99 # labels: map from 100 # 0 -> 0 101 # 128 -> 1 102 # 255 -> 2 103 # the rest are noise 104 labels = imageio.imread(os.path.join(folder, "mask.png")) 105 new_labels = np.zeros_like(labels) 106 new_labels[labels == 128] = 1 107 new_labels[labels == 255] = 2 108 f.create_dataset("labels", data=new_labels, compression="gzip") 109 110 # clean up 111 rmtree(data_root) 112 113 114def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: bool) -> str: 115 """Download the axondeepseg data. 116 117 Args: 118 path: Filepath to a folder where the downloaded data will be saved. 119 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 120 download: Whether to download the data if it is not present. 121 122 Returns: 123 The filepath for the downloaded data. 124 """ 125 126 # download and unzip the data 127 url, checksum = URLS[name], CHECKSUMS[name] 128 os.makedirs(path, exist_ok=True) 129 out_path = os.path.join(path, name) 130 if os.path.exists(out_path): 131 return out_path 132 133 tmp_path = os.path.join(path, f"{name}.zip") 134 util.download_source(tmp_path, url, download, checksum=checksum) 135 util.unzip(tmp_path, out_path, remove=True) 136 137 if name == "sem": 138 _preprocess_sem_data(out_path) 139 elif name == "tem": 140 _preprocess_tem_data(out_path) 141 else: 142 raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.") 143 144 return out_path 145 146 147def get_axondeepseg_dataset( 148 path: Union[str, os.PathLike], 149 name: str, 150 patch_shape: Tuple[int, int], 151 download: bool = False, 152 one_hot_encoding: bool = False, 153 val_fraction: Optional[float] = None, 154 split: Optional[str] = None, 155 **kwargs, 156) -> Dataset: 157 """Get dataset for segmnetation of myelinated axons. 158 159 Args: 160 path: Filepath to a folder where the downloaded data will be saved. 161 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 162 patch_shape: The patch shape to use for training. 163 download: Whether to download the data if it is not present. 164 one_hot_encoding: Whether to return the labels as one hot encoding. 165 val_fraction: The fraction of the data to use for validation. 166 split: The data split. Either 'train' or 'val'. 167 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 168 169 Returns: 170 The segmentation dataset. 171 """ 172 if isinstance(name, str): 173 name = [name] 174 assert isinstance(name, (tuple, list)) 175 176 all_paths = [] 177 for nn in name: 178 data_root = get_axondeepseg_data(path, nn, download) 179 paths = glob(os.path.join(data_root, "*.h5")) 180 paths.sort() 181 if val_fraction is not None: 182 assert split is not None 183 n_samples = int(len(paths) * (1 - val_fraction)) 184 paths = paths[:n_samples] if split == "train" else paths[n_samples:] 185 all_paths.extend(paths) 186 187 if one_hot_encoding: 188 if isinstance(one_hot_encoding, bool): 189 # add transformation to go from [0, 1, 2] to one hot encoding 190 class_ids = [0, 1, 2] 191 elif isinstance(one_hot_encoding, int): 192 class_ids = list(range(one_hot_encoding)) 193 elif isinstance(one_hot_encoding, (list, tuple)): 194 class_ids = list(one_hot_encoding) 195 else: 196 raise ValueError( 197 f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list." 198 ) 199 label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids) 200 msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden." 201 kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 202 203 raw_key, label_key = "raw", "labels" 204 return torch_em.default_segmentation_dataset(all_paths, raw_key, all_paths, label_key, patch_shape, **kwargs) 205 206 207def get_axondeepseg_loader( 208 path: Union[str, os.PathLike], 209 name: str, 210 patch_shape: Tuple[int, int], 211 batch_size: int, 212 download: bool = False, 213 one_hot_encoding: bool = False, 214 val_fraction: Optional[float] = None, 215 split: Optional[str] = None, 216 **kwargs 217) -> DataLoader: 218 """Get dataloader for the segmentation of myelinated axons. 219 220 Args: 221 path: Filepath to a folder where the downloaded data will be saved. 222 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 223 patch_shape: The patch shape to use for training. 224 batch_size: The batch size for training. 225 download: Whether to download the data if it is not present. 226 one_hot_encoding: Whether to return the labels as one hot encoding. 227 val_fraction: The fraction of the data to use for validation. 228 split: The data split. Either 'train' or 'val'. 229 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 230 231 Returns: 232 The PyTorch DataLoader. 233 """ 234 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 235 dataset = get_axondeepseg_dataset( 236 path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding, 237 val_fraction=val_fraction, split=split, **ds_kwargs 238 ) 239 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS =
{'sem': 'https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip', 'tem': 'https://osf.io/download/uewd9'}
CHECKSUMS =
{'sem': 'd334cbacf548f78ce8dd4a597bf86b884bd15a47a230a0ccc46e1ffa94d58426', 'tem': 'e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95'}
def
get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: bool) -> str:
115def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: bool) -> str: 116 """Download the axondeepseg data. 117 118 Args: 119 path: Filepath to a folder where the downloaded data will be saved. 120 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 121 download: Whether to download the data if it is not present. 122 123 Returns: 124 The filepath for the downloaded data. 125 """ 126 127 # download and unzip the data 128 url, checksum = URLS[name], CHECKSUMS[name] 129 os.makedirs(path, exist_ok=True) 130 out_path = os.path.join(path, name) 131 if os.path.exists(out_path): 132 return out_path 133 134 tmp_path = os.path.join(path, f"{name}.zip") 135 util.download_source(tmp_path, url, download, checksum=checksum) 136 util.unzip(tmp_path, out_path, remove=True) 137 138 if name == "sem": 139 _preprocess_sem_data(out_path) 140 elif name == "tem": 141 _preprocess_tem_data(out_path) 142 else: 143 raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.") 144 145 return out_path
Download the axondeepseg data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- download: Whether to download the data if it is not present.
Returns:
The filepath for the downloaded data.
def
get_axondeepseg_dataset( path: Union[str, os.PathLike], name: str, patch_shape: Tuple[int, int], download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[str] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
148def get_axondeepseg_dataset( 149 path: Union[str, os.PathLike], 150 name: str, 151 patch_shape: Tuple[int, int], 152 download: bool = False, 153 one_hot_encoding: bool = False, 154 val_fraction: Optional[float] = None, 155 split: Optional[str] = None, 156 **kwargs, 157) -> Dataset: 158 """Get dataset for segmnetation of myelinated axons. 159 160 Args: 161 path: Filepath to a folder where the downloaded data will be saved. 162 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 163 patch_shape: The patch shape to use for training. 164 download: Whether to download the data if it is not present. 165 one_hot_encoding: Whether to return the labels as one hot encoding. 166 val_fraction: The fraction of the data to use for validation. 167 split: The data split. Either 'train' or 'val'. 168 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 169 170 Returns: 171 The segmentation dataset. 172 """ 173 if isinstance(name, str): 174 name = [name] 175 assert isinstance(name, (tuple, list)) 176 177 all_paths = [] 178 for nn in name: 179 data_root = get_axondeepseg_data(path, nn, download) 180 paths = glob(os.path.join(data_root, "*.h5")) 181 paths.sort() 182 if val_fraction is not None: 183 assert split is not None 184 n_samples = int(len(paths) * (1 - val_fraction)) 185 paths = paths[:n_samples] if split == "train" else paths[n_samples:] 186 all_paths.extend(paths) 187 188 if one_hot_encoding: 189 if isinstance(one_hot_encoding, bool): 190 # add transformation to go from [0, 1, 2] to one hot encoding 191 class_ids = [0, 1, 2] 192 elif isinstance(one_hot_encoding, int): 193 class_ids = list(range(one_hot_encoding)) 194 elif isinstance(one_hot_encoding, (list, tuple)): 195 class_ids = list(one_hot_encoding) 196 else: 197 raise ValueError( 198 f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list." 199 ) 200 label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids) 201 msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden." 202 kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 203 204 raw_key, label_key = "raw", "labels" 205 return torch_em.default_segmentation_dataset(all_paths, raw_key, all_paths, label_key, patch_shape, **kwargs)
Get dataset for segmnetation of myelinated axons.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- patch_shape: The patch shape to use for training.
- download: Whether to download the data if it is not present.
- one_hot_encoding: Whether to return the labels as one hot encoding.
- val_fraction: The fraction of the data to use for validation.
- split: The data split. Either 'train' or 'val'.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_axondeepseg_loader( path: Union[str, os.PathLike], name: str, patch_shape: Tuple[int, int], batch_size: int, download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[str] = None, **kwargs) -> torch.utils.data.dataloader.DataLoader:
208def get_axondeepseg_loader( 209 path: Union[str, os.PathLike], 210 name: str, 211 patch_shape: Tuple[int, int], 212 batch_size: int, 213 download: bool = False, 214 one_hot_encoding: bool = False, 215 val_fraction: Optional[float] = None, 216 split: Optional[str] = None, 217 **kwargs 218) -> DataLoader: 219 """Get dataloader for the segmentation of myelinated axons. 220 221 Args: 222 path: Filepath to a folder where the downloaded data will be saved. 223 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 224 patch_shape: The patch shape to use for training. 225 batch_size: The batch size for training. 226 download: Whether to download the data if it is not present. 227 one_hot_encoding: Whether to return the labels as one hot encoding. 228 val_fraction: The fraction of the data to use for validation. 229 split: The data split. Either 'train' or 'val'. 230 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 231 232 Returns: 233 The PyTorch DataLoader. 234 """ 235 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 236 dataset = get_axondeepseg_dataset( 237 path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding, 238 val_fraction=val_fraction, split=split, **ds_kwargs 239 ) 240 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get dataloader for the segmentation of myelinated axons.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- patch_shape: The patch shape to use for training.
- batch_size: The batch size for training.
- download: Whether to download the data if it is not present.
- one_hot_encoding: Whether to return the labels as one hot encoding.
- val_fraction: The fraction of the data to use for validation.
- split: The data split. Either 'train' or 'val'.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The PyTorch DataLoader.