torch_em.data.datasets.electron_microscopy.waenet
The WAE-NET dataset contains seven biological electron microscopy datasets for multi-class semantic segmentation of cellular structures.
The seven sub-datasets cover different cell types and imaging modalities: - Dataset 1: Human pancreatic carcinoid cell line (FIB-SEM) — background, cytoplasm, nucleus - Dataset 2: BON cell during interphase (ssTEM) — background, cytoplasm, chromosomes - Dataset 3: BON cell during mitosis (ssTEM) — background, cytoplasm, nucleus, mitochondria - Dataset 4: Human T-cell line Jurkat (TEM) — background, cytoplasm, nucleus - Dataset 5: Primary human T-cell blood (TEM) — background, cytoplasm, nucleus - Dataset 6: Murine B-cell line J558L (TEM) — background, cytoplasm, nucleus - Dataset 7: Phytohemagglutinin/IL-2 expanded human T cells (TEM) — background, cytoplasm, nucleus
The data is available at https://doi.org/10.17632/9rdmnn2x4x.1. The dataset was published in https://doi.org/10.1007/s00418-022-02148-3. Please cite this publication if you use the dataset in your research.
1"""The WAE-NET dataset contains seven biological electron microscopy datasets 2for multi-class semantic segmentation of cellular structures. 3 4The seven sub-datasets cover different cell types and imaging modalities: 5 - Dataset 1: Human pancreatic carcinoid cell line (FIB-SEM) — background, cytoplasm, nucleus 6 - Dataset 2: BON cell during interphase (ssTEM) — background, cytoplasm, chromosomes 7 - Dataset 3: BON cell during mitosis (ssTEM) — background, cytoplasm, nucleus, mitochondria 8 - Dataset 4: Human T-cell line Jurkat (TEM) — background, cytoplasm, nucleus 9 - Dataset 5: Primary human T-cell blood (TEM) — background, cytoplasm, nucleus 10 - Dataset 6: Murine B-cell line J558L (TEM) — background, cytoplasm, nucleus 11 - Dataset 7: Phytohemagglutinin/IL-2 expanded human T cells (TEM) — background, cytoplasm, nucleus 12 13The data is available at https://doi.org/10.17632/9rdmnn2x4x.1. 14The dataset was published in https://doi.org/10.1007/s00418-022-02148-3. 15Please cite this publication if you use the dataset in your research. 16""" 17 18import os 19from glob import glob 20from shutil import rmtree 21from typing import List, Literal, Optional, Tuple, Union 22 23import numpy as np 24 25from torch.utils.data import DataLoader, Dataset 26 27import torch_em 28 29from .. import util 30 31 32URL = "https://zenodo.org/records/6603083/files/Datasets.zip" 33CHECKSUM = None 34 35# Maps dataset_id -> number of segmentation classes (including background) 36DATASET_CLASSES = {1: 3, 2: 3, 3: 4, 4: 3, 5: 3, 6: 3, 7: 3} 37 38# Maps dataset_id -> ordered list of class names (index 0 = background, etc.) 39DATASET_CLASS_NAMES = { 40 1: ["background", "cytoplasm", "nucleus"], 41 2: ["background", "cytoplasm", "chromosomes"], 42 3: ["background", "cytoplasm", "nucleus", "mitochondria"], 43 4: ["background", "cytoplasm", "nucleus"], 44 5: ["background", "cytoplasm", "nucleus"], 45 6: ["background", "cytoplasm", "nucleus"], 46 7: ["background", "cytoplasm", "nucleus"], 47} 48 49 50def _get_dataset_dir(data_root, dataset_id): 51 """Find the subdirectory for a given dataset ID inside the extracted archive.""" 52 for dname in ( 53 f"Dataset {dataset_id}", f"Dataset_{dataset_id}", f"Dataset{dataset_id}", f"D{dataset_id}", str(dataset_id) 54 ): 55 d = os.path.join(data_root, dname) 56 if os.path.exists(d): 57 return d 58 raise RuntimeError( 59 f"Cannot find a sub-directory for dataset {dataset_id} inside '{data_root}'. " 60 f"Contents: {os.listdir(data_root)}" 61 ) 62 63 64def _get_image_mask_dirs(dataset_dir): 65 """Find the image and mask subdirectories within a per-dataset directory.""" 66 img_dir = None 67 for name in ("Images", "images", "Image", "image", "Raw", "raw"): 68 candidate = os.path.join(dataset_dir, name) 69 if os.path.exists(candidate): 70 img_dir = candidate 71 break 72 73 mask_dir = None 74 for name in ("Ground truth mask", "Masks", "masks", "Mask", "mask", "Labels", "labels", "Label", "label"): 75 candidate = os.path.join(dataset_dir, name) 76 if os.path.exists(candidate): 77 mask_dir = candidate 78 break 79 80 if img_dir is None or mask_dir is None: 81 raise RuntimeError( 82 f"Cannot find image/mask directories inside '{dataset_dir}'. " 83 f"Contents: {os.listdir(dataset_dir)}" 84 ) 85 return img_dir, mask_dir 86 87 88def _create_h5_files(data_root, dataset_id, out_dir): 89 """Convert TIF image/mask pairs for one sub-dataset into individual HDF5 files.""" 90 import h5py 91 import imageio.v3 as imageio 92 93 dataset_dir = _get_dataset_dir(data_root, dataset_id) 94 img_dir, mask_dir = _get_image_mask_dirs(dataset_dir) 95 96 image_files = sorted( 97 glob(os.path.join(img_dir, "*.tif")) + 98 glob(os.path.join(img_dir, "*.tiff")) + 99 glob(os.path.join(img_dir, "*.png")) 100 ) 101 mask_files = sorted( 102 glob(os.path.join(mask_dir, "*.tif")) + 103 glob(os.path.join(mask_dir, "*.tiff")) + 104 glob(os.path.join(mask_dir, "*.png")) 105 ) 106 107 assert len(image_files) > 0, f"No TIF files found in '{img_dir}'" 108 assert len(image_files) == len(mask_files), ( 109 f"Mismatch: {len(image_files)} images vs {len(mask_files)} masks in '{dataset_dir}'" 110 ) 111 112 os.makedirs(out_dir, exist_ok=True) 113 114 for img_path, mask_path in zip(image_files, mask_files): 115 fname = os.path.splitext(os.path.basename(img_path))[0] 116 out_path = os.path.join(out_dir, f"{fname}.h5") 117 118 raw = imageio.imread(img_path) 119 if raw.ndim == 3: # drop extra channels (e.g. RGBA -> grayscale) 120 raw = raw[..., 0] 121 122 labels = imageio.imread(mask_path) 123 if labels.ndim == 3: 124 labels = labels[..., 0] 125 126 # Remap arbitrary grayscale values to consecutive class indices (0, 1, 2, …). 127 unique_vals = np.sort(np.unique(labels)) 128 if not np.array_equal(unique_vals, np.arange(len(unique_vals))): 129 new_labels = np.zeros_like(labels) 130 for cls_idx, val in enumerate(unique_vals): 131 new_labels[labels == val] = cls_idx 132 labels = new_labels 133 134 class_names = DATASET_CLASS_NAMES[dataset_id] 135 136 with h5py.File(out_path, "w") as f: 137 f.create_dataset("raw", data=raw, compression="gzip") 138 label_group = f.create_group("labels") 139 for cls_idx, cls_name in enumerate(class_names): 140 binary_mask = (labels == cls_idx).astype("uint8") 141 label_group.create_dataset(cls_name, data=binary_mask, compression="gzip") 142 143 144def get_waenet_data(path: Union[os.PathLike, str], dataset_id: int, download: bool = False) -> str: 145 """Download and preprocess the WAE-NET dataset. 146 147 Args: 148 path: Filepath to a folder where the downloaded data will be saved. 149 dataset_id: Which of the seven sub-datasets to use (1–7). 150 download: Whether to download the data if it is not present. 151 152 Returns: 153 The path to the directory containing the preprocessed HDF5 files. 154 """ 155 if dataset_id not in DATASET_CLASSES: 156 raise ValueError(f"Invalid dataset_id {dataset_id!r}. Choose from {sorted(DATASET_CLASSES)}.") 157 158 out_dir = os.path.join(path, f"dataset_{dataset_id}") 159 if os.path.exists(out_dir): 160 return out_dir 161 162 os.makedirs(path, exist_ok=True) 163 164 zip_path = os.path.join(path, "Datasets.zip") 165 util.download_source(zip_path, URL, download, checksum=CHECKSUM) 166 167 # Extract to a temporary sub-directory and process all seven datasets in one pass. 168 extract_dir = os.path.join(path, "_extracted") 169 util.unzip(zip_path, extract_dir, remove=True) 170 171 # The archive likely contains a single root folder (e.g. "Datasets/"). 172 subdirs = [ 173 d for d in os.listdir(extract_dir) if os.path.isdir(os.path.join(extract_dir, d)) 174 ] 175 data_root = os.path.join(extract_dir, subdirs[0]) if subdirs else extract_dir 176 177 for did in DATASET_CLASSES: 178 _create_h5_files(data_root, did, os.path.join(path, f"dataset_{did}")) 179 180 rmtree(extract_dir) 181 182 return out_dir 183 184 185def get_waenet_paths( 186 path: Union[os.PathLike, str], 187 dataset_id: int, 188 split: Optional[Literal["train", "test"]] = None, 189 val_fraction: float = 0.2, 190 download: bool = False, 191) -> List[str]: 192 """Get paths to the WAE-NET HDF5 files. 193 194 Args: 195 path: Filepath to a folder where the downloaded data will be saved. 196 dataset_id: Which of the seven sub-datasets to use (1–7). 197 split: The data split. Either 'train', 'test', or None for all data. 198 val_fraction: Fraction of images reserved for the test split (default 0.2, matching the paper's 8:2 ratio). 199 download: Whether to download the data if it is not present. 200 201 Returns: 202 List of filepaths to the HDF5 files. 203 """ 204 data_dir = get_waenet_data(path, dataset_id, download) 205 all_paths = sorted(glob(os.path.join(data_dir, "*.h5"))) 206 assert len(all_paths) > 0, f"No HDF5 files found in '{data_dir}'" 207 208 if split is None: 209 return all_paths 210 211 assert split in ("train", "test"), f"split must be 'train', 'test', or None, got {split!r}" 212 n_train = int(len(all_paths) * (1 - val_fraction)) 213 return all_paths[:n_train] if split == "train" else all_paths[n_train:] 214 215 216def get_waenet_dataset( 217 path: Union[os.PathLike, str], 218 dataset_id: int, 219 patch_shape: Tuple[int, int], 220 split: Optional[Literal["train", "test"]] = None, 221 val_fraction: float = 0.2, 222 label_type: Optional[str] = None, 223 download: bool = False, 224 **kwargs, 225) -> Dataset: 226 """Get the WAE-NET dataset for multi-class semantic segmentation in electron microscopy. 227 228 Args: 229 path: Filepath to a folder where the downloaded data will be saved. 230 dataset_id: Which of the seven sub-datasets to use (1–7). 231 patch_shape: The patch shape to use for training. 232 split: The data split. Either 'train', 'test', or None for all data. 233 val_fraction: Fraction of images reserved for the test split (default 0.2). 234 label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria'). 235 If None, defaults to the first non-background class for the given dataset. 236 Available classes per dataset are listed in `DATASET_CLASS_NAMES`. 237 download: Whether to download the data if it is not present. 238 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 239 240 Returns: 241 The segmentation dataset. 242 """ 243 all_paths = get_waenet_paths(path, dataset_id, split, val_fraction, download) 244 245 if label_type is None: 246 label_type = DATASET_CLASS_NAMES[dataset_id][1] 247 248 valid_types = DATASET_CLASS_NAMES[dataset_id] 249 if label_type not in valid_types: 250 raise ValueError(f"Invalid label_type '{label_type}' for dataset {dataset_id}. Choose from {valid_types}.") 251 252 return torch_em.default_segmentation_dataset( 253 raw_paths=all_paths, 254 raw_key="raw", 255 label_paths=all_paths, 256 label_key=f"labels/{label_type}", 257 patch_shape=patch_shape, 258 **kwargs, 259 ) 260 261 262def get_waenet_loader( 263 path: Union[os.PathLike, str], 264 dataset_id: int, 265 patch_shape: Tuple[int, int], 266 batch_size: int, 267 split: Optional[Literal["train", "test"]] = None, 268 val_fraction: float = 0.2, 269 label_type: Optional[str] = None, 270 download: bool = False, 271 **kwargs, 272) -> DataLoader: 273 """Get the WAE-NET dataloader for multi-class semantic segmentation in electron microscopy. 274 275 Args: 276 path: Filepath to a folder where the downloaded data will be saved. 277 dataset_id: Which of the seven sub-datasets to use (1–7). 278 patch_shape: The patch shape to use for training. 279 batch_size: The batch size for training. 280 split: The data split. Either 'train', 'test', or None for all data. 281 val_fraction: Fraction of images reserved for the test split (default 0.2). 282 label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria'). 283 If None, defaults to the first non-background class for the given dataset. 284 Available classes per dataset are listed in `DATASET_CLASS_NAMES`. 285 download: Whether to download the data if it is not present. 286 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 287 288 Returns: 289 The PyTorch DataLoader. 290 """ 291 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 292 dataset = get_waenet_dataset(path, dataset_id, patch_shape, split, val_fraction, label_type, download, **ds_kwargs) 293 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
145def get_waenet_data(path: Union[os.PathLike, str], dataset_id: int, download: bool = False) -> str: 146 """Download and preprocess the WAE-NET dataset. 147 148 Args: 149 path: Filepath to a folder where the downloaded data will be saved. 150 dataset_id: Which of the seven sub-datasets to use (1–7). 151 download: Whether to download the data if it is not present. 152 153 Returns: 154 The path to the directory containing the preprocessed HDF5 files. 155 """ 156 if dataset_id not in DATASET_CLASSES: 157 raise ValueError(f"Invalid dataset_id {dataset_id!r}. Choose from {sorted(DATASET_CLASSES)}.") 158 159 out_dir = os.path.join(path, f"dataset_{dataset_id}") 160 if os.path.exists(out_dir): 161 return out_dir 162 163 os.makedirs(path, exist_ok=True) 164 165 zip_path = os.path.join(path, "Datasets.zip") 166 util.download_source(zip_path, URL, download, checksum=CHECKSUM) 167 168 # Extract to a temporary sub-directory and process all seven datasets in one pass. 169 extract_dir = os.path.join(path, "_extracted") 170 util.unzip(zip_path, extract_dir, remove=True) 171 172 # The archive likely contains a single root folder (e.g. "Datasets/"). 173 subdirs = [ 174 d for d in os.listdir(extract_dir) if os.path.isdir(os.path.join(extract_dir, d)) 175 ] 176 data_root = os.path.join(extract_dir, subdirs[0]) if subdirs else extract_dir 177 178 for did in DATASET_CLASSES: 179 _create_h5_files(data_root, did, os.path.join(path, f"dataset_{did}")) 180 181 rmtree(extract_dir) 182 183 return out_dir
Download and preprocess the WAE-NET dataset.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: Which of the seven sub-datasets to use (1–7).
- download: Whether to download the data if it is not present.
Returns:
The path to the directory containing the preprocessed HDF5 files.
186def get_waenet_paths( 187 path: Union[os.PathLike, str], 188 dataset_id: int, 189 split: Optional[Literal["train", "test"]] = None, 190 val_fraction: float = 0.2, 191 download: bool = False, 192) -> List[str]: 193 """Get paths to the WAE-NET HDF5 files. 194 195 Args: 196 path: Filepath to a folder where the downloaded data will be saved. 197 dataset_id: Which of the seven sub-datasets to use (1–7). 198 split: The data split. Either 'train', 'test', or None for all data. 199 val_fraction: Fraction of images reserved for the test split (default 0.2, matching the paper's 8:2 ratio). 200 download: Whether to download the data if it is not present. 201 202 Returns: 203 List of filepaths to the HDF5 files. 204 """ 205 data_dir = get_waenet_data(path, dataset_id, download) 206 all_paths = sorted(glob(os.path.join(data_dir, "*.h5"))) 207 assert len(all_paths) > 0, f"No HDF5 files found in '{data_dir}'" 208 209 if split is None: 210 return all_paths 211 212 assert split in ("train", "test"), f"split must be 'train', 'test', or None, got {split!r}" 213 n_train = int(len(all_paths) * (1 - val_fraction)) 214 return all_paths[:n_train] if split == "train" else all_paths[n_train:]
Get paths to the WAE-NET HDF5 files.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: Which of the seven sub-datasets to use (1–7).
- split: The data split. Either 'train', 'test', or None for all data.
- val_fraction: Fraction of images reserved for the test split (default 0.2, matching the paper's 8:2 ratio).
- download: Whether to download the data if it is not present.
Returns:
List of filepaths to the HDF5 files.
217def get_waenet_dataset( 218 path: Union[os.PathLike, str], 219 dataset_id: int, 220 patch_shape: Tuple[int, int], 221 split: Optional[Literal["train", "test"]] = None, 222 val_fraction: float = 0.2, 223 label_type: Optional[str] = None, 224 download: bool = False, 225 **kwargs, 226) -> Dataset: 227 """Get the WAE-NET dataset for multi-class semantic segmentation in electron microscopy. 228 229 Args: 230 path: Filepath to a folder where the downloaded data will be saved. 231 dataset_id: Which of the seven sub-datasets to use (1–7). 232 patch_shape: The patch shape to use for training. 233 split: The data split. Either 'train', 'test', or None for all data. 234 val_fraction: Fraction of images reserved for the test split (default 0.2). 235 label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria'). 236 If None, defaults to the first non-background class for the given dataset. 237 Available classes per dataset are listed in `DATASET_CLASS_NAMES`. 238 download: Whether to download the data if it is not present. 239 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 240 241 Returns: 242 The segmentation dataset. 243 """ 244 all_paths = get_waenet_paths(path, dataset_id, split, val_fraction, download) 245 246 if label_type is None: 247 label_type = DATASET_CLASS_NAMES[dataset_id][1] 248 249 valid_types = DATASET_CLASS_NAMES[dataset_id] 250 if label_type not in valid_types: 251 raise ValueError(f"Invalid label_type '{label_type}' for dataset {dataset_id}. Choose from {valid_types}.") 252 253 return torch_em.default_segmentation_dataset( 254 raw_paths=all_paths, 255 raw_key="raw", 256 label_paths=all_paths, 257 label_key=f"labels/{label_type}", 258 patch_shape=patch_shape, 259 **kwargs, 260 )
Get the WAE-NET dataset for multi-class semantic segmentation in electron microscopy.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: Which of the seven sub-datasets to use (1–7).
- patch_shape: The patch shape to use for training.
- split: The data split. Either 'train', 'test', or None for all data.
- val_fraction: Fraction of images reserved for the test split (default 0.2).
- label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria').
If None, defaults to the first non-background class for the given dataset.
Available classes per dataset are listed in
DATASET_CLASS_NAMES. - download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset.
Returns:
The segmentation dataset.
263def get_waenet_loader( 264 path: Union[os.PathLike, str], 265 dataset_id: int, 266 patch_shape: Tuple[int, int], 267 batch_size: int, 268 split: Optional[Literal["train", "test"]] = None, 269 val_fraction: float = 0.2, 270 label_type: Optional[str] = None, 271 download: bool = False, 272 **kwargs, 273) -> DataLoader: 274 """Get the WAE-NET dataloader for multi-class semantic segmentation in electron microscopy. 275 276 Args: 277 path: Filepath to a folder where the downloaded data will be saved. 278 dataset_id: Which of the seven sub-datasets to use (1–7). 279 patch_shape: The patch shape to use for training. 280 batch_size: The batch size for training. 281 split: The data split. Either 'train', 'test', or None for all data. 282 val_fraction: Fraction of images reserved for the test split (default 0.2). 283 label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria'). 284 If None, defaults to the first non-background class for the given dataset. 285 Available classes per dataset are listed in `DATASET_CLASS_NAMES`. 286 download: Whether to download the data if it is not present. 287 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 288 289 Returns: 290 The PyTorch DataLoader. 291 """ 292 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 293 dataset = get_waenet_dataset(path, dataset_id, patch_shape, split, val_fraction, label_type, download, **ds_kwargs) 294 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the WAE-NET dataloader for multi-class semantic segmentation in electron microscopy.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: Which of the seven sub-datasets to use (1–7).
- patch_shape: The patch shape to use for training.
- batch_size: The batch size for training.
- split: The data split. Either 'train', 'test', or None for all data.
- val_fraction: Fraction of images reserved for the test split (default 0.2).
- label_type: The class to use as segmentation target (e.g. 'cytoplasm', 'nucleus', 'mitochondria').
If None, defaults to the first non-background class for the given dataset.
Available classes per dataset are listed in
DATASET_CLASS_NAMES. - download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_datasetor for the PyTorch DataLoader.
Returns:
The PyTorch DataLoader.