torch_em.data.datasets.electron_microscopy.axonem

AxomEM is a datast for segmenting axons in electron microscopy. It contains two large annotated volumes, one from mouse cortex, the other from human cortex. This dataset was used for the AxonEM Challenge: https://axonem.grand-challenge.org/.

Please cite the publication https://arxiv.org/abs/2107.05451 if you use this dataset for your research.

  1"""AxomEM is a datast for segmenting axons in electron microscopy.
  2It contains two large annotated volumes, one from mouse cortex, the other from human cortex.
  3This dataset was used for the AxonEM Challenge: https://axonem.grand-challenge.org/.
  4
  5Please cite the publication https://arxiv.org/abs/2107.05451 if you use this dataset for your research.
  6"""
  7
  8import os
  9from glob import glob
 10from typing import Union, Sequence, List, Tuple
 11
 12from torch.utils.data import Dataset, DataLoader
 13
 14import torch_em
 15
 16from .. import util
 17
 18
 19URLS = {
 20    "human": "https://huggingface.co/datasets/pytc/AxonEM/resolve/main/EM30-H-train-9vol-pad-20-512-512.zip",
 21    "mouse": "https://huggingface.co/datasets/pytc/AxonEM/resolve/main/EM30-M-train-9vol-pad-20-512-512.zip",
 22}
 23
 24CHECKSUMS = {
 25    "human": "0b53d155ff62f5e24c552bf90adce329fcf9a8fefd5c697f8bcd0312a61fda60",
 26    "mouse": "dae06b5dabe388ab7a0ff4e51548174f041a338d0d06bd665586aa7fdd43bac2",
 27}
 28
 29
 30def get_axonem_data(path: Union[os.PathLike, str], samples: Sequence[str], download: bool = False):
 31    """Download the AxonEM training data.
 32
 33    Args:
 34        path: Filepath to a folder where the downloaded data will be saved.
 35        samples: The samples to download. The available samples are 'human' and 'mouse'.
 36        download: Whether to download the data if it is not present.
 37    """
 38    if isinstance(samples, str):
 39        samples = [samples]
 40
 41    assert len(set(samples) - {"human", "mouse"}) == 0, f"{samples}"
 42    os.makedirs(path, exist_ok=True)
 43
 44    for sample in samples:
 45        dst = os.path.join(path, sample)
 46        if os.path.exists(dst):
 47            continue
 48
 49        os.makedirs(dst, exist_ok=True)
 50
 51        # Download the zipfile.
 52        zip_path = os.path.join(path, f"{sample}.zip")
 53        util.download_source(path=zip_path, url=URLS[sample], download=download, checksum=CHECKSUMS[sample])
 54
 55        # Extract the h5 crops from the zipfile.
 56        util.unzip(zip_path=zip_path, dst=dst, remove=True)
 57
 58        if sample == "mouse":
 59            # NOTE: We need to make a hotfix by removing a crop which does not have masks.
 60            label_path = os.path.join(path, "mouse", "valid_mask.h5")
 61            os.remove(label_path)
 62
 63            # And the additional volume with no corresponding mask.
 64            image_path = os.path.join(path, "mouse", "im_675-800-800_pad.h5")
 65            os.remove(image_path)
 66
 67
 68def get_axonem_paths(
 69    path: Union[os.PathLike, str], samples: Sequence[str], download: bool = False,
 70) -> Tuple[List[str], List[str]]:
 71    """Get paths for the AxonEM training data.
 72
 73    Args:
 74        path: Filepath to a folder where the downloaded data will be saved.
 75        samples: The samples to download. The available samples are 'human' and 'mouse'.
 76        download: Whether to download the data if it is not present.
 77
 78    Returns:
 79        List of filepaths for the image volumes.
 80        List of filepaths for the label volumes.
 81    """
 82    get_axonem_data(path, samples, download)
 83
 84    if isinstance(samples, str):
 85        samples = [samples]
 86
 87    image_paths, label_paths = [], []
 88    for sample in samples:
 89        curr_image_paths = glob(os.path.join(path, sample, "im_*.h5"))
 90        image_paths.extend(curr_image_paths)
 91        label_paths.extend([p.replace("im_", "seg_") for p in curr_image_paths])
 92
 93    return image_paths, label_paths
 94
 95
 96def get_axonem_dataset(
 97    path: Union[os.PathLike, str],
 98    patch_shape: Tuple[int, ...],
 99    samples: Sequence[str] = ("human", "mouse"),
100    download: bool = False,
101    **kwargs
102) -> Dataset:
103    """Get the AxonEM dataset for the segmentation of axons in EM.
104
105    Args:
106        path: Filepath to a folder where the downloaded data will be saved.
107        patch_shape: The patch shape to use for training.
108        samples: The samples to download. The available samples are 'human' and 'mouse'.
109        download: Whether to download the data if it is not present.
110        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
111
112    Returns:
113       The segmentation dataset.
114    """
115    image_paths, label_paths = get_axonem_paths(path, samples, download)
116
117    return torch_em.default_segmentation_dataset(
118        raw_paths=image_paths,
119        raw_key="main",
120        label_paths=label_paths,
121        label_key="main",
122        patch_shape=patch_shape,
123        **kwargs
124    )
125
126
127def get_axonem_loader(
128    path: Union[os.PathLike, str],
129    batch_size: int,
130    patch_shape: Tuple[int, ...],
131    samples: Sequence[str] = ("human", "mouse"),
132    download: bool = False,
133    **kwargs
134) -> DataLoader:
135    """Get the AxonEM dataloader for the segmentation of axons in EM.
136
137    Args:
138        path: Filepath to a folder where the downloaded data will be saved.
139        batch_size: The batch size for training.
140        patch_shape: The patch shape to use for training.
141        samples: The samples to download. The available samples are 'human' and 'mouse'.
142        download: Whether to download the data if it is not present.
143        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
144
145    Returns:
146        The DataLoader.
147    """
148    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
149    dataset = get_axonem_dataset(path, patch_shape, samples, download, **ds_kwargs)
150    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'human': 'https://huggingface.co/datasets/pytc/AxonEM/resolve/main/EM30-H-train-9vol-pad-20-512-512.zip', 'mouse': 'https://huggingface.co/datasets/pytc/AxonEM/resolve/main/EM30-M-train-9vol-pad-20-512-512.zip'}
CHECKSUMS = {'human': '0b53d155ff62f5e24c552bf90adce329fcf9a8fefd5c697f8bcd0312a61fda60', 'mouse': 'dae06b5dabe388ab7a0ff4e51548174f041a338d0d06bd665586aa7fdd43bac2'}
def get_axonem_data( path: Union[os.PathLike, str], samples: Sequence[str], download: bool = False):
31def get_axonem_data(path: Union[os.PathLike, str], samples: Sequence[str], download: bool = False):
32    """Download the AxonEM training data.
33
34    Args:
35        path: Filepath to a folder where the downloaded data will be saved.
36        samples: The samples to download. The available samples are 'human' and 'mouse'.
37        download: Whether to download the data if it is not present.
38    """
39    if isinstance(samples, str):
40        samples = [samples]
41
42    assert len(set(samples) - {"human", "mouse"}) == 0, f"{samples}"
43    os.makedirs(path, exist_ok=True)
44
45    for sample in samples:
46        dst = os.path.join(path, sample)
47        if os.path.exists(dst):
48            continue
49
50        os.makedirs(dst, exist_ok=True)
51
52        # Download the zipfile.
53        zip_path = os.path.join(path, f"{sample}.zip")
54        util.download_source(path=zip_path, url=URLS[sample], download=download, checksum=CHECKSUMS[sample])
55
56        # Extract the h5 crops from the zipfile.
57        util.unzip(zip_path=zip_path, dst=dst, remove=True)
58
59        if sample == "mouse":
60            # NOTE: We need to make a hotfix by removing a crop which does not have masks.
61            label_path = os.path.join(path, "mouse", "valid_mask.h5")
62            os.remove(label_path)
63
64            # And the additional volume with no corresponding mask.
65            image_path = os.path.join(path, "mouse", "im_675-800-800_pad.h5")
66            os.remove(image_path)

Download the AxonEM training data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • samples: The samples to download. The available samples are 'human' and 'mouse'.
  • download: Whether to download the data if it is not present.
def get_axonem_paths( path: Union[os.PathLike, str], samples: Sequence[str], download: bool = False) -> Tuple[List[str], List[str]]:
69def get_axonem_paths(
70    path: Union[os.PathLike, str], samples: Sequence[str], download: bool = False,
71) -> Tuple[List[str], List[str]]:
72    """Get paths for the AxonEM training data.
73
74    Args:
75        path: Filepath to a folder where the downloaded data will be saved.
76        samples: The samples to download. The available samples are 'human' and 'mouse'.
77        download: Whether to download the data if it is not present.
78
79    Returns:
80        List of filepaths for the image volumes.
81        List of filepaths for the label volumes.
82    """
83    get_axonem_data(path, samples, download)
84
85    if isinstance(samples, str):
86        samples = [samples]
87
88    image_paths, label_paths = [], []
89    for sample in samples:
90        curr_image_paths = glob(os.path.join(path, sample, "im_*.h5"))
91        image_paths.extend(curr_image_paths)
92        label_paths.extend([p.replace("im_", "seg_") for p in curr_image_paths])
93
94    return image_paths, label_paths

Get paths for the AxonEM training data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • samples: The samples to download. The available samples are 'human' and 'mouse'.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image volumes. List of filepaths for the label volumes.

def get_axonem_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], samples: Sequence[str] = ('human', 'mouse'), download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
 97def get_axonem_dataset(
 98    path: Union[os.PathLike, str],
 99    patch_shape: Tuple[int, ...],
100    samples: Sequence[str] = ("human", "mouse"),
101    download: bool = False,
102    **kwargs
103) -> Dataset:
104    """Get the AxonEM dataset for the segmentation of axons in EM.
105
106    Args:
107        path: Filepath to a folder where the downloaded data will be saved.
108        patch_shape: The patch shape to use for training.
109        samples: The samples to download. The available samples are 'human' and 'mouse'.
110        download: Whether to download the data if it is not present.
111        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
112
113    Returns:
114       The segmentation dataset.
115    """
116    image_paths, label_paths = get_axonem_paths(path, samples, download)
117
118    return torch_em.default_segmentation_dataset(
119        raw_paths=image_paths,
120        raw_key="main",
121        label_paths=label_paths,
122        label_key="main",
123        patch_shape=patch_shape,
124        **kwargs
125    )

Get the AxonEM dataset for the segmentation of axons in EM.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • patch_shape: The patch shape to use for training.
  • samples: The samples to download. The available samples are 'human' and 'mouse'.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_axonem_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], samples: Sequence[str] = ('human', 'mouse'), download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
128def get_axonem_loader(
129    path: Union[os.PathLike, str],
130    batch_size: int,
131    patch_shape: Tuple[int, ...],
132    samples: Sequence[str] = ("human", "mouse"),
133    download: bool = False,
134    **kwargs
135) -> DataLoader:
136    """Get the AxonEM dataloader for the segmentation of axons in EM.
137
138    Args:
139        path: Filepath to a folder where the downloaded data will be saved.
140        batch_size: The batch size for training.
141        patch_shape: The patch shape to use for training.
142        samples: The samples to download. The available samples are 'human' and 'mouse'.
143        download: Whether to download the data if it is not present.
144        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
145
146    Returns:
147        The DataLoader.
148    """
149    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
150    dataset = get_axonem_dataset(path, patch_shape, samples, download, **ds_kwargs)
151    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the AxonEM dataloader for the segmentation of axons in EM.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • samples: The samples to download. The available samples are 'human' and 'mouse'.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.