torch_em.data.datasets.electron_microscopy.axondeepseg
AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. It contains two different data types: TEM and SEM.
The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. Please cite this publication if you use the dataset in your research.
1"""AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. 2It contains two different data types: TEM and SEM. 3 4The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. 5Please cite this publication if you use the dataset in your research. 6""" 7 8import os 9from glob import glob 10from shutil import rmtree 11from typing import Optional, Tuple, Union, Literal, List 12 13import imageio 14import numpy as np 15 16from torch.utils.data import Dataset, DataLoader 17 18import torch_em 19 20from .. import util 21 22URLS = { 23 "sem": "https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip", 24 "tem": "https://osf.io/download/uewd9" 25} 26CHECKSUMS = { 27 "sem": "d334cbacf548f78ce8dd4a597bf86b884bd15a47a230a0ccc46e1ffa94d58426", 28 "tem": "e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95" 29} 30 31 32def _preprocess_sem_data(out_path): 33 import h5py 34 35 # preprocess the data to get it to a better data format 36 data_root = os.path.join(out_path, "data_axondeepseg_sem-master") 37 assert os.path.exists(data_root) 38 39 # get the raw data paths 40 raw_folders = glob(os.path.join(data_root, "sub-rat*")) 41 raw_folders.sort() 42 raw_paths = [] 43 for folder in raw_folders: 44 paths = glob(os.path.join(folder, "micr", "*.png")) 45 paths.sort() 46 raw_paths.extend(paths) 47 48 # get the label paths 49 label_folders = glob(os.path.join( 50 data_root, "derivatives", "labels", "sub-rat*" 51 )) 52 label_folders.sort() 53 label_paths = [] 54 for folder in label_folders: 55 paths = glob(os.path.join(folder, "micr", "*axonmyelin-manual.png")) 56 paths.sort() 57 label_paths.extend(paths) 58 assert len(raw_paths) == len(label_paths), f"{len(raw_paths)}, {len(label_paths)}" 59 60 # process raw data and labels 61 for i, (rp, lp) in enumerate(zip(raw_paths, label_paths)): 62 outp = os.path.join(out_path, f"sem_data_{i}.h5") 63 with h5py.File(outp, "w") as f: 64 65 # raw data: invert to match tem em intensities 66 raw = imageio.imread(rp) 67 assert np.dtype(raw.dtype) == np.dtype("uint8") 68 if raw.ndim == 3: # (one of the images is RGBA) 69 raw = np.mean(raw[..., :-3], axis=-1) 70 raw = 255 - raw 71 f.create_dataset("raw", data=raw, compression="gzip") 72 73 # labels: map from 74 # 0 -> 0 75 # 127, 128 -> 1 76 # 255 -> 2 77 labels = imageio.imread(lp) 78 assert labels.shape == raw.shape, f"{labels.shape}, {raw.shape}" 79 label_vals = np.unique(labels) 80 # 127, 128: both myelin labels, 130, 233: noise 81 assert len(np.setdiff1d(label_vals, [0, 127, 128, 130, 233, 255])) == 0, f"{label_vals}" 82 new_labels = np.zeros_like(labels) 83 new_labels[labels == 127] = 1 84 new_labels[labels == 128] = 1 85 new_labels[labels == 255] = 2 86 f.create_dataset("labels", data=new_labels, compression="gzip") 87 88 # clean up 89 rmtree(data_root) 90 91 92def _preprocess_tem_data(out_path): 93 import h5py 94 95 data_root = os.path.join(out_path, "TEM_dataset") 96 folder_names = os.listdir(data_root) 97 folders = [os.path.join(data_root, fname) for fname in folder_names 98 if os.path.isdir(os.path.join(data_root, fname))] 99 for i, folder in enumerate(folders): 100 data_out = os.path.join(out_path, f"tem_{i}.h5") 101 with h5py.File(data_out, "w") as f: 102 im = imageio.imread(os.path.join(folder, "image.png")) 103 f.create_dataset("raw", data=im, compression="gzip") 104 105 # labels: map from 106 # 0 -> 0 107 # 128 -> 1 108 # 255 -> 2 109 # the rest are noise 110 labels = imageio.imread(os.path.join(folder, "mask.png")) 111 new_labels = np.zeros_like(labels) 112 new_labels[labels == 128] = 1 113 new_labels[labels == 255] = 2 114 f.create_dataset("labels", data=new_labels, compression="gzip") 115 116 # clean up 117 rmtree(data_root) 118 119 120def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str: 121 """Download the AxonDeepSeg data. 122 123 Args: 124 path: Filepath to a folder where the downloaded data will be saved. 125 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 126 download: Whether to download the data if it is not present. 127 128 Returns: 129 The filepath for the downloaded data. 130 """ 131 132 # download and unzip the data 133 url, checksum = URLS[name], CHECKSUMS[name] 134 os.makedirs(path, exist_ok=True) 135 out_path = os.path.join(path, name) 136 if os.path.exists(out_path): 137 return out_path 138 139 tmp_path = os.path.join(path, f"{name}.zip") 140 util.download_source(tmp_path, url, download, checksum=checksum) 141 util.unzip(tmp_path, out_path, remove=True) 142 143 if name == "sem": 144 _preprocess_sem_data(out_path) 145 elif name == "tem": 146 _preprocess_tem_data(out_path) 147 else: 148 raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.") 149 150 return out_path 151 152 153def get_axondeepseg_paths( 154 path: Union[str, os.PathLike], 155 name: Literal["sem", "tem"], 156 download: bool = False, 157 val_fraction: Optional[float] = None, 158 split: Optional[str] = None, 159) -> List[str]: 160 """Get paths to the AxonDeepSeg data. 161 162 Args: 163 path: Filepath to a folder where the downloaded data will be saved. 164 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 165 download: Whether to download the data if it is not present. 166 val_fraction: The fraction of the data to use for validation. 167 split: The data split. Either 'train' or 'val'. 168 169 Returns: 170 List of paths for all the data. 171 """ 172 all_paths = [] 173 for nn in name: 174 data_root = get_axondeepseg_data(path, nn, download) 175 paths = glob(os.path.join(data_root, "*.h5")) 176 paths.sort() 177 if val_fraction is not None: 178 assert split is not None 179 n_samples = int(len(paths) * (1 - val_fraction)) 180 paths = paths[:n_samples] if split == "train" else paths[n_samples:] 181 all_paths.extend(paths) 182 183 return all_paths 184 185 186def get_axondeepseg_dataset( 187 path: Union[str, os.PathLike], 188 name: Literal["sem", "tem"], 189 patch_shape: Tuple[int, int], 190 download: bool = False, 191 one_hot_encoding: bool = False, 192 val_fraction: Optional[float] = None, 193 split: Optional[Literal['train', 'val']] = None, 194 **kwargs, 195) -> Dataset: 196 """Get dataset for segmentation of myelinated axons. 197 198 Args: 199 path: Filepath to a folder where the downloaded data will be saved. 200 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 201 patch_shape: The patch shape to use for training. 202 download: Whether to download the data if it is not present. 203 one_hot_encoding: Whether to return the labels as one hot encoding. 204 val_fraction: The fraction of the data to use for validation. 205 split: The data split. Either 'train' or 'val'. 206 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 207 208 Returns: 209 The segmentation dataset. 210 """ 211 if isinstance(name, str): 212 name = [name] 213 assert isinstance(name, (tuple, list)) 214 215 all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split) 216 217 if one_hot_encoding: 218 if isinstance(one_hot_encoding, bool): 219 # add transformation to go from [0, 1, 2] to one hot encoding 220 class_ids = [0, 1, 2] 221 elif isinstance(one_hot_encoding, int): 222 class_ids = list(range(one_hot_encoding)) 223 elif isinstance(one_hot_encoding, (list, tuple)): 224 class_ids = list(one_hot_encoding) 225 else: 226 raise ValueError( 227 f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list." 228 ) 229 label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids) 230 msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden." 231 kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 232 233 return torch_em.default_segmentation_dataset( 234 raw_paths=all_paths, 235 raw_key="raw", 236 label_paths=all_paths, 237 label_key="labels", 238 patch_shape=patch_shape, 239 **kwargs 240 ) 241 242 243def get_axondeepseg_loader( 244 path: Union[str, os.PathLike], 245 name: Literal["sem", "tem"], 246 patch_shape: Tuple[int, int], 247 batch_size: int, 248 download: bool = False, 249 one_hot_encoding: bool = False, 250 val_fraction: Optional[float] = None, 251 split: Optional[Literal["train", "val"]] = None, 252 **kwargs 253) -> DataLoader: 254 """Get dataloader for the segmentation of myelinated axons. 255 256 Args: 257 path: Filepath to a folder where the downloaded data will be saved. 258 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 259 patch_shape: The patch shape to use for training. 260 batch_size: The batch size for training. 261 download: Whether to download the data if it is not present. 262 one_hot_encoding: Whether to return the labels as one hot encoding. 263 val_fraction: The fraction of the data to use for validation. 264 split: The data split. Either 'train' or 'val'. 265 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 266 267 Returns: 268 The PyTorch DataLoader. 269 """ 270 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 271 dataset = get_axondeepseg_dataset( 272 path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding, 273 val_fraction=val_fraction, split=split, **ds_kwargs 274 ) 275 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS =
{'sem': 'https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip', 'tem': 'https://osf.io/download/uewd9'}
CHECKSUMS =
{'sem': 'd334cbacf548f78ce8dd4a597bf86b884bd15a47a230a0ccc46e1ffa94d58426', 'tem': 'e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95'}
def
get_axondeepseg_data( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], download: bool = False) -> str:
121def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str: 122 """Download the AxonDeepSeg data. 123 124 Args: 125 path: Filepath to a folder where the downloaded data will be saved. 126 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 127 download: Whether to download the data if it is not present. 128 129 Returns: 130 The filepath for the downloaded data. 131 """ 132 133 # download and unzip the data 134 url, checksum = URLS[name], CHECKSUMS[name] 135 os.makedirs(path, exist_ok=True) 136 out_path = os.path.join(path, name) 137 if os.path.exists(out_path): 138 return out_path 139 140 tmp_path = os.path.join(path, f"{name}.zip") 141 util.download_source(tmp_path, url, download, checksum=checksum) 142 util.unzip(tmp_path, out_path, remove=True) 143 144 if name == "sem": 145 _preprocess_sem_data(out_path) 146 elif name == "tem": 147 _preprocess_tem_data(out_path) 148 else: 149 raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.") 150 151 return out_path
Download the AxonDeepSeg data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- download: Whether to download the data if it is not present.
Returns:
The filepath for the downloaded data.
def
get_axondeepseg_paths( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], download: bool = False, val_fraction: Optional[float] = None, split: Optional[str] = None) -> List[str]:
154def get_axondeepseg_paths( 155 path: Union[str, os.PathLike], 156 name: Literal["sem", "tem"], 157 download: bool = False, 158 val_fraction: Optional[float] = None, 159 split: Optional[str] = None, 160) -> List[str]: 161 """Get paths to the AxonDeepSeg data. 162 163 Args: 164 path: Filepath to a folder where the downloaded data will be saved. 165 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 166 download: Whether to download the data if it is not present. 167 val_fraction: The fraction of the data to use for validation. 168 split: The data split. Either 'train' or 'val'. 169 170 Returns: 171 List of paths for all the data. 172 """ 173 all_paths = [] 174 for nn in name: 175 data_root = get_axondeepseg_data(path, nn, download) 176 paths = glob(os.path.join(data_root, "*.h5")) 177 paths.sort() 178 if val_fraction is not None: 179 assert split is not None 180 n_samples = int(len(paths) * (1 - val_fraction)) 181 paths = paths[:n_samples] if split == "train" else paths[n_samples:] 182 all_paths.extend(paths) 183 184 return all_paths
Get paths to the AxonDeepSeg data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- download: Whether to download the data if it is not present.
- val_fraction: The fraction of the data to use for validation.
- split: The data split. Either 'train' or 'val'.
Returns:
List of paths for all the data.
def
get_axondeepseg_dataset( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], patch_shape: Tuple[int, int], download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[Literal['train', 'val']] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
187def get_axondeepseg_dataset( 188 path: Union[str, os.PathLike], 189 name: Literal["sem", "tem"], 190 patch_shape: Tuple[int, int], 191 download: bool = False, 192 one_hot_encoding: bool = False, 193 val_fraction: Optional[float] = None, 194 split: Optional[Literal['train', 'val']] = None, 195 **kwargs, 196) -> Dataset: 197 """Get dataset for segmentation of myelinated axons. 198 199 Args: 200 path: Filepath to a folder where the downloaded data will be saved. 201 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 202 patch_shape: The patch shape to use for training. 203 download: Whether to download the data if it is not present. 204 one_hot_encoding: Whether to return the labels as one hot encoding. 205 val_fraction: The fraction of the data to use for validation. 206 split: The data split. Either 'train' or 'val'. 207 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 208 209 Returns: 210 The segmentation dataset. 211 """ 212 if isinstance(name, str): 213 name = [name] 214 assert isinstance(name, (tuple, list)) 215 216 all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split) 217 218 if one_hot_encoding: 219 if isinstance(one_hot_encoding, bool): 220 # add transformation to go from [0, 1, 2] to one hot encoding 221 class_ids = [0, 1, 2] 222 elif isinstance(one_hot_encoding, int): 223 class_ids = list(range(one_hot_encoding)) 224 elif isinstance(one_hot_encoding, (list, tuple)): 225 class_ids = list(one_hot_encoding) 226 else: 227 raise ValueError( 228 f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list." 229 ) 230 label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids) 231 msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden." 232 kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 233 234 return torch_em.default_segmentation_dataset( 235 raw_paths=all_paths, 236 raw_key="raw", 237 label_paths=all_paths, 238 label_key="labels", 239 patch_shape=patch_shape, 240 **kwargs 241 )
Get dataset for segmentation of myelinated axons.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- patch_shape: The patch shape to use for training.
- download: Whether to download the data if it is not present.
- one_hot_encoding: Whether to return the labels as one hot encoding.
- val_fraction: The fraction of the data to use for validation.
- split: The data split. Either 'train' or 'val'.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_axondeepseg_loader( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], patch_shape: Tuple[int, int], batch_size: int, download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[Literal['train', 'val']] = None, **kwargs) -> torch.utils.data.dataloader.DataLoader:
244def get_axondeepseg_loader( 245 path: Union[str, os.PathLike], 246 name: Literal["sem", "tem"], 247 patch_shape: Tuple[int, int], 248 batch_size: int, 249 download: bool = False, 250 one_hot_encoding: bool = False, 251 val_fraction: Optional[float] = None, 252 split: Optional[Literal["train", "val"]] = None, 253 **kwargs 254) -> DataLoader: 255 """Get dataloader for the segmentation of myelinated axons. 256 257 Args: 258 path: Filepath to a folder where the downloaded data will be saved. 259 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 260 patch_shape: The patch shape to use for training. 261 batch_size: The batch size for training. 262 download: Whether to download the data if it is not present. 263 one_hot_encoding: Whether to return the labels as one hot encoding. 264 val_fraction: The fraction of the data to use for validation. 265 split: The data split. Either 'train' or 'val'. 266 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 267 268 Returns: 269 The PyTorch DataLoader. 270 """ 271 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 272 dataset = get_axondeepseg_dataset( 273 path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding, 274 val_fraction=val_fraction, split=split, **ds_kwargs 275 ) 276 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get dataloader for the segmentation of myelinated axons.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- patch_shape: The patch shape to use for training.
- batch_size: The batch size for training.
- download: Whether to download the data if it is not present.
- one_hot_encoding: Whether to return the labels as one hot encoding.
- val_fraction: The fraction of the data to use for validation.
- split: The data split. Either 'train' or 'val'.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The PyTorch DataLoader.