torch_em.data.datasets.electron_microscopy.cem
The CEM, or MitoLab, dataset is a collection of data for training mitochondria generalist models. It consists of:
- CEM-MitoLab: annotated 2d data for training mitochondria segmentation models
- CEM-Mito-Benchmark: 7 Benchmark datasets for mitochondria segmentation
- CEM-1.5M: unlabeled EM images for pretraining: (Not yet implemented)
These datasets are from the publication https://doi.org/10.1016/j.cels.2022.12.006. Please cite this publication if you use this data in your research.
The data itself can be downloaded from EMPIAR via aspera.
- You can install aspera via mamba. We recommend to do this in a separate environment
to avoid dependency issues:
$ mamba create -c conda-forge -c hcc -n aspera aspera-cli
- After this you can run
$ mamba activate aspera
to have an environment with aspera installed. - You can then download the data for one of the three datasets like this:
- ascp -QT -l 200m -P33001 -i
/etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/ - Where
is the path to the mamba environment, the id of one of the three datasets and where you want to download the data.
- ascp -QT -l 200m -P33001 -i
- After this you can use the functions in this file if you use
as location for the data.
Note that we have implemented automatic download, but this leads to dependency issues, so we recommend to download the data manually and then run the loaders with the correct path.
1"""The CEM, or MitoLab, dataset is a collection of data for 2training mitochondria generalist models. It consists of: 3- CEM-MitoLab: annotated 2d data for training mitochondria segmentation models 4 - https://www.ebi.ac.uk/empiar/EMPIAR-11037/ 5- CEM-Mito-Benchmark: 7 Benchmark datasets for mitochondria segmentation 6 - https://www.ebi.ac.uk/empiar/EMPIAR-10982/ 7- CEM-1.5M: unlabeled EM images for pretraining: (Not yet implemented) 8 - https://www.ebi.ac.uk/empiar/EMPIAR-11035/ 9 10These datasets are from the publication https://doi.org/10.1016/j.cels.2022.12.006. 11Please cite this publication if you use this data in your research. 12 13The data itself can be downloaded from EMPIAR via aspera. 14- You can install aspera via mamba. We recommend to do this in a separate environment 15 to avoid dependency issues: 16 - `$ mamba create -c conda-forge -c hcc -n aspera aspera-cli` 17- After this you can run `$ mamba activate aspera` to have an environment with aspera installed. 18- You can then download the data for one of the three datasets like this: 19 - ascp -QT -l 200m -P33001 -i <PREFIX>/etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/<EMPIAR_ID> <PATH> 20 - Where <PREFIX> is the path to the mamba environment, <EMPIAR_ID> the id of one of the three datasets 21 and <PATH> where you want to download the data. 22- After this you can use the functions in this file if you use <PATH> as location for the data. 23 24Note that we have implemented automatic download, but this leads to dependency 25issues, so we recommend to download the data manually and then run the loaders with the correct path. 26""" 27 28import os 29import json 30from glob import glob 31from typing import List, Tuple, Union, Literal 32 33import numpy as np 34import imageio.v3 as imageio 35from sklearn.model_selection import train_test_split 36 37from torch.utils.data import Dataset, DataLoader 38 39import torch_em 40 41from .. import util 42 43 44BENCHMARK_DATASETS = { 45 1: "mito_benchmarks/c_elegans", 46 2: "mito_benchmarks/fly_brain", 47 3: "mito_benchmarks/glycolytic_muscle", 48 4: "mito_benchmarks/hela_cell", 49 5: "mito_benchmarks/lucchi_pp", 50 6: "mito_benchmarks/salivary_gland", 51 7: "tem_benchmark", 52} 53BENCHMARK_SHAPES = { 54 1: (256, 256, 256), 55 2: (256, 255, 255), 56 3: (302, 383, 765), 57 4: (256, 256, 256), 58 5: (165, 768, 1024), 59 6: (1260, 1081, 1200), 60 7: (224, 224), # NOTE: this is the minimal square shape that fits 61} 62 63 64def _get_all_images(path): 65 raw_paths, label_paths = [], [] 66 folders = glob(os.path.join(path, "*")) 67 assert all(os.path.isdir(folder) for folder in folders) 68 for folder in folders: 69 images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) 70 assert len(images) > 0 71 labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) 72 assert len(images) == len(labels) 73 raw_paths.extend(images) 74 label_paths.extend(labels) 75 return raw_paths, label_paths 76 77 78def _get_non_empty_images(path): 79 save_path = os.path.join(path, "non_empty_images.json") 80 81 if os.path.exists(save_path): 82 with open(save_path, "r") as f: 83 saved_images = json.load(f) 84 raw_paths, label_paths = saved_images["images"], saved_images["labels"] 85 raw_paths = [os.path.join(path, rp) for rp in raw_paths] 86 label_paths = [os.path.join(path, lp) for lp in label_paths] 87 return raw_paths, label_paths 88 89 folders = glob(os.path.join(path, "*")) 90 assert all(os.path.isdir(folder) for folder in folders) 91 92 raw_paths, label_paths = [], [] 93 for folder in folders: 94 images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) 95 labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) 96 assert len(images) > 0 97 assert len(images) == len(labels) 98 99 for im, lab in zip(images, labels): 100 n_labels = len(np.unique(imageio.imread(lab))) 101 if n_labels > 1: 102 raw_paths.append(im) 103 label_paths.append(lab) 104 105 raw_paths_rel = [os.path.relpath(rp, path) for rp in raw_paths] 106 label_paths_rel = [os.path.relpath(lp, path) for lp in label_paths] 107 108 with open(save_path, "w") as f: 109 json.dump({"images": raw_paths_rel, "labels": label_paths_rel}, f) 110 111 return raw_paths, label_paths 112 113 114def get_mitolab_data(path: Union[os.PathLike, str], download: bool = False) -> str: 115 """Download the MitoLab training data. 116 117 Args: 118 path: Filepath to a folder where the downloaded data will be saved. 119 download: Whether to download the data if it is not present. 120 121 Returns: 122 The filepath for the downloaded data. 123 """ 124 access_id = "11037" 125 data_path = util.download_source_empiar(path, access_id, download) 126 127 zip_path = os.path.join(data_path, "data/cem_mitolab.zip") 128 if os.path.exists(zip_path): 129 util.unzip(zip_path, data_path, remove=True) 130 131 data_root = os.path.join(data_path, "cem_mitolab") 132 assert os.path.exists(data_root) 133 134 return data_root 135 136 137def get_mitolab_paths( 138 path: Union[os.PathLike, str], 139 split: Literal['train', 'val'], 140 val_fraction: float = 0.05, 141 download: bool = False, 142 discard_empty_images: bool = True, 143) -> Tuple[List[str], List[str]]: 144 """Get the paths to MitoLab training data. 145 146 Args: 147 path: Filepath to a folder where the downloaded data will be saved. 148 split: The data split. Either 'train' or 'val'. 149 val_fraction: The fraction of the data to use for validation. 150 download: Whether to download the data if it is not present. 151 discard_empty_images: Whether to discard images without annotations. 152 153 Returns: 154 List of the image data paths. 155 List of the label data paths. 156 """ 157 data_path = get_mitolab_data(path, download) 158 159 if discard_empty_images: 160 raw_paths, label_paths = _get_non_empty_images(data_path) 161 else: 162 raw_paths, label_paths = _get_all_images(data_path) 163 164 if split is not None: 165 raw_train, raw_val, labels_train, labels_val = train_test_split( 166 raw_paths, label_paths, test_size=val_fraction, random_state=42, 167 ) 168 if split == "train": 169 raw_paths, label_paths = raw_train, labels_train 170 else: 171 raw_paths, label_paths = raw_val, labels_val 172 173 assert len(raw_paths) > 0 174 assert len(raw_paths) == len(label_paths) 175 return raw_paths, label_paths 176 177 178def get_benchmark_data(path: Union[os.PathLike, str], dataset_id: int, download: bool = False) -> str: 179 """Download the MitoLab benchmark data. 180 181 Args: 182 path: Filepath to a folder where the downloaded data will be saved. 183 dataset_id: The id of the benchmark dataset to download. 184 download: Whether to download the data if it is not present. 185 186 Returns: 187 The filepath for the stored data. 188 """ 189 access_id = "10982" 190 data_path = util.download_source_empiar(path, access_id, download) 191 dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id]) 192 return dataset_path 193 194 195def get_benchmark_paths( 196 path: Union[os.PathLike, str], dataset_id: int, download: bool = False 197) -> Tuple[List[str], List[str], str, str, bool]: 198 """Get paths to the MitoLab benchmark data. 199 200 Args: 201 path: Filepath to a folder where the downloaded data will be saved. 202 dataset_id: The id of the benchmark dataset to download. 203 download: Whether to download the data if it is not present. 204 205 Returns: 206 List of the image data paths. 207 List of the label data paths. 208 The image data key. 209 The label data key. 210 Whether this is a segmentation dataset. 211 """ 212 dataset_path = get_benchmark_data(path, dataset_id, download) 213 214 # these are the 3d datasets 215 if dataset_id in range(1, 7): 216 dataset_name = os.path.basename(dataset_path) 217 raw_paths = os.path.join(dataset_path, f"{dataset_name}_em.tif") 218 label_paths = os.path.join(dataset_path, f"{dataset_name}_mito.tif") 219 raw_key, label_key = None, None 220 is_seg_dataset = True 221 222 # this is the 2d dataset 223 else: 224 raw_paths = os.path.join(dataset_path, "images") 225 label_paths = os.path.join(dataset_path, "masks") 226 raw_key, label_key = "*.tiff", "*.tiff" 227 is_seg_dataset = False 228 229 return raw_paths, label_paths, raw_key, label_key, is_seg_dataset 230 231 232# 233# Datasets 234# 235 236 237def get_mitolab_dataset( 238 path: Union[os.PathLike, str], 239 split: Literal['train', 'val'], 240 patch_shape: Tuple[int, int] = (224, 224), 241 val_fraction: float = 0.05, 242 download: bool = False, 243 discard_empty_images: bool = True, 244 **kwargs 245) -> Dataset: 246 """Get the dataset for the MitoLab training data. 247 248 Args: 249 path: Filepath to a folder where the downloaded data will be saved. 250 split: The data split. Either 'train' or 'val'. 251 patch_shape: The patch shape to use for training. 252 val_fraction: The fraction of the data to use for validation. 253 download: Whether to download the data if it is not present. 254 discard_empty_images: Whether to discard images without annotations. 255 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 256 257 Returns: 258 The segmentation dataset. 259 """ 260 assert split in ("train", "val", None) 261 assert os.path.exists(path) 262 263 raw_paths, label_paths = get_mitolab_paths(path, split, val_fraction, download, discard_empty_images) 264 265 return torch_em.default_segmentation_dataset( 266 raw_paths=raw_paths, 267 raw_key=None, 268 label_paths=label_paths, 269 label_key=None, 270 patch_shape=patch_shape, 271 is_seg_dataset=False, 272 ndim=2, 273 **kwargs 274 ) 275 276 277def get_cem15m_dataset(path): 278 raise NotImplementedError 279 280 281def get_benchmark_dataset( 282 path: Union[os.PathLike, str], dataset_id: int, patch_shape: Tuple[int, int], download: bool = False, **kwargs 283) -> Dataset: 284 """Get the dataset for one of the mitolab benchmark datasets. 285 286 Args: 287 path: Filepath to a folder where the downloaded data will be saved. 288 dataset_id: The id of the benchmark dataset to download. 289 patch_shape: The patch shape to use for training. 290 download: Whether to download the data if it is not present. 291 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 292 293 Returns: 294 The segmentation dataset. 295 """ 296 if dataset_id not in range(1, 8): 297 raise ValueError(f"Invalid dataset id {dataset_id}, expected id in range [1, 7].") 298 299 raw_paths, label_paths, raw_key, label_key, is_seg_dataset = get_benchmark_paths(path, dataset_id, download) 300 301 return torch_em.default_segmentation_dataset( 302 raw_paths=raw_paths, 303 raw_key=raw_key, 304 label_paths=label_paths, 305 label_key=label_key, 306 patch_shape=patch_shape, 307 is_seg_dataset=is_seg_dataset, 308 **kwargs, 309 ) 310 311 312# 313# DataLoaders 314# 315 316 317def get_mitolab_loader( 318 path: Union[os.PathLike, str], 319 split: str, 320 batch_size: int, 321 patch_shape: Tuple[int, int] = (224, 224), 322 discard_empty_images: bool = True, 323 val_fraction: float = 0.05, 324 download: bool = False, 325 **kwargs 326) -> DataLoader: 327 """Get the dataloader for the MitoLab training data. 328 329 Args: 330 path: Filepath to a folder where the downloaded data will be saved. 331 split: The data split. Either 'train' or 'val'. 332 batch_size: The batch size for training. 333 patch_shape: The patch shape to use for training. 334 discard_empty_images: Whether to discard images without annotations. 335 val_fraction: The fraction of the data to use for validation. 336 download: Whether to download the data if it is not present. 337 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 338 339 Returns: 340 The PyTorch DataLoader. 341 """ 342 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 343 dataset = get_mitolab_dataset( 344 path=path, 345 split=split, 346 patch_shape=patch_shape, 347 val_fraction=val_fraction, 348 download=download, 349 discard_empty_images=discard_empty_images, 350 **ds_kwargs 351 ) 352 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 353 354 355def get_cem15m_loader(path): 356 raise NotImplementedError 357 358 359def get_benchmark_loader( 360 path: Union[os.PathLike, str], 361 dataset_id: int, 362 batch_size: int, 363 patch_shape: Tuple[int, int], 364 download: bool = False, 365 **kwargs 366) -> DataLoader: 367 """Get the dataloader for one of the MitoLab benchmark datasets. 368 369 Args: 370 path: Filepath to a folder where the downloaded data will be saved. 371 dataset_id: The id of the benchmark dataset to download. 372 batch_size: The batch size for training. 373 patch_shape: The patch shape to use for training. 374 download: Whether to download the data if it is not present. 375 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 376 377 Returns: 378 The DataLoader. 379 """ 380 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 381 dataset = get_benchmark_dataset(path, dataset_id, patch_shape=patch_shape, download=download, **ds_kwargs) 382 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
115def get_mitolab_data(path: Union[os.PathLike, str], download: bool = False) -> str: 116 """Download the MitoLab training data. 117 118 Args: 119 path: Filepath to a folder where the downloaded data will be saved. 120 download: Whether to download the data if it is not present. 121 122 Returns: 123 The filepath for the downloaded data. 124 """ 125 access_id = "11037" 126 data_path = util.download_source_empiar(path, access_id, download) 127 128 zip_path = os.path.join(data_path, "data/cem_mitolab.zip") 129 if os.path.exists(zip_path): 130 util.unzip(zip_path, data_path, remove=True) 131 132 data_root = os.path.join(data_path, "cem_mitolab") 133 assert os.path.exists(data_root) 134 135 return data_root
Download the MitoLab training data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- download: Whether to download the data if it is not present.
Returns:
The filepath for the downloaded data.
138def get_mitolab_paths( 139 path: Union[os.PathLike, str], 140 split: Literal['train', 'val'], 141 val_fraction: float = 0.05, 142 download: bool = False, 143 discard_empty_images: bool = True, 144) -> Tuple[List[str], List[str]]: 145 """Get the paths to MitoLab training data. 146 147 Args: 148 path: Filepath to a folder where the downloaded data will be saved. 149 split: The data split. Either 'train' or 'val'. 150 val_fraction: The fraction of the data to use for validation. 151 download: Whether to download the data if it is not present. 152 discard_empty_images: Whether to discard images without annotations. 153 154 Returns: 155 List of the image data paths. 156 List of the label data paths. 157 """ 158 data_path = get_mitolab_data(path, download) 159 160 if discard_empty_images: 161 raw_paths, label_paths = _get_non_empty_images(data_path) 162 else: 163 raw_paths, label_paths = _get_all_images(data_path) 164 165 if split is not None: 166 raw_train, raw_val, labels_train, labels_val = train_test_split( 167 raw_paths, label_paths, test_size=val_fraction, random_state=42, 168 ) 169 if split == "train": 170 raw_paths, label_paths = raw_train, labels_train 171 else: 172 raw_paths, label_paths = raw_val, labels_val 173 174 assert len(raw_paths) > 0 175 assert len(raw_paths) == len(label_paths) 176 return raw_paths, label_paths
Get the paths to MitoLab training data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The data split. Either 'train' or 'val'.
- val_fraction: The fraction of the data to use for validation.
- download: Whether to download the data if it is not present.
- discard_empty_images: Whether to discard images without annotations.
Returns:
List of the image data paths. List of the label data paths.
179def get_benchmark_data(path: Union[os.PathLike, str], dataset_id: int, download: bool = False) -> str: 180 """Download the MitoLab benchmark data. 181 182 Args: 183 path: Filepath to a folder where the downloaded data will be saved. 184 dataset_id: The id of the benchmark dataset to download. 185 download: Whether to download the data if it is not present. 186 187 Returns: 188 The filepath for the stored data. 189 """ 190 access_id = "10982" 191 data_path = util.download_source_empiar(path, access_id, download) 192 dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id]) 193 return dataset_path
Download the MitoLab benchmark data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: The id of the benchmark dataset to download.
- download: Whether to download the data if it is not present.
Returns:
The filepath for the stored data.
196def get_benchmark_paths( 197 path: Union[os.PathLike, str], dataset_id: int, download: bool = False 198) -> Tuple[List[str], List[str], str, str, bool]: 199 """Get paths to the MitoLab benchmark data. 200 201 Args: 202 path: Filepath to a folder where the downloaded data will be saved. 203 dataset_id: The id of the benchmark dataset to download. 204 download: Whether to download the data if it is not present. 205 206 Returns: 207 List of the image data paths. 208 List of the label data paths. 209 The image data key. 210 The label data key. 211 Whether this is a segmentation dataset. 212 """ 213 dataset_path = get_benchmark_data(path, dataset_id, download) 214 215 # these are the 3d datasets 216 if dataset_id in range(1, 7): 217 dataset_name = os.path.basename(dataset_path) 218 raw_paths = os.path.join(dataset_path, f"{dataset_name}_em.tif") 219 label_paths = os.path.join(dataset_path, f"{dataset_name}_mito.tif") 220 raw_key, label_key = None, None 221 is_seg_dataset = True 222 223 # this is the 2d dataset 224 else: 225 raw_paths = os.path.join(dataset_path, "images") 226 label_paths = os.path.join(dataset_path, "masks") 227 raw_key, label_key = "*.tiff", "*.tiff" 228 is_seg_dataset = False 229 230 return raw_paths, label_paths, raw_key, label_key, is_seg_dataset
Get paths to the MitoLab benchmark data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: The id of the benchmark dataset to download.
- download: Whether to download the data if it is not present.
Returns:
List of the image data paths. List of the label data paths. The image data key. The label data key. Whether this is a segmentation dataset.
238def get_mitolab_dataset( 239 path: Union[os.PathLike, str], 240 split: Literal['train', 'val'], 241 patch_shape: Tuple[int, int] = (224, 224), 242 val_fraction: float = 0.05, 243 download: bool = False, 244 discard_empty_images: bool = True, 245 **kwargs 246) -> Dataset: 247 """Get the dataset for the MitoLab training data. 248 249 Args: 250 path: Filepath to a folder where the downloaded data will be saved. 251 split: The data split. Either 'train' or 'val'. 252 patch_shape: The patch shape to use for training. 253 val_fraction: The fraction of the data to use for validation. 254 download: Whether to download the data if it is not present. 255 discard_empty_images: Whether to discard images without annotations. 256 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 257 258 Returns: 259 The segmentation dataset. 260 """ 261 assert split in ("train", "val", None) 262 assert os.path.exists(path) 263 264 raw_paths, label_paths = get_mitolab_paths(path, split, val_fraction, download, discard_empty_images) 265 266 return torch_em.default_segmentation_dataset( 267 raw_paths=raw_paths, 268 raw_key=None, 269 label_paths=label_paths, 270 label_key=None, 271 patch_shape=patch_shape, 272 is_seg_dataset=False, 273 ndim=2, 274 **kwargs 275 )
Get the dataset for the MitoLab training data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The data split. Either 'train' or 'val'.
- patch_shape: The patch shape to use for training.
- val_fraction: The fraction of the data to use for validation.
- download: Whether to download the data if it is not present.
- discard_empty_images: Whether to discard images without annotations.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
282def get_benchmark_dataset( 283 path: Union[os.PathLike, str], dataset_id: int, patch_shape: Tuple[int, int], download: bool = False, **kwargs 284) -> Dataset: 285 """Get the dataset for one of the mitolab benchmark datasets. 286 287 Args: 288 path: Filepath to a folder where the downloaded data will be saved. 289 dataset_id: The id of the benchmark dataset to download. 290 patch_shape: The patch shape to use for training. 291 download: Whether to download the data if it is not present. 292 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 293 294 Returns: 295 The segmentation dataset. 296 """ 297 if dataset_id not in range(1, 8): 298 raise ValueError(f"Invalid dataset id {dataset_id}, expected id in range [1, 7].") 299 300 raw_paths, label_paths, raw_key, label_key, is_seg_dataset = get_benchmark_paths(path, dataset_id, download) 301 302 return torch_em.default_segmentation_dataset( 303 raw_paths=raw_paths, 304 raw_key=raw_key, 305 label_paths=label_paths, 306 label_key=label_key, 307 patch_shape=patch_shape, 308 is_seg_dataset=is_seg_dataset, 309 **kwargs, 310 )
Get the dataset for one of the mitolab benchmark datasets.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: The id of the benchmark dataset to download.
- patch_shape: The patch shape to use for training.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
318def get_mitolab_loader( 319 path: Union[os.PathLike, str], 320 split: str, 321 batch_size: int, 322 patch_shape: Tuple[int, int] = (224, 224), 323 discard_empty_images: bool = True, 324 val_fraction: float = 0.05, 325 download: bool = False, 326 **kwargs 327) -> DataLoader: 328 """Get the dataloader for the MitoLab training data. 329 330 Args: 331 path: Filepath to a folder where the downloaded data will be saved. 332 split: The data split. Either 'train' or 'val'. 333 batch_size: The batch size for training. 334 patch_shape: The patch shape to use for training. 335 discard_empty_images: Whether to discard images without annotations. 336 val_fraction: The fraction of the data to use for validation. 337 download: Whether to download the data if it is not present. 338 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 339 340 Returns: 341 The PyTorch DataLoader. 342 """ 343 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 344 dataset = get_mitolab_dataset( 345 path=path, 346 split=split, 347 patch_shape=patch_shape, 348 val_fraction=val_fraction, 349 download=download, 350 discard_empty_images=discard_empty_images, 351 **ds_kwargs 352 ) 353 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the dataloader for the MitoLab training data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The data split. Either 'train' or 'val'.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- discard_empty_images: Whether to discard images without annotations.
- val_fraction: The fraction of the data to use for validation.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The PyTorch DataLoader.
360def get_benchmark_loader( 361 path: Union[os.PathLike, str], 362 dataset_id: int, 363 batch_size: int, 364 patch_shape: Tuple[int, int], 365 download: bool = False, 366 **kwargs 367) -> DataLoader: 368 """Get the dataloader for one of the MitoLab benchmark datasets. 369 370 Args: 371 path: Filepath to a folder where the downloaded data will be saved. 372 dataset_id: The id of the benchmark dataset to download. 373 batch_size: The batch size for training. 374 patch_shape: The patch shape to use for training. 375 download: Whether to download the data if it is not present. 376 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 377 378 Returns: 379 The DataLoader. 380 """ 381 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 382 dataset = get_benchmark_dataset(path, dataset_id, patch_shape=patch_shape, download=download, **ds_kwargs) 383 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the dataloader for one of the MitoLab benchmark datasets.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- dataset_id: The id of the benchmark dataset to download.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.