torch_em.data.datasets.electron_microscopy.axondeepseg
AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. It contains two different data types: TEM and SEM.
The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. Please cite this publication if you use the dataset in your research.
1"""AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. 2It contains two different data types: TEM and SEM. 3 4The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. 5Please cite this publication if you use the dataset in your research. 6""" 7 8import os 9from glob import glob 10from shutil import rmtree 11from typing import Optional, Tuple, Union, Literal, List 12 13import imageio 14import numpy as np 15 16from torch.utils.data import Dataset, DataLoader 17 18import torch_em 19 20from .. import util 21 22URLS = { 23 "sem": "https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip", 24 "tem": "https://osf.io/download/uewd9" 25} 26CHECKSUMS = { 27 "sem": "12f2f03834c41720badf00131bb7b7a2127e532cf78e01fbea398e1ff800779b", 28 "tem": "e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95" 29} 30 31 32def _preprocess_sem_data(out_path): 33 import h5py 34 35 # preprocess the data to get it to a better data format 36 data_root = os.path.join(out_path, "data_axondeepseg_sem-master") 37 assert os.path.exists(data_root) 38 39 # get the raw data paths 40 raw_folders = glob(os.path.join(data_root, "sub-rat*")) 41 raw_folders.sort() 42 raw_paths = [] 43 for folder in raw_folders: 44 paths = glob(os.path.join(folder, "micr", "*.png")) 45 paths.sort() 46 raw_paths.extend(paths) 47 48 # get the label paths 49 label_folders = glob(os.path.join( 50 data_root, "derivatives", "labels", "sub-rat*" 51 )) 52 label_folders.sort() 53 label_paths = [] 54 for folder in label_folders: 55 paths = glob(os.path.join(folder, "micr", "*axonmyelin-manual.png")) 56 paths.sort() 57 label_paths.extend(paths) 58 assert len(raw_paths) == len(label_paths), f"{len(raw_paths)}, {len(label_paths)}" 59 60 # process raw data and labels 61 for i, (rp, lp) in enumerate(zip(raw_paths, label_paths)): 62 outp = os.path.join(out_path, f"sem_data_{i}.h5") 63 with h5py.File(outp, "w") as f: 64 65 # raw data: invert to match tem em intensities 66 raw = imageio.imread(rp) 67 assert np.dtype(raw.dtype) == np.dtype("uint8") 68 if raw.ndim == 3: # Some images have extra channels (RGBA or grayscale+alpha). 69 raw = raw[..., 0] 70 raw = 255 - raw 71 f.create_dataset("raw", data=raw, compression="gzip") 72 73 # labels: map from 74 # 0 -> 0 75 # 127, 128 -> 1 76 # 255 -> 2 77 labels = imageio.imread(lp) 78 if labels.ndim == 3: # Some labels have an extra alpha channel. 79 labels = labels[..., 0] 80 assert labels.shape == raw.shape, f"{labels.shape}, {raw.shape}" 81 label_vals = np.unique(labels) 82 # 127, 128: both myelin labels, 130, 233: noise 83 assert len(np.setdiff1d(label_vals, [0, 127, 128, 130, 233, 255])) == 0, f"{label_vals}" 84 new_labels = np.zeros_like(labels) 85 new_labels[labels == 127] = 1 86 new_labels[labels == 128] = 1 87 new_labels[labels == 255] = 2 88 f.create_dataset("labels", data=new_labels, compression="gzip") 89 90 # clean up 91 rmtree(data_root) 92 93 94def _preprocess_tem_data(out_path): 95 import h5py 96 97 data_root = os.path.join(out_path, "TEM_dataset") 98 folder_names = os.listdir(data_root) 99 folders = [os.path.join(data_root, fname) for fname in folder_names 100 if os.path.isdir(os.path.join(data_root, fname))] 101 for i, folder in enumerate(folders): 102 data_out = os.path.join(out_path, f"tem_{i}.h5") 103 with h5py.File(data_out, "w") as f: 104 im = imageio.imread(os.path.join(folder, "image.png")) 105 f.create_dataset("raw", data=im, compression="gzip") 106 107 # labels: map from 108 # 0 -> 0 109 # 128 -> 1 110 # 255 -> 2 111 # the rest are noise 112 labels = imageio.imread(os.path.join(folder, "mask.png")) 113 new_labels = np.zeros_like(labels) 114 new_labels[labels == 128] = 1 115 new_labels[labels == 255] = 2 116 f.create_dataset("labels", data=new_labels, compression="gzip") 117 118 # clean up 119 rmtree(data_root) 120 121 122def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str: 123 """Download the AxonDeepSeg data. 124 125 Args: 126 path: Filepath to a folder where the downloaded data will be saved. 127 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 128 download: Whether to download the data if it is not present. 129 130 Returns: 131 The filepath for the downloaded data. 132 """ 133 134 # download and unzip the data 135 url, checksum = URLS[name], CHECKSUMS[name] 136 os.makedirs(path, exist_ok=True) 137 out_path = os.path.join(path, name) 138 if os.path.exists(out_path): 139 return out_path 140 141 tmp_path = os.path.join(path, f"{name}.zip") 142 util.download_source(tmp_path, url, download, checksum=checksum) 143 util.unzip(tmp_path, out_path, remove=True) 144 145 if name == "sem": 146 _preprocess_sem_data(out_path) 147 elif name == "tem": 148 _preprocess_tem_data(out_path) 149 else: 150 raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.") 151 152 return out_path 153 154 155def get_axondeepseg_paths( 156 path: Union[str, os.PathLike], 157 name: Literal["sem", "tem"], 158 download: bool = False, 159 val_fraction: Optional[float] = None, 160 split: Optional[str] = None, 161) -> List[str]: 162 """Get paths to the AxonDeepSeg data. 163 164 Args: 165 path: Filepath to a folder where the downloaded data will be saved. 166 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 167 download: Whether to download the data if it is not present. 168 val_fraction: The fraction of the data to use for validation. 169 split: The data split. Either 'train' or 'val'. 170 171 Returns: 172 List of paths for all the data. 173 """ 174 all_paths = [] 175 for nn in name: 176 data_root = get_axondeepseg_data(path, nn, download) 177 paths = glob(os.path.join(data_root, "*.h5")) 178 paths.sort() 179 if val_fraction is not None: 180 assert split is not None 181 n_samples = int(len(paths) * (1 - val_fraction)) 182 paths = paths[:n_samples] if split == "train" else paths[n_samples:] 183 all_paths.extend(paths) 184 185 return all_paths 186 187 188def get_axondeepseg_dataset( 189 path: Union[str, os.PathLike], 190 name: Literal["sem", "tem"], 191 patch_shape: Tuple[int, int], 192 download: bool = False, 193 one_hot_encoding: bool = False, 194 val_fraction: Optional[float] = None, 195 split: Optional[Literal['train', 'val']] = None, 196 **kwargs, 197) -> Dataset: 198 """Get dataset for segmentation of myelinated axons. 199 200 Args: 201 path: Filepath to a folder where the downloaded data will be saved. 202 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 203 patch_shape: The patch shape to use for training. 204 download: Whether to download the data if it is not present. 205 one_hot_encoding: Whether to return the labels as one hot encoding. 206 val_fraction: The fraction of the data to use for validation. 207 split: The data split. Either 'train' or 'val'. 208 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 209 210 Returns: 211 The segmentation dataset. 212 """ 213 if isinstance(name, str): 214 name = [name] 215 assert isinstance(name, (tuple, list)) 216 217 all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split) 218 219 if one_hot_encoding: 220 if isinstance(one_hot_encoding, bool): 221 # add transformation to go from [0, 1, 2] to one hot encoding 222 class_ids = [0, 1, 2] 223 elif isinstance(one_hot_encoding, int): 224 class_ids = list(range(one_hot_encoding)) 225 elif isinstance(one_hot_encoding, (list, tuple)): 226 class_ids = list(one_hot_encoding) 227 else: 228 raise ValueError( 229 f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list." 230 ) 231 label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids) 232 msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden." 233 kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 234 235 return torch_em.default_segmentation_dataset( 236 raw_paths=all_paths, 237 raw_key="raw", 238 label_paths=all_paths, 239 label_key="labels", 240 patch_shape=patch_shape, 241 **kwargs 242 ) 243 244 245def get_axondeepseg_loader( 246 path: Union[str, os.PathLike], 247 name: Literal["sem", "tem"], 248 patch_shape: Tuple[int, int], 249 batch_size: int, 250 download: bool = False, 251 one_hot_encoding: bool = False, 252 val_fraction: Optional[float] = None, 253 split: Optional[Literal["train", "val"]] = None, 254 **kwargs 255) -> DataLoader: 256 """Get dataloader for the segmentation of myelinated axons. 257 258 Args: 259 path: Filepath to a folder where the downloaded data will be saved. 260 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 261 patch_shape: The patch shape to use for training. 262 batch_size: The batch size for training. 263 download: Whether to download the data if it is not present. 264 one_hot_encoding: Whether to return the labels as one hot encoding. 265 val_fraction: The fraction of the data to use for validation. 266 split: The data split. Either 'train' or 'val'. 267 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 268 269 Returns: 270 The PyTorch DataLoader. 271 """ 272 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 273 dataset = get_axondeepseg_dataset( 274 path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding, 275 val_fraction=val_fraction, split=split, **ds_kwargs 276 ) 277 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS =
{'sem': 'https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip', 'tem': 'https://osf.io/download/uewd9'}
CHECKSUMS =
{'sem': '12f2f03834c41720badf00131bb7b7a2127e532cf78e01fbea398e1ff800779b', 'tem': 'e4657280808f3b80d3bf1fba87d1cbbf2455f519baf1a7b16d2ddf2e54739a95'}
def
get_axondeepseg_data( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], download: bool = False) -> str:
123def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str: 124 """Download the AxonDeepSeg data. 125 126 Args: 127 path: Filepath to a folder where the downloaded data will be saved. 128 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 129 download: Whether to download the data if it is not present. 130 131 Returns: 132 The filepath for the downloaded data. 133 """ 134 135 # download and unzip the data 136 url, checksum = URLS[name], CHECKSUMS[name] 137 os.makedirs(path, exist_ok=True) 138 out_path = os.path.join(path, name) 139 if os.path.exists(out_path): 140 return out_path 141 142 tmp_path = os.path.join(path, f"{name}.zip") 143 util.download_source(tmp_path, url, download, checksum=checksum) 144 util.unzip(tmp_path, out_path, remove=True) 145 146 if name == "sem": 147 _preprocess_sem_data(out_path) 148 elif name == "tem": 149 _preprocess_tem_data(out_path) 150 else: 151 raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.") 152 153 return out_path
Download the AxonDeepSeg data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- download: Whether to download the data if it is not present.
Returns:
The filepath for the downloaded data.
def
get_axondeepseg_paths( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], download: bool = False, val_fraction: Optional[float] = None, split: Optional[str] = None) -> List[str]:
156def get_axondeepseg_paths( 157 path: Union[str, os.PathLike], 158 name: Literal["sem", "tem"], 159 download: bool = False, 160 val_fraction: Optional[float] = None, 161 split: Optional[str] = None, 162) -> List[str]: 163 """Get paths to the AxonDeepSeg data. 164 165 Args: 166 path: Filepath to a folder where the downloaded data will be saved. 167 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 168 download: Whether to download the data if it is not present. 169 val_fraction: The fraction of the data to use for validation. 170 split: The data split. Either 'train' or 'val'. 171 172 Returns: 173 List of paths for all the data. 174 """ 175 all_paths = [] 176 for nn in name: 177 data_root = get_axondeepseg_data(path, nn, download) 178 paths = glob(os.path.join(data_root, "*.h5")) 179 paths.sort() 180 if val_fraction is not None: 181 assert split is not None 182 n_samples = int(len(paths) * (1 - val_fraction)) 183 paths = paths[:n_samples] if split == "train" else paths[n_samples:] 184 all_paths.extend(paths) 185 186 return all_paths
Get paths to the AxonDeepSeg data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- download: Whether to download the data if it is not present.
- val_fraction: The fraction of the data to use for validation.
- split: The data split. Either 'train' or 'val'.
Returns:
List of paths for all the data.
def
get_axondeepseg_dataset( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], patch_shape: Tuple[int, int], download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[Literal['train', 'val']] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
189def get_axondeepseg_dataset( 190 path: Union[str, os.PathLike], 191 name: Literal["sem", "tem"], 192 patch_shape: Tuple[int, int], 193 download: bool = False, 194 one_hot_encoding: bool = False, 195 val_fraction: Optional[float] = None, 196 split: Optional[Literal['train', 'val']] = None, 197 **kwargs, 198) -> Dataset: 199 """Get dataset for segmentation of myelinated axons. 200 201 Args: 202 path: Filepath to a folder where the downloaded data will be saved. 203 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 204 patch_shape: The patch shape to use for training. 205 download: Whether to download the data if it is not present. 206 one_hot_encoding: Whether to return the labels as one hot encoding. 207 val_fraction: The fraction of the data to use for validation. 208 split: The data split. Either 'train' or 'val'. 209 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 210 211 Returns: 212 The segmentation dataset. 213 """ 214 if isinstance(name, str): 215 name = [name] 216 assert isinstance(name, (tuple, list)) 217 218 all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split) 219 220 if one_hot_encoding: 221 if isinstance(one_hot_encoding, bool): 222 # add transformation to go from [0, 1, 2] to one hot encoding 223 class_ids = [0, 1, 2] 224 elif isinstance(one_hot_encoding, int): 225 class_ids = list(range(one_hot_encoding)) 226 elif isinstance(one_hot_encoding, (list, tuple)): 227 class_ids = list(one_hot_encoding) 228 else: 229 raise ValueError( 230 f"Invalid value {one_hot_encoding} passed for 'one_hot_encoding', expect bool, int or list." 231 ) 232 label_transform = torch_em.transform.label.OneHotTransform(class_ids=class_ids) 233 msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden." 234 kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 235 236 return torch_em.default_segmentation_dataset( 237 raw_paths=all_paths, 238 raw_key="raw", 239 label_paths=all_paths, 240 label_key="labels", 241 patch_shape=patch_shape, 242 **kwargs 243 )
Get dataset for segmentation of myelinated axons.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- patch_shape: The patch shape to use for training.
- download: Whether to download the data if it is not present.
- one_hot_encoding: Whether to return the labels as one hot encoding.
- val_fraction: The fraction of the data to use for validation.
- split: The data split. Either 'train' or 'val'.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset.
Returns:
The segmentation dataset.
def
get_axondeepseg_loader( path: Union[str, os.PathLike], name: Literal['sem', 'tem'], patch_shape: Tuple[int, int], batch_size: int, download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, split: Optional[Literal['train', 'val']] = None, **kwargs) -> torch.utils.data.dataloader.DataLoader:
246def get_axondeepseg_loader( 247 path: Union[str, os.PathLike], 248 name: Literal["sem", "tem"], 249 patch_shape: Tuple[int, int], 250 batch_size: int, 251 download: bool = False, 252 one_hot_encoding: bool = False, 253 val_fraction: Optional[float] = None, 254 split: Optional[Literal["train", "val"]] = None, 255 **kwargs 256) -> DataLoader: 257 """Get dataloader for the segmentation of myelinated axons. 258 259 Args: 260 path: Filepath to a folder where the downloaded data will be saved. 261 name: The name of the dataset to download. Can be either 'sem' or 'tem'. 262 patch_shape: The patch shape to use for training. 263 batch_size: The batch size for training. 264 download: Whether to download the data if it is not present. 265 one_hot_encoding: Whether to return the labels as one hot encoding. 266 val_fraction: The fraction of the data to use for validation. 267 split: The data split. Either 'train' or 'val'. 268 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 269 270 Returns: 271 The PyTorch DataLoader. 272 """ 273 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 274 dataset = get_axondeepseg_dataset( 275 path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding, 276 val_fraction=val_fraction, split=split, **ds_kwargs 277 ) 278 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get dataloader for the segmentation of myelinated axons.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- name: The name of the dataset to download. Can be either 'sem' or 'tem'.
- patch_shape: The patch shape to use for training.
- batch_size: The batch size for training.
- download: Whether to download the data if it is not present.
- one_hot_encoding: Whether to return the labels as one hot encoding.
- val_fraction: The fraction of the data to use for validation.
- split: The data split. Either 'train' or 'val'.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_datasetor for the PyTorch DataLoader.
Returns:
The PyTorch DataLoader.