torch_em.data.datasets.electron_microscopy.waenet

The WAE-NET dataset contains seven biological electron microscopy datasets for multi-class semantic segmentation of cellular structures.

The seven sub-datasets cover different cell types and imaging modalities: - Dataset 1: Human pancreatic carcinoid cell line (FIB-SEM) — background, cytoplasm, nucleus - Dataset 2: BON cell during interphase (ssTEM) — background, cytoplasm, chromosomes - Dataset 3: BON cell during mitosis (ssTEM) — background, cytoplasm, nucleus, mitochondria - Dataset 4: Human T-cell line Jurkat (TEM) — background, cytoplasm, nucleus - Dataset 5: Primary human T-cell blood (TEM) — background, cytoplasm, nucleus - Dataset 6: Murine B-cell line J558L (TEM) — background, cytoplasm, nucleus - Dataset 7: Phytohemagglutinin/IL-2 expanded human T cells (TEM) — background, cytoplasm, nucleus

The data is available at https://doi.org/10.17632/9rdmnn2x4x.1. The dataset was published in https://doi.org/10.1007/s00418-022-02148-3. Please cite this publication if you use the dataset in your research.

  1"""The WAE-NET dataset contains seven biological electron microscopy datasets
  2for multi-class semantic segmentation of cellular structures.
  3
  4The seven sub-datasets cover different cell types and imaging modalities:
  5    - Dataset 1: Human pancreatic carcinoid cell line (FIB-SEM) — background, cytoplasm, nucleus
  6    - Dataset 2: BON cell during interphase (ssTEM) — background, cytoplasm, chromosomes
  7    - Dataset 3: BON cell during mitosis (ssTEM) — background, cytoplasm, nucleus, mitochondria
  8    - Dataset 4: Human T-cell line Jurkat (TEM) — background, cytoplasm, nucleus
  9    - Dataset 5: Primary human T-cell blood (TEM) — background, cytoplasm, nucleus
 10    - Dataset 6: Murine B-cell line J558L (TEM) — background, cytoplasm, nucleus
 11    - Dataset 7: Phytohemagglutinin/IL-2 expanded human T cells (TEM) — background, cytoplasm, nucleus
 12
 13The data is available at https://doi.org/10.17632/9rdmnn2x4x.1.
 14The dataset was published in https://doi.org/10.1007/s00418-022-02148-3.
 15Please cite this publication if you use the dataset in your research.
 16"""
 17
 18import os
 19from glob import glob
 20from shutil import rmtree
 21from typing import List, Literal, Optional, Tuple, Union
 22
 23import numpy as np
 24
 25from torch.utils.data import DataLoader, Dataset
 26
 27import torch_em
 28
 29from .. import util
 30
 31
 32URL = "https://zenodo.org/records/6603083/files/Datasets.zip"
 33CHECKSUM = None
 34
 35# Maps dataset_id -> number of segmentation classes (including background)
 36DATASET_CLASSES = {1: 3, 2: 3, 3: 4, 4: 3, 5: 3, 6: 3, 7: 3}
 37
 38# Maps dataset_id -> ordered list of class names (index 0 = background, etc.)
 39DATASET_CLASS_NAMES = {
 40    1: ["background", "cytoplasm", "nucleus"],
 41    2: ["background", "cytoplasm", "chromosomes"],
 42    3: ["background", "cytoplasm", "nucleus", "mitochondria"],
 43    4: ["background", "cytoplasm", "nucleus"],
 44    5: ["background", "cytoplasm", "nucleus"],
 45    6: ["background", "cytoplasm", "nucleus"],
 46    7: ["background", "cytoplasm", "nucleus"],
 47}
 48
 49
 50def _get_dataset_dir(data_root, dataset_id):
 51    """Find the subdirectory for a given dataset ID inside the extracted archive."""
 52    for dname in (
 53        f"Dataset {dataset_id}", f"Dataset_{dataset_id}", f"Dataset{dataset_id}", f"D{dataset_id}", str(dataset_id)
 54    ):
 55        d = os.path.join(data_root, dname)
 56        if os.path.exists(d):
 57            return d
 58    raise RuntimeError(
 59        f"Cannot find a sub-directory for dataset {dataset_id} inside '{data_root}'. "
 60        f"Contents: {os.listdir(data_root)}"
 61    )
 62
 63
 64def _get_image_mask_dirs(dataset_dir):
 65    """Find the image and mask subdirectories within a per-dataset directory."""
 66    img_dir = None
 67    for name in ("Images", "images", "Image", "image", "Raw", "raw"):
 68        candidate = os.path.join(dataset_dir, name)
 69        if os.path.exists(candidate):
 70            img_dir = candidate
 71            break
 72
 73    mask_dir = None
 74    for name in ("Ground truth mask", "Masks", "masks", "Mask", "mask", "Labels", "labels", "Label", "label"):
 75        candidate = os.path.join(dataset_dir, name)
 76        if os.path.exists(candidate):
 77            mask_dir = candidate
 78            break
 79
 80    if img_dir is None or mask_dir is None:
 81        raise RuntimeError(
 82            f"Cannot find image/mask directories inside '{dataset_dir}'. "
 83            f"Contents: {os.listdir(dataset_dir)}"
 84        )
 85    return img_dir, mask_dir
 86
 87
 88def _create_h5_files(data_root, dataset_id, out_dir):
 89    """Convert TIF image/mask pairs for one sub-dataset into individual HDF5 files."""
 90    import h5py
 91    import imageio.v3 as imageio
 92
 93    dataset_dir = _get_dataset_dir(data_root, dataset_id)
 94    img_dir, mask_dir = _get_image_mask_dirs(dataset_dir)
 95
 96    image_files = sorted(
 97        glob(os.path.join(img_dir, "*.tif")) +
 98        glob(os.path.join(img_dir, "*.tiff")) +
 99        glob(os.path.join(img_dir, "*.png"))
100    )
101    mask_files = sorted(
102        glob(os.path.join(mask_dir, "*.tif")) +
103        glob(os.path.join(mask_dir, "*.tiff")) +
104        glob(os.path.join(mask_dir, "*.png"))
105    )
106
107    assert len(image_files) > 0, f"No TIF files found in '{img_dir}'"
108    assert len(image_files) == len(mask_files), (
109        f"Mismatch: {len(image_files)} images vs {len(mask_files)} masks in '{dataset_dir}'"
110    )
111
112    os.makedirs(out_dir, exist_ok=True)
113
114    for img_path, mask_path in zip(image_files, mask_files):
115        fname = os.path.splitext(os.path.basename(img_path))[0]
116        out_path = os.path.join(out_dir, f"{fname}.h5")
117
118        raw = imageio.imread(img_path)
119        if raw.ndim == 3:  # drop extra channels (e.g. RGBA -> grayscale)
120            raw = raw[..., 0]
121
122        labels = imageio.imread(mask_path)
123        if labels.ndim == 3:
124            labels = labels[..., 0]
125
126        # Remap arbitrary grayscale values to consecutive class indices (0, 1, 2, …).
127        unique_vals = np.sort(np.unique(labels))
128        if not np.array_equal(unique_vals, np.arange(len(unique_vals))):
129            new_labels = np.zeros_like(labels)
130            for cls_idx, val in enumerate(unique_vals):
131                new_labels[labels == val] = cls_idx
132            labels = new_labels
133
134        class_names = DATASET_CLASS_NAMES[dataset_id]
135
136        with h5py.File(out_path, "w") as f:
137            f.create_dataset("raw", data=raw, compression="gzip")
138            label_group = f.create_group("labels")
139            for cls_idx, cls_name in enumerate(class_names):
140                binary_mask = (labels == cls_idx).astype("uint8")
141                label_group.create_dataset(cls_name, data=binary_mask, compression="gzip")
142
143
144def get_waenet_data(path: Union[os.PathLike, str], dataset_id: int, download: bool = False) -> str:
145    """Download and preprocess the WAE-NET dataset.
146
147    Args:
148        path: Filepath to a folder where the downloaded data will be saved.
149        dataset_id: Which of the seven sub-datasets to use (1–7).
150        download: Whether to download the data if it is not present.
151
152    Returns:
153        The path to the directory containing the preprocessed HDF5 files.
154    """
155    if dataset_id not in DATASET_CLASSES:
156        raise ValueError(f"Invalid dataset_id {dataset_id!r}. Choose from {sorted(DATASET_CLASSES)}.")
157
158    out_dir = os.path.join(path, f"dataset_{dataset_id}")
159    if os.path.exists(out_dir):
160        return out_dir
161
162    os.makedirs(path, exist_ok=True)
163
164    zip_path = os.path.join(path, "Datasets.zip")
165    util.download_source(zip_path, URL, download, checksum=CHECKSUM)
166
167    # Extract to a temporary sub-directory and process all seven datasets in one pass.
168    extract_dir = os.path.join(path, "_extracted")
169    util.unzip(zip_path, extract_dir, remove=True)
170
171    # The archive likely contains a single root folder (e.g. "Datasets/").
172    subdirs = [
173        d for d in os.listdir(extract_dir) if os.path.isdir(os.path.join(extract_dir, d))
174    ]
175    data_root = os.path.join(extract_dir, subdirs[0]) if subdirs else extract_dir
176
177    for did in DATASET_CLASSES:
178        _create_h5_files(data_root, did, os.path.join(path, f"dataset_{did}"))
179
180    rmtree(extract_dir)
181
182    return out_dir
183
184
185def get_waenet_paths(
186    path: Union[os.PathLike, str],
187    dataset_id: int,
188    split: Optional[Literal["train", "test"]] = None,
189    val_fraction: float = 0.2,
190    download: bool = False,
191) -> List[str]:
192    """Get paths to the WAE-NET HDF5 files.
193
194    Args:
195        path: Filepath to a folder where the downloaded data will be saved.
196        dataset_id: Which of the seven sub-datasets to use (1–7).
197        split: The data split. Either 'train', 'test', or None for all data.
198        val_fraction: Fraction of images reserved for the test split (default 0.2, matching the paper's 8:2 ratio).
199        download: Whether to download the data if it is not present.
200
201    Returns:
202        List of filepaths to the HDF5 files.
203    """
204    data_dir = get_waenet_data(path, dataset_id, download)
205    all_paths = sorted(glob(os.path.join(data_dir, "*.h5")))
206    assert len(all_paths) > 0, f"No HDF5 files found in '{data_dir}'"
207
208    if split is None:
209        return all_paths
210
211    assert split in ("train", "test"), f"split must be 'train', 'test', or None, got {split!r}"
212    n_train = int(len(all_paths) * (1 - val_fraction))
213    return all_paths[:n_train] if split == "train" else all_paths[n_train:]
214
215
216def get_waenet_dataset(
217    path: Union[os.PathLike, str],
218    dataset_id: int,
219    patch_shape: Tuple[int, int],
220    split: Optional[Literal["train", "test"]] = None,
221    val_fraction: float = 0.2,
222    label_type: Optional[str] = None,
223    download: bool = False,
224    **kwargs,
225) -> Dataset:
226    """Get the WAE-NET dataset for multi-class semantic segmentation in electron microscopy.
227
228    Args:
229        path: Filepath to a folder where the downloaded data will be saved.
230        dataset_id: Which of the seven sub-datasets to use (1–7).
231        patch_shape: The patch shape to use for training.
232        split: The data split. Either 'train', 'test', or None for all data.
233        val_fraction: Fraction of images reserved for the test split (default 0.2).
234        label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria').
235            If None, defaults to the first non-background class for the given dataset.
236            Available classes per dataset are listed in `DATASET_CLASS_NAMES`.
237        download: Whether to download the data if it is not present.
238        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
239
240    Returns:
241        The segmentation dataset.
242    """
243    all_paths = get_waenet_paths(path, dataset_id, split, val_fraction, download)
244
245    if label_type is None:
246        label_type = DATASET_CLASS_NAMES[dataset_id][1]
247
248    valid_types = DATASET_CLASS_NAMES[dataset_id]
249    if label_type not in valid_types:
250        raise ValueError(f"Invalid label_type '{label_type}' for dataset {dataset_id}. Choose from {valid_types}.")
251
252    return torch_em.default_segmentation_dataset(
253        raw_paths=all_paths,
254        raw_key="raw",
255        label_paths=all_paths,
256        label_key=f"labels/{label_type}",
257        patch_shape=patch_shape,
258        **kwargs,
259    )
260
261
262def get_waenet_loader(
263    path: Union[os.PathLike, str],
264    dataset_id: int,
265    patch_shape: Tuple[int, int],
266    batch_size: int,
267    split: Optional[Literal["train", "test"]] = None,
268    val_fraction: float = 0.2,
269    label_type: Optional[str] = None,
270    download: bool = False,
271    **kwargs,
272) -> DataLoader:
273    """Get the WAE-NET dataloader for multi-class semantic segmentation in electron microscopy.
274
275    Args:
276        path: Filepath to a folder where the downloaded data will be saved.
277        dataset_id: Which of the seven sub-datasets to use (1–7).
278        patch_shape: The patch shape to use for training.
279        batch_size: The batch size for training.
280        split: The data split. Either 'train', 'test', or None for all data.
281        val_fraction: Fraction of images reserved for the test split (default 0.2).
282        label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria').
283            If None, defaults to the first non-background class for the given dataset.
284            Available classes per dataset are listed in `DATASET_CLASS_NAMES`.
285        download: Whether to download the data if it is not present.
286        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
287
288    Returns:
289        The PyTorch DataLoader.
290    """
291    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
292    dataset = get_waenet_dataset(path, dataset_id, patch_shape, split, val_fraction, label_type, download, **ds_kwargs)
293    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://zenodo.org/records/6603083/files/Datasets.zip'
CHECKSUM = None
DATASET_CLASSES = {1: 3, 2: 3, 3: 4, 4: 3, 5: 3, 6: 3, 7: 3}
DATASET_CLASS_NAMES = {1: ['background', 'cytoplasm', 'nucleus'], 2: ['background', 'cytoplasm', 'chromosomes'], 3: ['background', 'cytoplasm', 'nucleus', 'mitochondria'], 4: ['background', 'cytoplasm', 'nucleus'], 5: ['background', 'cytoplasm', 'nucleus'], 6: ['background', 'cytoplasm', 'nucleus'], 7: ['background', 'cytoplasm', 'nucleus']}
def get_waenet_data( path: Union[os.PathLike, str], dataset_id: int, download: bool = False) -> str:
145def get_waenet_data(path: Union[os.PathLike, str], dataset_id: int, download: bool = False) -> str:
146    """Download and preprocess the WAE-NET dataset.
147
148    Args:
149        path: Filepath to a folder where the downloaded data will be saved.
150        dataset_id: Which of the seven sub-datasets to use (1–7).
151        download: Whether to download the data if it is not present.
152
153    Returns:
154        The path to the directory containing the preprocessed HDF5 files.
155    """
156    if dataset_id not in DATASET_CLASSES:
157        raise ValueError(f"Invalid dataset_id {dataset_id!r}. Choose from {sorted(DATASET_CLASSES)}.")
158
159    out_dir = os.path.join(path, f"dataset_{dataset_id}")
160    if os.path.exists(out_dir):
161        return out_dir
162
163    os.makedirs(path, exist_ok=True)
164
165    zip_path = os.path.join(path, "Datasets.zip")
166    util.download_source(zip_path, URL, download, checksum=CHECKSUM)
167
168    # Extract to a temporary sub-directory and process all seven datasets in one pass.
169    extract_dir = os.path.join(path, "_extracted")
170    util.unzip(zip_path, extract_dir, remove=True)
171
172    # The archive likely contains a single root folder (e.g. "Datasets/").
173    subdirs = [
174        d for d in os.listdir(extract_dir) if os.path.isdir(os.path.join(extract_dir, d))
175    ]
176    data_root = os.path.join(extract_dir, subdirs[0]) if subdirs else extract_dir
177
178    for did in DATASET_CLASSES:
179        _create_h5_files(data_root, did, os.path.join(path, f"dataset_{did}"))
180
181    rmtree(extract_dir)
182
183    return out_dir

Download and preprocess the WAE-NET dataset.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • dataset_id: Which of the seven sub-datasets to use (1–7).
  • download: Whether to download the data if it is not present.
Returns:

The path to the directory containing the preprocessed HDF5 files.

def get_waenet_paths( path: Union[os.PathLike, str], dataset_id: int, split: Optional[Literal['train', 'test']] = None, val_fraction: float = 0.2, download: bool = False) -> List[str]:
186def get_waenet_paths(
187    path: Union[os.PathLike, str],
188    dataset_id: int,
189    split: Optional[Literal["train", "test"]] = None,
190    val_fraction: float = 0.2,
191    download: bool = False,
192) -> List[str]:
193    """Get paths to the WAE-NET HDF5 files.
194
195    Args:
196        path: Filepath to a folder where the downloaded data will be saved.
197        dataset_id: Which of the seven sub-datasets to use (1–7).
198        split: The data split. Either 'train', 'test', or None for all data.
199        val_fraction: Fraction of images reserved for the test split (default 0.2, matching the paper's 8:2 ratio).
200        download: Whether to download the data if it is not present.
201
202    Returns:
203        List of filepaths to the HDF5 files.
204    """
205    data_dir = get_waenet_data(path, dataset_id, download)
206    all_paths = sorted(glob(os.path.join(data_dir, "*.h5")))
207    assert len(all_paths) > 0, f"No HDF5 files found in '{data_dir}'"
208
209    if split is None:
210        return all_paths
211
212    assert split in ("train", "test"), f"split must be 'train', 'test', or None, got {split!r}"
213    n_train = int(len(all_paths) * (1 - val_fraction))
214    return all_paths[:n_train] if split == "train" else all_paths[n_train:]

Get paths to the WAE-NET HDF5 files.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • dataset_id: Which of the seven sub-datasets to use (1–7).
  • split: The data split. Either 'train', 'test', or None for all data.
  • val_fraction: Fraction of images reserved for the test split (default 0.2, matching the paper's 8:2 ratio).
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths to the HDF5 files.

def get_waenet_dataset( path: Union[os.PathLike, str], dataset_id: int, patch_shape: Tuple[int, int], split: Optional[Literal['train', 'test']] = None, val_fraction: float = 0.2, label_type: Optional[str] = None, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
217def get_waenet_dataset(
218    path: Union[os.PathLike, str],
219    dataset_id: int,
220    patch_shape: Tuple[int, int],
221    split: Optional[Literal["train", "test"]] = None,
222    val_fraction: float = 0.2,
223    label_type: Optional[str] = None,
224    download: bool = False,
225    **kwargs,
226) -> Dataset:
227    """Get the WAE-NET dataset for multi-class semantic segmentation in electron microscopy.
228
229    Args:
230        path: Filepath to a folder where the downloaded data will be saved.
231        dataset_id: Which of the seven sub-datasets to use (1–7).
232        patch_shape: The patch shape to use for training.
233        split: The data split. Either 'train', 'test', or None for all data.
234        val_fraction: Fraction of images reserved for the test split (default 0.2).
235        label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria').
236            If None, defaults to the first non-background class for the given dataset.
237            Available classes per dataset are listed in `DATASET_CLASS_NAMES`.
238        download: Whether to download the data if it is not present.
239        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
240
241    Returns:
242        The segmentation dataset.
243    """
244    all_paths = get_waenet_paths(path, dataset_id, split, val_fraction, download)
245
246    if label_type is None:
247        label_type = DATASET_CLASS_NAMES[dataset_id][1]
248
249    valid_types = DATASET_CLASS_NAMES[dataset_id]
250    if label_type not in valid_types:
251        raise ValueError(f"Invalid label_type '{label_type}' for dataset {dataset_id}. Choose from {valid_types}.")
252
253    return torch_em.default_segmentation_dataset(
254        raw_paths=all_paths,
255        raw_key="raw",
256        label_paths=all_paths,
257        label_key=f"labels/{label_type}",
258        patch_shape=patch_shape,
259        **kwargs,
260    )

Get the WAE-NET dataset for multi-class semantic segmentation in electron microscopy.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • dataset_id: Which of the seven sub-datasets to use (1–7).
  • patch_shape: The patch shape to use for training.
  • split: The data split. Either 'train', 'test', or None for all data.
  • val_fraction: Fraction of images reserved for the test split (default 0.2).
  • label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria'). If None, defaults to the first non-background class for the given dataset. Available classes per dataset are listed in DATASET_CLASS_NAMES.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_waenet_loader( path: Union[os.PathLike, str], dataset_id: int, patch_shape: Tuple[int, int], batch_size: int, split: Optional[Literal['train', 'test']] = None, val_fraction: float = 0.2, label_type: Optional[str] = None, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
263def get_waenet_loader(
264    path: Union[os.PathLike, str],
265    dataset_id: int,
266    patch_shape: Tuple[int, int],
267    batch_size: int,
268    split: Optional[Literal["train", "test"]] = None,
269    val_fraction: float = 0.2,
270    label_type: Optional[str] = None,
271    download: bool = False,
272    **kwargs,
273) -> DataLoader:
274    """Get the WAE-NET dataloader for multi-class semantic segmentation in electron microscopy.
275
276    Args:
277        path: Filepath to a folder where the downloaded data will be saved.
278        dataset_id: Which of the seven sub-datasets to use (1–7).
279        patch_shape: The patch shape to use for training.
280        batch_size: The batch size for training.
281        split: The data split. Either 'train', 'test', or None for all data.
282        val_fraction: Fraction of images reserved for the test split (default 0.2).
283        label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria').
284            If None, defaults to the first non-background class for the given dataset.
285            Available classes per dataset are listed in `DATASET_CLASS_NAMES`.
286        download: Whether to download the data if it is not present.
287        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
288
289    Returns:
290        The PyTorch DataLoader.
291    """
292    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
293    dataset = get_waenet_dataset(path, dataset_id, patch_shape, split, val_fraction, label_type, download, **ds_kwargs)
294    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the WAE-NET dataloader for multi-class semantic segmentation in electron microscopy.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • dataset_id: Which of the seven sub-datasets to use (1–7).
  • patch_shape: The patch shape to use for training.
  • batch_size: The batch size for training.
  • split: The data split. Either 'train', 'test', or None for all data.
  • val_fraction: Fraction of images reserved for the test split (default 0.2).
  • label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria'). If None, defaults to the first non-background class for the given dataset. Available classes per dataset are listed in DATASET_CLASS_NAMES.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The PyTorch DataLoader.