torch_em.data.datasets.electron_microscopy.cem

The CEM, or MitoLab, dataset is a collection of data for training mitochondria generalist models. It consists of:

These datasets are from the publication https://doi.org/10.1016/j.cels.2022.12.006. Please cite this publication if you use this data in your research.

The data itself can be downloaded from EMPIAR via aspera.

  • You can install aspera via mamba. We recommend to do this in a separate environment to avoid dependency issues:
    • $ mamba create -c conda-forge -c hcc -n aspera aspera-cli
  • After this you can run $ mamba activate aspera to have an environment with aspera installed.
  • You can then download the data for one of the three datasets like this:
    • ascp -QT -l 200m -P33001 -i /etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/
    • Where is the path to the mamba environment, the id of one of the three datasets and where you want to download the data.
  • After this you can use the functions in this file if you use as location for the data.

Note that we have implemented automatic download, but this leads to dependency issues, so we recommend to download the data manually and then run the loaders with the correct path.

  1"""The CEM, or MitoLab, dataset is a collection of data for
  2training mitochondria generalist models. It consists of:
  3- CEM-MitoLab: annotated 2d data for training mitochondria segmentation models
  4  - https://www.ebi.ac.uk/empiar/EMPIAR-11037/
  5- CEM-Mito-Benchmark: 7 Benchmark datasets for mitochondria segmentation
  6  - https://www.ebi.ac.uk/empiar/EMPIAR-10982/
  7- CEM-1.5M: unlabeled EM images for pretraining: (Not yet implemented)
  8  - https://www.ebi.ac.uk/empiar/EMPIAR-11035/
  9
 10These datasets are from the publication https://doi.org/10.1016/j.cels.2022.12.006.
 11Please cite this publication if you use this data in your research.
 12
 13The data itself can be downloaded from EMPIAR via aspera.
 14- You can install aspera via mamba. We recommend to do this in a separate environment
 15  to avoid dependency issues:
 16    - `$ mamba create -c conda-forge -c hcc -n aspera aspera-cli`
 17- After this you can run `$ mamba activate aspera` to have an environment with aspera installed.
 18- You can then download the data for one of the three datasets like this:
 19    - ascp -QT -l 200m -P33001 -i <PREFIX>/etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/<EMPIAR_ID> <PATH>
 20    - Where <PREFIX> is the path to the mamba environment, <EMPIAR_ID> the id of one of the three datasets
 21      and <PATH> where you want to download the data.
 22- After this you can use the functions in this file if you use <PATH> as location for the data.
 23
 24Note that we have implemented automatic download, but this leads to dependency
 25issues, so we recommend to download the data manually and then run the loaders with the correct path.
 26"""
 27
 28import json
 29import os
 30from glob import glob
 31from typing import List, Tuple, Union
 32
 33import imageio.v3 as imageio
 34import numpy as np
 35import torch_em
 36from sklearn.model_selection import train_test_split
 37from torch.utils.data import Dataset, DataLoader
 38
 39from .. import util
 40
 41BENCHMARK_DATASETS = {
 42    1: "mito_benchmarks/c_elegans",
 43    2: "mito_benchmarks/fly_brain",
 44    3: "mito_benchmarks/glycolytic_muscle",
 45    4: "mito_benchmarks/hela_cell",
 46    5: "mito_benchmarks/lucchi_pp",
 47    6: "mito_benchmarks/salivary_gland",
 48    7: "tem_benchmark",
 49}
 50BENCHMARK_SHAPES = {
 51    1: (256, 256, 256),
 52    2: (256, 255, 255),
 53    3: (302, 383, 765),
 54    4: (256, 256, 256),
 55    5: (165, 768, 1024),
 56    6: (1260, 1081, 1200),
 57    7: (224, 224),  # NOTE: this is the minimal square shape that fits
 58}
 59
 60
 61def _get_mitolab_data(path, download):
 62    access_id = "11037"
 63    data_path = util.download_source_empiar(path, access_id, download)
 64
 65    zip_path = os.path.join(data_path, "data/cem_mitolab.zip")
 66    if os.path.exists(zip_path):
 67        util.unzip(zip_path, data_path, remove=True)
 68
 69    data_root = os.path.join(data_path, "cem_mitolab")
 70    assert os.path.exists(data_root)
 71
 72    return data_root
 73
 74
 75def _get_all_images(path):
 76    raw_paths, label_paths = [], []
 77    folders = glob(os.path.join(path, "*"))
 78    assert all(os.path.isdir(folder) for folder in folders)
 79    for folder in folders:
 80        images = sorted(glob(os.path.join(folder, "images", "*.tiff")))
 81        assert len(images) > 0
 82        labels = sorted(glob(os.path.join(folder, "masks", "*.tiff")))
 83        assert len(images) == len(labels)
 84        raw_paths.extend(images)
 85        label_paths.extend(labels)
 86    return raw_paths, label_paths
 87
 88
 89def _get_non_empty_images(path):
 90    save_path = os.path.join(path, "non_empty_images.json")
 91
 92    if os.path.exists(save_path):
 93        with open(save_path, "r") as f:
 94            saved_images = json.load(f)
 95        raw_paths, label_paths = saved_images["images"], saved_images["labels"]
 96        raw_paths = [os.path.join(path, rp) for rp in raw_paths]
 97        label_paths = [os.path.join(path, lp) for lp in label_paths]
 98        return raw_paths, label_paths
 99
100    folders = glob(os.path.join(path, "*"))
101    assert all(os.path.isdir(folder) for folder in folders)
102
103    raw_paths, label_paths = [], []
104    for folder in folders:
105        images = sorted(glob(os.path.join(folder, "images", "*.tiff")))
106        labels = sorted(glob(os.path.join(folder, "masks", "*.tiff")))
107        assert len(images) > 0
108        assert len(images) == len(labels)
109
110        for im, lab in zip(images, labels):
111            n_labels = len(np.unique(imageio.imread(lab)))
112            if n_labels > 1:
113                raw_paths.append(im)
114                label_paths.append(lab)
115
116    raw_paths_rel = [os.path.relpath(rp, path) for rp in raw_paths]
117    label_paths_rel = [os.path.relpath(lp, path) for lp in label_paths]
118
119    with open(save_path, "w") as f:
120        json.dump({"images": raw_paths_rel, "labels": label_paths_rel}, f)
121
122    return raw_paths, label_paths
123
124
125def get_mitolab_data(
126    path: Union[os.PathLike, str],
127    split: str,
128    val_fraction: float,
129    download: bool,
130    discard_empty_images: bool
131) -> Tuple[List[str], List[str]]:
132    """Download the mitolab training data.
133
134    Args:
135        path: Filepath to a folder where the downloaded data will be saved.
136        split: The data split. Either 'train' or 'val'.
137        val_fraction: The fraction of the data to use for validation.
138        download: Whether to download the data if it is not present.
139        discard_empty_images: Whether to discard images without annotations.
140
141    Returns:
142        List of the image data paths.
143        List of the label data paths.
144    """
145    data_path = _get_mitolab_data(path, download)
146    if discard_empty_images:
147        raw_paths, label_paths = _get_non_empty_images(data_path)
148    else:
149        raw_paths, label_paths = _get_all_images(data_path)
150
151    if split is not None:
152        raw_train, raw_val, labels_train, labels_val = train_test_split(
153            raw_paths, label_paths, test_size=val_fraction, random_state=42,
154        )
155        if split == "train":
156            raw_paths, label_paths = raw_train, labels_train
157        else:
158            raw_paths, label_paths = raw_val, labels_val
159
160    assert len(raw_paths) > 0
161    assert len(raw_paths) == len(label_paths)
162    return raw_paths, label_paths
163
164
165def get_benchmark_data(
166    path: Union[os.PathLike, str],
167    dataset_id: int,
168    download: bool
169) -> Tuple[
170    List[str], List[str], str, str, bool
171]:
172    """Download the mitolab benechmark data.
173
174    Args:
175        path: Filepath to a folder where the downloaded data will be saved.
176        dataset_id: The id of the benchmark dataset to download.
177        download: Whether to download the data if it is not present.
178
179    Returns:
180        List of the image data paths.
181        List of the label data paths.
182        The image data key.
183        The label data key.
184        Whether this is a segmentation dataset.
185    """
186    access_id = "10982"
187    data_path = util.download_source_empiar(path, access_id, download)
188    dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id])
189
190    # these are the 3d datasets
191    if dataset_id in range(1, 7):
192        dataset_name = os.path.basename(dataset_path)
193        raw_paths = os.path.join(dataset_path, f"{dataset_name}_em.tif")
194        label_paths = os.path.join(dataset_path, f"{dataset_name}_mito.tif")
195        raw_key, label_key = None, None
196        is_seg_dataset = True
197
198    # this is the 2d dataset
199    else:
200        raw_paths = os.path.join(dataset_path, "images")
201        label_paths = os.path.join(dataset_path, "masks")
202        raw_key, label_key = "*.tiff", "*.tiff"
203        is_seg_dataset = False
204
205    return raw_paths, label_paths, raw_key, label_key, is_seg_dataset
206
207
208#
209# Datasets
210#
211
212
213def get_mitolab_dataset(
214    path: Union[os.PathLike, str],
215    split: str,
216    patch_shape: Tuple[int, int] = (224, 224),
217    val_fraction: float = 0.05,
218    download: bool = False,
219    discard_empty_images: bool = True,
220    **kwargs
221) -> Dataset:
222    """Get the dataset for the mitolab training data.
223
224    Args:
225        path: Filepath to a folder where the downloaded data will be saved.
226        split: The data split. Either 'train' or 'val'.
227        patch_shape: The patch shape to use for training.
228        val_fraction: The fraction of the data to use for validation.
229        download: Whether to download the data if it is not present.
230        discard_empty_images: Whether to discard images without annotations.
231        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
232
233    Returns:
234        The segmentation dataset.
235    """
236    assert split in ("train", "val", None)
237    assert os.path.exists(path)
238    raw_paths, label_paths = get_mitolab_data(path, split, val_fraction, download, discard_empty_images)
239    return torch_em.default_segmentation_dataset(
240        raw_paths=raw_paths, raw_key=None,
241        label_paths=label_paths, label_key=None,
242        patch_shape=patch_shape, is_seg_dataset=False, ndim=2, **kwargs
243    )
244
245
246def get_cem15m_dataset(path):
247    raise NotImplementedError
248
249
250def get_benchmark_dataset(
251    path,
252    dataset_id,
253    patch_shape,
254    download=False,
255    **kwargs,
256) -> Dataset:
257    """Get the dataset for one of the mitolab benchmark datasets.
258
259    Args:
260        path: Filepath to a folder where the downloaded data will be saved.
261        dataset_id: The id of the benchmark dataset to download.
262        patch_shape: The patch shape to use for training.
263        download: Whether to download the data if it is not present.
264        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
265
266    Returns:
267        The segmentation dataset.
268    """
269    if dataset_id not in range(1, 8):
270        raise ValueError(f"Invalid dataset id {dataset_id}, expected id in range [1, 7].")
271    raw_paths, label_paths, raw_key, label_key, is_seg_dataset = get_benchmark_data(path, dataset_id, download)
272    return torch_em.default_segmentation_dataset(
273        raw_paths=raw_paths, raw_key=raw_key,
274        label_paths=label_paths, label_key=label_key,
275        patch_shape=patch_shape,
276        is_seg_dataset=is_seg_dataset, **kwargs,
277    )
278
279
280#
281# DataLoaders
282#
283
284
285def get_mitolab_loader(
286    path: Union[os.PathLike, str],
287    split: str,
288    batch_size: int,
289    patch_shape: Tuple[int, int] = (224, 224),
290    discard_empty_images: bool = True,
291    val_fraction: float = 0.05,
292    download: bool = False,
293    **kwargs
294) -> DataLoader:
295    """Get the dataloader for the mitolab training data.
296
297    Args:
298        path: Filepath to a folder where the downloaded data will be saved.
299        split: The data split. Either 'train' or 'val'.
300        batch_size: The batch size for training.
301        patch_shape: The patch shape to use for training.
302        discard_empty_images: Whether to discard images without annotations.
303        val_fraction: The fraction of the data to use for validation.
304        download: Whether to download the data if it is not present.
305        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
306
307    Returns:
308        The PyTorch DataLoader.
309    """
310    ds_kwargs, loader_kwargs = util.split_kwargs(
311        torch_em.default_segmentation_dataset, **kwargs
312    )
313    dataset = get_mitolab_dataset(
314        path, split, patch_shape, download=download, discard_empty_images=discard_empty_images, **ds_kwargs
315    )
316    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
317    return loader
318
319
320def get_cem15m_loader(path):
321    raise NotImplementedError
322
323
324def get_benchmark_loader(
325    path: Union[os.PathLike, str],
326    dataset_id: int,
327    batch_size: int,
328    patch_shape: Tuple[int, int],
329    download: bool = False,
330    **kwargs
331) -> DataLoader:
332    """Get the datasloader for one of the mitolab benchmark datasets.
333
334    Args:
335        path: Filepath to a folder where the downloaded data will be saved.
336        dataset_id: The id of the benchmark dataset to download.
337        batch_size: The batch size for training.
338        patch_shape: The patch shape to use for training.
339        download: Whether to download the data if it is not present.
340        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
341
342    Returns:
343        The DataLoader.
344    """
345    ds_kwargs, loader_kwargs = util.split_kwargs(
346        torch_em.default_segmentation_dataset, **kwargs
347    )
348    dataset = get_benchmark_dataset(
349        path, dataset_id,
350        patch_shape=patch_shape, download=download, **ds_kwargs
351    )
352    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
353    return loader
BENCHMARK_DATASETS = {1: 'mito_benchmarks/c_elegans', 2: 'mito_benchmarks/fly_brain', 3: 'mito_benchmarks/glycolytic_muscle', 4: 'mito_benchmarks/hela_cell', 5: 'mito_benchmarks/lucchi_pp', 6: 'mito_benchmarks/salivary_gland', 7: 'tem_benchmark'}
BENCHMARK_SHAPES = {1: (256, 256, 256), 2: (256, 255, 255), 3: (302, 383, 765), 4: (256, 256, 256), 5: (165, 768, 1024), 6: (1260, 1081, 1200), 7: (224, 224)}
def get_mitolab_data( path: Union[os.PathLike, str], split: str, val_fraction: float, download: bool, discard_empty_images: bool) -> Tuple[List[str], List[str]]:
126def get_mitolab_data(
127    path: Union[os.PathLike, str],
128    split: str,
129    val_fraction: float,
130    download: bool,
131    discard_empty_images: bool
132) -> Tuple[List[str], List[str]]:
133    """Download the mitolab training data.
134
135    Args:
136        path: Filepath to a folder where the downloaded data will be saved.
137        split: The data split. Either 'train' or 'val'.
138        val_fraction: The fraction of the data to use for validation.
139        download: Whether to download the data if it is not present.
140        discard_empty_images: Whether to discard images without annotations.
141
142    Returns:
143        List of the image data paths.
144        List of the label data paths.
145    """
146    data_path = _get_mitolab_data(path, download)
147    if discard_empty_images:
148        raw_paths, label_paths = _get_non_empty_images(data_path)
149    else:
150        raw_paths, label_paths = _get_all_images(data_path)
151
152    if split is not None:
153        raw_train, raw_val, labels_train, labels_val = train_test_split(
154            raw_paths, label_paths, test_size=val_fraction, random_state=42,
155        )
156        if split == "train":
157            raw_paths, label_paths = raw_train, labels_train
158        else:
159            raw_paths, label_paths = raw_val, labels_val
160
161    assert len(raw_paths) > 0
162    assert len(raw_paths) == len(label_paths)
163    return raw_paths, label_paths

Download the mitolab training data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split. Either 'train' or 'val'.
  • val_fraction: The fraction of the data to use for validation.
  • download: Whether to download the data if it is not present.
  • discard_empty_images: Whether to discard images without annotations.
Returns:

List of the image data paths. List of the label data paths.

def get_benchmark_data( path: Union[os.PathLike, str], dataset_id: int, download: bool) -> Tuple[List[str], List[str], str, str, bool]:
166def get_benchmark_data(
167    path: Union[os.PathLike, str],
168    dataset_id: int,
169    download: bool
170) -> Tuple[
171    List[str], List[str], str, str, bool
172]:
173    """Download the mitolab benechmark data.
174
175    Args:
176        path: Filepath to a folder where the downloaded data will be saved.
177        dataset_id: The id of the benchmark dataset to download.
178        download: Whether to download the data if it is not present.
179
180    Returns:
181        List of the image data paths.
182        List of the label data paths.
183        The image data key.
184        The label data key.
185        Whether this is a segmentation dataset.
186    """
187    access_id = "10982"
188    data_path = util.download_source_empiar(path, access_id, download)
189    dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id])
190
191    # these are the 3d datasets
192    if dataset_id in range(1, 7):
193        dataset_name = os.path.basename(dataset_path)
194        raw_paths = os.path.join(dataset_path, f"{dataset_name}_em.tif")
195        label_paths = os.path.join(dataset_path, f"{dataset_name}_mito.tif")
196        raw_key, label_key = None, None
197        is_seg_dataset = True
198
199    # this is the 2d dataset
200    else:
201        raw_paths = os.path.join(dataset_path, "images")
202        label_paths = os.path.join(dataset_path, "masks")
203        raw_key, label_key = "*.tiff", "*.tiff"
204        is_seg_dataset = False
205
206    return raw_paths, label_paths, raw_key, label_key, is_seg_dataset

Download the mitolab benechmark data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • dataset_id: The id of the benchmark dataset to download.
  • download: Whether to download the data if it is not present.
Returns:

List of the image data paths. List of the label data paths. The image data key. The label data key. Whether this is a segmentation dataset.

def get_mitolab_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int] = (224, 224), val_fraction: float = 0.05, download: bool = False, discard_empty_images: bool = True, **kwargs) -> torch.utils.data.dataset.Dataset:
214def get_mitolab_dataset(
215    path: Union[os.PathLike, str],
216    split: str,
217    patch_shape: Tuple[int, int] = (224, 224),
218    val_fraction: float = 0.05,
219    download: bool = False,
220    discard_empty_images: bool = True,
221    **kwargs
222) -> Dataset:
223    """Get the dataset for the mitolab training data.
224
225    Args:
226        path: Filepath to a folder where the downloaded data will be saved.
227        split: The data split. Either 'train' or 'val'.
228        patch_shape: The patch shape to use for training.
229        val_fraction: The fraction of the data to use for validation.
230        download: Whether to download the data if it is not present.
231        discard_empty_images: Whether to discard images without annotations.
232        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
233
234    Returns:
235        The segmentation dataset.
236    """
237    assert split in ("train", "val", None)
238    assert os.path.exists(path)
239    raw_paths, label_paths = get_mitolab_data(path, split, val_fraction, download, discard_empty_images)
240    return torch_em.default_segmentation_dataset(
241        raw_paths=raw_paths, raw_key=None,
242        label_paths=label_paths, label_key=None,
243        patch_shape=patch_shape, is_seg_dataset=False, ndim=2, **kwargs
244    )

Get the dataset for the mitolab training data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split. Either 'train' or 'val'.
  • patch_shape: The patch shape to use for training.
  • val_fraction: The fraction of the data to use for validation.
  • download: Whether to download the data if it is not present.
  • discard_empty_images: Whether to discard images without annotations.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_cem15m_dataset(path):
247def get_cem15m_dataset(path):
248    raise NotImplementedError
def get_benchmark_dataset( path, dataset_id, patch_shape, download=False, **kwargs) -> torch.utils.data.dataset.Dataset:
251def get_benchmark_dataset(
252    path,
253    dataset_id,
254    patch_shape,
255    download=False,
256    **kwargs,
257) -> Dataset:
258    """Get the dataset for one of the mitolab benchmark datasets.
259
260    Args:
261        path: Filepath to a folder where the downloaded data will be saved.
262        dataset_id: The id of the benchmark dataset to download.
263        patch_shape: The patch shape to use for training.
264        download: Whether to download the data if it is not present.
265        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
266
267    Returns:
268        The segmentation dataset.
269    """
270    if dataset_id not in range(1, 8):
271        raise ValueError(f"Invalid dataset id {dataset_id}, expected id in range [1, 7].")
272    raw_paths, label_paths, raw_key, label_key, is_seg_dataset = get_benchmark_data(path, dataset_id, download)
273    return torch_em.default_segmentation_dataset(
274        raw_paths=raw_paths, raw_key=raw_key,
275        label_paths=label_paths, label_key=label_key,
276        patch_shape=patch_shape,
277        is_seg_dataset=is_seg_dataset, **kwargs,
278    )

Get the dataset for one of the mitolab benchmark datasets.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • dataset_id: The id of the benchmark dataset to download.
  • patch_shape: The patch shape to use for training.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_mitolab_loader( path: Union[os.PathLike, str], split: str, batch_size: int, patch_shape: Tuple[int, int] = (224, 224), discard_empty_images: bool = True, val_fraction: float = 0.05, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
286def get_mitolab_loader(
287    path: Union[os.PathLike, str],
288    split: str,
289    batch_size: int,
290    patch_shape: Tuple[int, int] = (224, 224),
291    discard_empty_images: bool = True,
292    val_fraction: float = 0.05,
293    download: bool = False,
294    **kwargs
295) -> DataLoader:
296    """Get the dataloader for the mitolab training data.
297
298    Args:
299        path: Filepath to a folder where the downloaded data will be saved.
300        split: The data split. Either 'train' or 'val'.
301        batch_size: The batch size for training.
302        patch_shape: The patch shape to use for training.
303        discard_empty_images: Whether to discard images without annotations.
304        val_fraction: The fraction of the data to use for validation.
305        download: Whether to download the data if it is not present.
306        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
307
308    Returns:
309        The PyTorch DataLoader.
310    """
311    ds_kwargs, loader_kwargs = util.split_kwargs(
312        torch_em.default_segmentation_dataset, **kwargs
313    )
314    dataset = get_mitolab_dataset(
315        path, split, patch_shape, download=download, discard_empty_images=discard_empty_images, **ds_kwargs
316    )
317    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
318    return loader

Get the dataloader for the mitolab training data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split. Either 'train' or 'val'.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • discard_empty_images: Whether to discard images without annotations.
  • val_fraction: The fraction of the data to use for validation.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The PyTorch DataLoader.

def get_cem15m_loader(path):
321def get_cem15m_loader(path):
322    raise NotImplementedError
def get_benchmark_loader( path: Union[os.PathLike, str], dataset_id: int, batch_size: int, patch_shape: Tuple[int, int], download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
325def get_benchmark_loader(
326    path: Union[os.PathLike, str],
327    dataset_id: int,
328    batch_size: int,
329    patch_shape: Tuple[int, int],
330    download: bool = False,
331    **kwargs
332) -> DataLoader:
333    """Get the datasloader for one of the mitolab benchmark datasets.
334
335    Args:
336        path: Filepath to a folder where the downloaded data will be saved.
337        dataset_id: The id of the benchmark dataset to download.
338        batch_size: The batch size for training.
339        patch_shape: The patch shape to use for training.
340        download: Whether to download the data if it is not present.
341        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
342
343    Returns:
344        The DataLoader.
345    """
346    ds_kwargs, loader_kwargs = util.split_kwargs(
347        torch_em.default_segmentation_dataset, **kwargs
348    )
349    dataset = get_benchmark_dataset(
350        path, dataset_id,
351        patch_shape=patch_shape, download=download, **ds_kwargs
352    )
353    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
354    return loader

Get the datasloader for one of the mitolab benchmark datasets.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • dataset_id: The id of the benchmark dataset to download.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.