torch_em.data.datasets.electron_microscopy.cem
The CEM, or MitoLab, dataset is a collection of data for training mitochondria generalist models. It consists of:
- CEM-MitoLab: annotated 2d data for training mitochondria segmentation models
- CEM-Mito-Benchmark: 7 Benchmark datasets for mitochondria segmentation
- CEM-1.5M: unlabeled EM images for pretraining: (Not yet implemented)
These datasets are from the publication https://doi.org/10.1016/j.cels.2022.12.006. Please cite this publication if you use this data in your research.
The data itself can be downloaded from EMPIAR via aspera.
- You can install aspera via mamba. We recommend to do this in a separate environment
to avoid dependency issues:
$ mamba create -c conda-forge -c hcc -n aspera aspera-cli
- After this you can run
$ mamba activate aspera
to have an environment with aspera installed. - You can then download the data for one of the three datasets like this:
- ascp -QT -l 200m -P33001 -i
/etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/ - Where
is the path to the mamba environment, the id of one of the three datasets and where you want to download the data.
- ascp -QT -l 200m -P33001 -i
- After this you can use the functions in this file if you use
as location for the data.
Note that we have implemented automatic download, but this leads to dependency issues, so we recommend to download the data manually and then run the loaders with the correct path.
1"""The CEM, or MitoLab, dataset is a collection of data for 2training mitochondria generalist models. It consists of: 3- CEM-MitoLab: annotated 2d data for training mitochondria segmentation models 4 - https://www.ebi.ac.uk/empiar/EMPIAR-11037/ 5- CEM-Mito-Benchmark: 7 Benchmark datasets for mitochondria segmentation 6 - https://www.ebi.ac.uk/empiar/EMPIAR-10982/ 7- CEM-1.5M: unlabeled EM images for pretraining: (Not yet implemented) 8 - https://www.ebi.ac.uk/empiar/EMPIAR-11035/ 9 10These datasets are from the publication https://doi.org/10.1016/j.cels.2022.12.006. 11Please cite this publication if you use this data in your research. 12 13The data itself can be downloaded from EMPIAR via aspera. 14- You can install aspera via mamba. We recommend to do this in a separate environment 15 to avoid dependency issues: 16 - `$ mamba create -c conda-forge -c hcc -n aspera aspera-cli` 17- After this you can run `$ mamba activate aspera` to have an environment with aspera installed. 18- You can then download the data for one of the three datasets like this: 19 - ascp -QT -l 200m -P33001 -i <PREFIX>/etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/<EMPIAR_ID> <PATH> 20 - Where <PREFIX> is the path to the mamba environment, <EMPIAR_ID> the id of one of the three datasets 21 and <PATH> where you want to download the data. 22- After this you can use the functions in this file if you use <PATH> as location for the data. 23 24Note that we have implemented automatic download, but this leads to dependency 25issues, so we recommend to download the data manually and then run the loaders with the correct path. 26""" 27 28import json 29import os 30from glob import glob 31from typing import List, Tuple, Union 32 33import imageio.v3 as imageio 34import numpy as np 35import torch_em 36from sklearn.model_selection import train_test_split 37from torch.utils.data import Dataset, DataLoader 38 39from .. import util 40 41BENCHMARK_DATASETS = { 42 1: "mito_benchmarks/c_elegans", 43 2: "mito_benchmarks/fly_brain", 44 3: "mito_benchmarks/glycolytic_muscle", 45 4: "mito_benchmarks/hela_cell", 46 5: "mito_benchmarks/lucchi_pp", 47 6: "mito_benchmarks/salivary_gland", 48 7: "tem_benchmark", 49} 50BENCHMARK_SHAPES = { 51 1: (256, 256, 256), 52 2: (256, 255, 255), 53 3: (302, 383, 765), 54 4: (256, 256, 256), 55 5: (165, 768, 1024), 56 6: (1260, 1081, 1200), 57 7: (224, 224), # NOTE: this is the minimal square shape that fits 58} 59 60 61def _get_mitolab_data(path, download): 62 access_id = "11037" 63 data_path = util.download_source_empiar(path, access_id, download) 64 65 zip_path = os.path.join(data_path, "data/cem_mitolab.zip") 66 if os.path.exists(zip_path): 67 util.unzip(zip_path, data_path, remove=True) 68 69 data_root = os.path.join(data_path, "cem_mitolab") 70 assert os.path.exists(data_root) 71 72 return data_root 73 74 75def _get_all_images(path): 76 raw_paths, label_paths = [], [] 77 folders = glob(os.path.join(path, "*")) 78 assert all(os.path.isdir(folder) for folder in folders) 79 for folder in folders: 80 images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) 81 assert len(images) > 0 82 labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) 83 assert len(images) == len(labels) 84 raw_paths.extend(images) 85 label_paths.extend(labels) 86 return raw_paths, label_paths 87 88 89def _get_non_empty_images(path): 90 save_path = os.path.join(path, "non_empty_images.json") 91 92 if os.path.exists(save_path): 93 with open(save_path, "r") as f: 94 saved_images = json.load(f) 95 raw_paths, label_paths = saved_images["images"], saved_images["labels"] 96 raw_paths = [os.path.join(path, rp) for rp in raw_paths] 97 label_paths = [os.path.join(path, lp) for lp in label_paths] 98 return raw_paths, label_paths 99 100 folders = glob(os.path.join(path, "*")) 101 assert all(os.path.isdir(folder) for folder in folders) 102 103 raw_paths, label_paths = [], [] 104 for folder in folders: 105 images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) 106 labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) 107 assert len(images) > 0 108 assert len(images) == len(labels) 109 110 for im, lab in zip(images, labels): 111 n_labels = len(np.unique(imageio.imread(lab))) 112 if n_labels > 1: 113 raw_paths.append(im) 114 label_paths.append(lab) 115 116 raw_paths_rel = [os.path.relpath(rp, path) for rp in raw_paths] 117 label_paths_rel = [os.path.relpath(lp, path) for lp in label_paths] 118 119 with open(save_path, "w") as f: 120 json.dump({"images": raw_paths_rel, "labels": label_paths_rel}, f) 121 122 return raw_paths, label_paths 123 124 125def get_mitolab_data( 126 path: Union[os.PathLike, str], 127 split: str, 128 val_fraction: float, 129 download: bool, 130 discard_empty_images: bool 131) -> Tuple[List[str], List[str]]: 132 """Download the mitolab training data. 133 134 Args: 135 path: Filepath to a folder where the downloaded data will be saved. 136 split: The data split. Either 'train' or 'val'. 137 val_fraction: The fraction of the data to use for validation. 138 download: Whether to download the data if it is not present. 139 discard_empty_images: Whether to discard images without annotations. 140 141 Returns: 142 List of the image data paths. 143 List of the label data paths. 144 """ 145 data_path = _get_mitolab_data(path, download) 146 if discard_empty_images: 147 raw_paths, label_paths = _get_non_empty_images(data_path) 148 else: 149 raw_paths, label_paths = _get_all_images(data_path) 150 151 if split is not None: 152 raw_train, raw_val, labels_train, labels_val = train_test_split( 153 raw_paths, label_paths, test_size=val_fraction, random_state=42, 154 ) 155 if split == "train": 156 raw_paths, label_paths = raw_train, labels_train 157 else: 158 raw_paths, label_paths = raw_val, labels_val 159 160 assert len(raw_paths) > 0 161 assert len(raw_paths) == len(label_paths) 162 return raw_paths, label_paths 163 164 165def get_benchmark_data( 166 path: Union[os.PathLike, str], 167 dataset_id: int, 168 download: bool 169) -> Tuple[ 170 List[str], List[str], str, str, bool 171]: 172 """Download the mitolab benechmark data. 173 174 Args: 175 path: Filepath to a folder where the downloaded data will be saved. 176 dataset_id: The id of the benchmark dataset to download. 177 download: Whether to download the data if it is not present. 178 179 Returns: 180 List of the image data paths. 181 List of the label data paths. 182 The image data key. 183 The label data key. 184 Whether this is a segmentation dataset. 185 """ 186 access_id = "10982" 187 data_path = util.download_source_empiar(path, access_id, download) 188 dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id]) 189 190 # these are the 3d datasets 191 if dataset_id in range(1, 7): 192 dataset_name = os.path.basename(dataset_path) 193 raw_paths = os.path.join(dataset_path, f"{dataset_name}_em.tif") 194 label_paths = os.path.join(dataset_path, f"{dataset_name}_mito.tif") 195 raw_key, label_key = None, None 196 is_seg_dataset = True 197 198 # this is the 2d dataset 199 else: 200 raw_paths = os.path.join(dataset_path, "images") 201 label_paths = os.path.join(dataset_path, "masks") 202 raw_key, label_key = "*.tiff", "*.tiff" 203 is_seg_dataset = False 204 205 return raw_paths, label_paths, raw_key, label_key, is_seg_dataset 206 207 208# 209# Datasets 210# 211 212 213def get_mitolab_dataset( 214 path: Union[os.PathLike, str], 215 split: str, 216 patch_shape: Tuple[int, int] = (224, 224), 217 val_fraction: float = 0.05, 218 download: bool = False, 219 discard_empty_images: bool = True, 220 **kwargs 221) -> Dataset: 222 """Get the dataset for the mitolab training data. 223 224 Args: 225 path: Filepath to a folder where the downloaded data will be saved. 226 split: The data split. Either 'train' or 'val'. 227 patch_shape: The patch shape to use for training. 228 val_fraction: The fraction of the data to use for validation. 229 download: Whether to download the data if it is not present. 230 discard_empty_images: Whether to discard images without annotations. 231 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 232 233 Returns: 234 The segmentation dataset. 235 """ 236 assert split in ("train", "val", None) 237 assert os.path.exists(path) 238 raw_paths, label_paths = get_mitolab_data(path, split, val_fraction, download, discard_empty_images) 239 return torch_em.default_segmentation_dataset( 240 raw_paths=raw_paths, raw_key=None, 241 label_paths=label_paths, label_key=None, 242 patch_shape=patch_shape, is_seg_dataset=False, ndim=2, **kwargs 243 ) 244 245 246def get_cem15m_dataset(path): 247 raise NotImplementedError 248 249 250def get_benchmark_dataset( 251 path, 252 dataset_id, 253 patch_shape, 254 download=False, 255 **kwargs, 256) -> Dataset: 257 """Get the dataset for one of the mitolab benchmark datasets. 258 259 Args: 260 path: Filepath to a folder where the downloaded data will be saved. 261 dataset_id: The id of the benchmark dataset to download. 262 patch_shape: The patch shape to use for training. 263 download: Whether to download the data if it is not present. 264 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 265 266 Returns: 267 The segmentation dataset. 268 """ 269 if dataset_id not in range(1, 8): 270 raise ValueError(f"Invalid dataset id {dataset_id}, expected id in range [1, 7].") 271 raw_paths, label_paths, raw_key, label_key, is_seg_dataset = get_benchmark_data(path, dataset_id, download) 272 return torch_em.default_segmentation_dataset( 273 raw_paths=raw_paths, raw_key=raw_key, 274 label_paths=label_paths, label_key=label_key, 275 patch_shape=patch_shape, 276 is_seg_dataset=is_seg_dataset, **kwargs, 277 ) 278 279 280# 281# DataLoaders 282# 283 284 285def get_mitolab_loader( 286 path: Union[os.PathLike, str], 287 split: str, 288 batch_size: int, 289 patch_shape: Tuple[int, int] = (224, 224), 290 discard_empty_images: bool = True, 291 val_fraction: float = 0.05, 292 download: bool = False, 293 **kwargs 294) -> DataLoader: 295 """Get the dataloader for the mitolab training data. 296 297 Args: 298 path: Filepath to a folder where the downloaded data will be saved. 299 split: The data split. Either 'train' or 'val'. 300 batch_size: The batch size for training. 301 patch_shape: The patch shape to use for training. 302 discard_empty_images: Whether to discard images without annotations. 303 val_fraction: The fraction of the data to use for validation. 304 download: Whether to download the data if it is not present. 305 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 306 307 Returns: 308 The PyTorch DataLoader. 309 """ 310 ds_kwargs, loader_kwargs = util.split_kwargs( 311 torch_em.default_segmentation_dataset, **kwargs 312 ) 313 dataset = get_mitolab_dataset( 314 path, split, patch_shape, download=download, discard_empty_images=discard_empty_images, **ds_kwargs 315 ) 316 loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 317 return loader 318 319 320def get_cem15m_loader(path): 321 raise NotImplementedError 322 323 324def get_benchmark_loader( 325 path: Union[os.PathLike, str], 326 dataset_id: int, 327 batch_size: int, 328 patch_shape: Tuple[int, int], 329 download: bool = False, 330 **kwargs 331) -> DataLoader: 332 """Get the datasloader for one of the mitolab benchmark datasets. 333 334 Args: 335 path: Filepath to a folder where the downloaded data will be saved. 336 dataset_id: The id of the benchmark dataset to download. 337 batch_size: The batch size for training. 338 patch_shape: The patch shape to use for training. 339 download: Whether to download the data if it is not present. 340 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 341 342 Returns: 343 The DataLoader. 344 """ 345 ds_kwargs, loader_kwargs = util.split_kwargs( 346 torch_em.default_segmentation_dataset, **kwargs 347 ) 348 dataset = get_benchmark_dataset( 349 path, dataset_id, 350 patch_shape=patch_shape, download=download, **ds_kwargs 351 ) 352 loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 353 return loader
126def get_mitolab_data( 127 path: Union[os.PathLike, str], 128 split: str, 129 val_fraction: float, 130 download: bool, 131 discard_empty_images: bool 132) -> Tuple[List[str], List[str]]: 133 """Download the mitolab training data. 134 135 Args: 136 path: Filepath to a folder where the downloaded data will be saved. 137 split: The data split. Either 'train' or 'val'. 138 val_fraction: The fraction of the data to use for validation. 139 download: Whether to download the data if it is not present. 140 discard_empty_images: Whether to discard images without annotations. 141 142 Returns: 143 List of the image data paths. 144 List of the label data paths. 145 """ 146 data_path = _get_mitolab_data(path, download) 147 if discard_empty_images: 148 raw_paths, label_paths = _get_non_empty_images(data_path) 149 else: 150 raw_paths, label_paths = _get_all_images(data_path) 151 152 if split is not None: 153 raw_train, raw_val, labels_train, labels_val = train_test_split( 154 raw_paths, label_paths, test_size=val_fraction, random_state=42, 155 ) 156 if split == "train": 157 raw_paths, label_paths = raw_train, labels_train 158 else: 159 raw_paths, label_paths = raw_val, labels_val 160 161 assert len(raw_paths) > 0 162 assert len(raw_paths) == len(label_paths) 163 return raw_paths, label_paths
Download the mitolab training data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The data split. Either 'train' or 'val'.
- val_fraction: The fraction of the data to use for validation.
- download: Whether to download the data if it is not present.
- discard_empty_images: Whether to discard images without annotations.
Returns:
List of the image data paths. List of the label data paths.
166def get_benchmark_data( 167 path: Union[os.PathLike, str], 168 dataset_id: int, 169 download: bool 170) -> Tuple[ 171 List[str], List[str], str, str, bool 172]: 173 """Download the mitolab benechmark data. 174 175 Args: 176 path: Filepath to a folder where the downloaded data will be saved. 177 dataset_id: The id of the benchmark dataset to download. 178 download: Whether to download the data if it is not present. 179 180 Returns: 181 List of the image data paths. 182 List of the label data paths. 183 The image data key. 184 The label data key. 185 Whether this is a segmentation dataset. 186 """ 187 access_id = "10982" 188 data_path = util.download_source_empiar(path, access_id, download) 189 dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id]) 190 191 # these are the 3d datasets 192 if dataset_id in range(1, 7): 193 dataset_name = os.path.basename(dataset_path) 194 raw_paths = os.path.join(dataset_path, f"{dataset_name}_em.tif") 195 label_paths = os.path.join(dataset_path, f"{dataset_name}_mito.tif") 196 raw_key, label_key = None, None 197 is_seg_dataset = True 198 199 # this is the 2d dataset 200 else: 201 raw_paths = os.path.join(dataset_path, "images") 202 label_paths = os.path.join(dataset_path, "masks") 203 raw_key, label_key = "*.tiff", "*.tiff" 204 is_seg_dataset = False 205 206 return raw_paths, label_paths, raw_key, label_key, is_seg_dataset
Download the mitolab benechmark data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: The id of the benchmark dataset to download.
- download: Whether to download the data if it is not present.
Returns:
List of the image data paths. List of the label data paths. The image data key. The label data key. Whether this is a segmentation dataset.
214def get_mitolab_dataset( 215 path: Union[os.PathLike, str], 216 split: str, 217 patch_shape: Tuple[int, int] = (224, 224), 218 val_fraction: float = 0.05, 219 download: bool = False, 220 discard_empty_images: bool = True, 221 **kwargs 222) -> Dataset: 223 """Get the dataset for the mitolab training data. 224 225 Args: 226 path: Filepath to a folder where the downloaded data will be saved. 227 split: The data split. Either 'train' or 'val'. 228 patch_shape: The patch shape to use for training. 229 val_fraction: The fraction of the data to use for validation. 230 download: Whether to download the data if it is not present. 231 discard_empty_images: Whether to discard images without annotations. 232 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 233 234 Returns: 235 The segmentation dataset. 236 """ 237 assert split in ("train", "val", None) 238 assert os.path.exists(path) 239 raw_paths, label_paths = get_mitolab_data(path, split, val_fraction, download, discard_empty_images) 240 return torch_em.default_segmentation_dataset( 241 raw_paths=raw_paths, raw_key=None, 242 label_paths=label_paths, label_key=None, 243 patch_shape=patch_shape, is_seg_dataset=False, ndim=2, **kwargs 244 )
Get the dataset for the mitolab training data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The data split. Either 'train' or 'val'.
- patch_shape: The patch shape to use for training.
- val_fraction: The fraction of the data to use for validation.
- download: Whether to download the data if it is not present.
- discard_empty_images: Whether to discard images without annotations.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
251def get_benchmark_dataset( 252 path, 253 dataset_id, 254 patch_shape, 255 download=False, 256 **kwargs, 257) -> Dataset: 258 """Get the dataset for one of the mitolab benchmark datasets. 259 260 Args: 261 path: Filepath to a folder where the downloaded data will be saved. 262 dataset_id: The id of the benchmark dataset to download. 263 patch_shape: The patch shape to use for training. 264 download: Whether to download the data if it is not present. 265 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 266 267 Returns: 268 The segmentation dataset. 269 """ 270 if dataset_id not in range(1, 8): 271 raise ValueError(f"Invalid dataset id {dataset_id}, expected id in range [1, 7].") 272 raw_paths, label_paths, raw_key, label_key, is_seg_dataset = get_benchmark_data(path, dataset_id, download) 273 return torch_em.default_segmentation_dataset( 274 raw_paths=raw_paths, raw_key=raw_key, 275 label_paths=label_paths, label_key=label_key, 276 patch_shape=patch_shape, 277 is_seg_dataset=is_seg_dataset, **kwargs, 278 )
Get the dataset for one of the mitolab benchmark datasets.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: The id of the benchmark dataset to download.
- patch_shape: The patch shape to use for training.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
286def get_mitolab_loader( 287 path: Union[os.PathLike, str], 288 split: str, 289 batch_size: int, 290 patch_shape: Tuple[int, int] = (224, 224), 291 discard_empty_images: bool = True, 292 val_fraction: float = 0.05, 293 download: bool = False, 294 **kwargs 295) -> DataLoader: 296 """Get the dataloader for the mitolab training data. 297 298 Args: 299 path: Filepath to a folder where the downloaded data will be saved. 300 split: The data split. Either 'train' or 'val'. 301 batch_size: The batch size for training. 302 patch_shape: The patch shape to use for training. 303 discard_empty_images: Whether to discard images without annotations. 304 val_fraction: The fraction of the data to use for validation. 305 download: Whether to download the data if it is not present. 306 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 307 308 Returns: 309 The PyTorch DataLoader. 310 """ 311 ds_kwargs, loader_kwargs = util.split_kwargs( 312 torch_em.default_segmentation_dataset, **kwargs 313 ) 314 dataset = get_mitolab_dataset( 315 path, split, patch_shape, download=download, discard_empty_images=discard_empty_images, **ds_kwargs 316 ) 317 loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 318 return loader
Get the dataloader for the mitolab training data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The data split. Either 'train' or 'val'.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- discard_empty_images: Whether to discard images without annotations.
- val_fraction: The fraction of the data to use for validation.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The PyTorch DataLoader.
325def get_benchmark_loader( 326 path: Union[os.PathLike, str], 327 dataset_id: int, 328 batch_size: int, 329 patch_shape: Tuple[int, int], 330 download: bool = False, 331 **kwargs 332) -> DataLoader: 333 """Get the datasloader for one of the mitolab benchmark datasets. 334 335 Args: 336 path: Filepath to a folder where the downloaded data will be saved. 337 dataset_id: The id of the benchmark dataset to download. 338 batch_size: The batch size for training. 339 patch_shape: The patch shape to use for training. 340 download: Whether to download the data if it is not present. 341 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 342 343 Returns: 344 The DataLoader. 345 """ 346 ds_kwargs, loader_kwargs = util.split_kwargs( 347 torch_em.default_segmentation_dataset, **kwargs 348 ) 349 dataset = get_benchmark_dataset( 350 path, dataset_id, 351 patch_shape=patch_shape, download=download, **ds_kwargs 352 ) 353 loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 354 return loader
Get the datasloader for one of the mitolab benchmark datasets.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: The id of the benchmark dataset to download.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.