torch_em.data.datasets.electron_microscopy.cellmap

CellMap is a dataset for segmenting various organelles in electron microscopy. It contains a large amount of annotation crops from several species. This dataset is released for the CellMap Segmentation Challenge: https://cellmapchallenge.janelia.org/.

Please cite them if you use this data for your research.

  1"""CellMap is a dataset for segmenting various organelles in electron microscopy.
  2It contains a large amount of annotation crops from several species.
  3This dataset is released for the `CellMap Segmentation Challenge`: https://cellmapchallenge.janelia.org/.
  4- Official documentation: https://janelia-cellmap.github.io/cellmap-segmentation-challenge/.
  5- Original GitHub repository for the toolbox: https://github.com/janelia-cellmap/cellmap-segmentation-challenge.
  6- And associated collection doi for the data: https://doi.org/10.25378/janelia.c.7456966.
  7
  8Please cite them if you use this data for your research.
  9"""
 10
 11import os
 12import time
 13import warnings
 14from pathlib import Path
 15from threading import Lock
 16from typing import Union, Optional, Tuple, List, Sequence
 17from concurrent.futures import ThreadPoolExecutor, as_completed
 18
 19import h5py
 20import numpy as np
 21from xarray import DataArray
 22
 23from torch.utils.data import Dataset, DataLoader
 24
 25import torch_em
 26
 27from elf.io import open_file
 28
 29from .. import util
 30
 31
 32def _download_cellmap_data(path, crops, resolution, padding, download=False):
 33    """Download scripts for the CellMap data.
 34    
 35    Inspired by https://github.com/janelia-cellmap/cellmap-segmentation-challenge/blob/main/src/cellmap_segmentation_challenge/cli/fetch_data.py
 36
 37    NOTE: The download scripts below are intended to stay as close to the original `fetch-data` CLI,
 38    in order to ensure easy syncing with any changes to the original repository in future.
 39    """  # noqa
 40
 41    # Importing packages locally.
 42    # NOTE: Keeping the relevant imports here to avoid `torch-em` throwing missing module error.
 43
 44    try:
 45        from cellmap_segmentation_challenge.utils.fetch_data import read_group, subset_to_slice
 46        from cellmap_segmentation_challenge.utils.crops import fetch_crop_manifest, get_test_crops, TestCropRow
 47    except ImportError:
 48        raise ModuleNotFoundError(
 49            "Please install 'cellmap_segmentation_challenge' package using "
 50            "'pip install git+https://github.com/janelia-cellmap/cellmap-segmentation-challenge.git'."
 51        )
 52
 53    # The imports below will come with the above lines of 'csc' installation.
 54    import structlog
 55    from xarray_ome_ngff import read_multiscale_group
 56    from xarray_ome_ngff.v04.multiscale import transforms_from_coords
 57
 58    # Some important stuff.
 59    fetch_save_start = time.time()
 60    log = structlog.get_logger()
 61    array_wrapper = {"name": "dask_array", "config": {"chunks": "auto"}}
 62
 63    # Get the absolute path location to store crops.
 64    dest_path_abs = Path(path).absolute()
 65    dest_path_abs.mkdir(exist_ok=True)
 66
 67    # Get the entire crop manifest.
 68    crops_from_manifest = fetch_crop_manifest()
 69
 70    # Get the desired crop info from the manifest.
 71    if crops == "all":
 72        crops_parsed = crops_from_manifest
 73    elif crops == "test":
 74        crops_parsed = get_test_crops()
 75        log.info(f"Found '{len(crops_parsed)}' test crops.")
 76    else:  # Otherwise, custom crops are parsed.
 77        crops_split = tuple(int(x) for x in crops.split(","))
 78        crops_parsed = tuple(filter(lambda v: v.id in crops_split, crops_from_manifest))
 79
 80    # Now get the crop ids.
 81    if len(crops_parsed) == 0:
 82        log.info(f"No crops found matching '{crops}'. Doing nothing.")
 83        return
 84
 85    crop_ids = tuple(c.id for c in crops_parsed)
 86    log.info(f"Preparing to copy the following crops: '{crop_ids}'.")
 87    log.info(f"Data will be saved to '{dest_path_abs}'.")
 88
 89    all_crops = []
 90    for crop in crops_parsed:
 91        log = log.bind(crop_id=crop.id, dataset=crop.dataset)
 92
 93        # Get the crop id to a new list for forwarding them ahead.
 94        all_crops.append(crop.id)
 95
 96        # Check whether the crop path has been downloaded already or not.
 97        crop_path = dest_path_abs / f"crop_{crop.id}.h5"
 98        if crop_path.exists():
 99            log.info(f"The crop '{crop.id}' is already saved at '{crop_path}'.")
100            log = log.unbind("crop_id", "dataset")
101            continue
102
103        # If 'download' is set to 'False', we do not go further from here.
104        if not download:
105            log.error(f"Cannot download the crop '{crop.id}' as 'download' is set to 'False'.")
106            return
107
108        # Check whether the crop is a part of the test crops, i.e. where GT masks is not available.
109        if isinstance(crop.gt_source, TestCropRow):
110            log.info(f"The test crop '{crop.id}' does not have GT data. Fetching em data only.")
111        else:
112            log.info(f"Fetching GT data for crop '{crop.id}' from '{crop.gt_source}'.")
113
114            # Get the ground-truth (gt) masks.
115            gt_source_group = read_group(str(crop.gt_source), storage_options={"anon": True})
116
117            log.info(f"Found GT data at '{crop.gt_source}'.")
118
119            # Let's get all ground-truth hierarchies.
120            # NOTE: Following same as the original repo, relying on fs.find to avoid slowness in traversing online zarr.
121            fs = gt_source_group.store.fs
122            store_path = gt_source_group.store.path
123            gt_files = fs.find(store_path)
124
125            crop_group_inventory = tuple(fn.split(store_path)[-1] for fn in gt_files)
126            crop_group_inventory = tuple(curr_cg[1:].split("/")[0] for curr_cg in crop_group_inventory)
127            crop_group_inventory = np.unique(crop_group_inventory).tolist()
128            crop_group_inventory = [
129                curr_cg for curr_cg in crop_group_inventory if curr_cg not in [".zattrs", ".zgroup"]
130            ]
131
132            # Get the offset values for the ground truth crops.
133            crop_multiscale_group = None
134            for _, group in gt_source_group.groups():
135                try:  # Get groups for all resolutions.
136                    crop_multiscale_group = read_multiscale_group(group, array_wrapper=array_wrapper)
137                    break
138                except (ValueError, TypeError):
139                    continue
140
141            if crop_multiscale_group is None:
142                log.info(f"No multiscale groups found in '{crop.gt_source}'. No EM data can be fetched.")
143                continue
144
145        # Get the EM volume group.
146        em_source_group = read_group(str(crop.em_url), storage_options={"anon": True})
147        log.info(f"Found EM data at '{crop.em_url}'.")
148
149        # Let's get the multiscale model of the source em group.
150        em_source_arrays = read_multiscale_group(em_source_group, array_wrapper)
151
152        # Next, we need to rely on the scales of each resolution to identify whether the resolution-level is same
153        # for the EM volume and corresponding ground-truth mask crops (if available).
154
155        # For this, we first extract the EM volume scales per resolution.
156        em_resolutions = {}
157        for res_key, array in em_source_arrays.items():
158            try:
159                _, (em_scale, em_translation) = transforms_from_coords(array.coords, transform_precision=4)
160                em_resolutions[res_key] = (em_scale.scale, em_translation.translation)
161            except Exception:
162                continue
163
164        if isinstance(crop.gt_source, TestCropRow):
165            # Choose the scale ratio threshold (from the original scripts)
166            ratio_threshold = 0.8  # NOTE: hard-coded atm to follow along the original data download code logic.
167
168            # Choose the matching resolution level with marked GT.
169            em_level = next(
170                (
171                    k for k, (scale, _) in em_resolutions.items()
172                    if all(s / vs > ratio_threshold for s, vs in zip(scale, crop.gt_source.voxel_size))
173                ), None
174            )
175
176            assert em_level is not None, "There has to be a scale match for the EM volume. Something went wrong."
177
178            scale = em_resolutions[em_level][0]
179            em_array = em_source_arrays[em_level]
180
181            # Get the slices (NOTE: there is info for some crop logic stuff)
182            starts = crop.gt_source.translation
183            stops = tuple(
184                start + size * vs for start, size, vs in zip(starts, crop.gt_source.shape, crop.gt_source.voxel_size)
185            )
186            coords = em_array.coords.copy()
187            for k, v in zip(em_array.coords.keys(), np.array((starts, stops)).T):
188                coords[k] = v
189
190            slices = subset_to_slice(outer_array=em_array, inner_array=DataArray(dims=em_array.dims, coords=coords))
191
192            # Set 'gt_level' to 'None' for better handling of crops without labels.
193            gt_level = None
194
195        else:
196            # Next, we extract the ground-truth scales per resolution (for labeled crops).
197            gt_resolutions = {}
198            for res_key, array in crop_multiscale_group.items():
199                try:
200                    _, (gt_scale, gt_translation) = transforms_from_coords(array.coords, transform_precision=4)
201                    gt_resolutions[res_key] = (gt_scale.scale, gt_translation.translation)
202                except Exception:
203                    continue
204
205            # Now, we find the matching scales and use the respoective "resolution" keys.
206            matching_keys = []
207            for gt_key, (gt_scale, gt_translation) in gt_resolutions.items():
208                for em_key, (em_scale, em_translation) in em_resolutions.items():
209                    if np.allclose(gt_scale, em_scale, rtol=1e-3, atol=1e-6):
210                        matching_keys.append((gt_key, em_key, gt_scale, gt_translation, em_translation))
211
212            # If no match found, that is pretty weird.
213            if not matching_keys:
214                log.error(f"No EM resolution level matches any GT scale for crop ID '{crop.id}'.")
215                continue
216
217            # We get the desired resolution level for the EM volume, labels, and the scale of choice.
218            matching_keys.sort(key=lambda x: np.prod(x[2]))
219            gt_level, em_level, scale, gt_translation, em_translation = matching_keys[0]
220
221            # Get the desired values for the particular resolution level.
222            em_array = em_source_arrays[em_level]
223            gt_crop_shape = gt_source_group[f"all/{gt_level}"].shape  # since "all" exists "al"ways, we rely on it.
224
225            log.info(f"Found a resolution match for EM data at level '{em_level}' and GT data at level '{gt_level}'.")
226
227            # Compute the input reference crop from the ground truth metadata.
228            starts = gt_translation
229            stops = [start + size * vs for start, size, vs in zip(starts, gt_crop_shape, scale)]
230
231            # Get the slices.
232            em_starts = [int(round((p_start - em_translation[i]) / scale[i])) for i, p_start in enumerate(starts)]
233            em_stops = [int(round((p_stop - em_translation[i]) / scale[i])) for i, p_stop in enumerate(stops)]
234            slices = tuple(slice(start, stop) for start, stop in zip(em_starts, em_stops))
235
236        # Pad the slices (in voxel space)
237        slices_padded = tuple(
238            slice(max(0, sl.start - padding), min(sl.stop + padding, dim), sl.step)
239            for sl, dim in zip(slices, em_array.shape)
240        )
241
242        # Extract cropped EM volume from remote zarr files.
243        em_crop = em_array[tuple(slices_padded)].data.compute()
244
245        # Write all stuff in a crop-level h5 file.
246        write_lock = Lock()
247        with h5py.File(crop_path, "w") as f:
248            # Store metadata
249            f.attrs["crop_id"] = crop.id
250            f.attrs["scale"] = scale
251            f.attrs["em_level"] = em_level
252
253            if gt_level is not None:
254                f.attrs["translation"] = gt_translation
255                f.attrs["gt_level"] = gt_level
256
257            # Store inputs.
258            f.create_dataset(name="raw_crop", data=em_crop, dtype=em_crop.dtype, compression="gzip")
259            log.info(f"Saved EM data crop for crop '{crop.id}'.")
260
261            def _fetch_and_write_label(label_name):
262                gt_crop = gt_source_group[f"{label_name}/{gt_level}"][:]
263
264                # Next, pad the labels to match the input shape.
265                def _pad_to_shape(array):
266                    return np.pad(
267                        array=array.astype(np.int16),
268                        pad_width=[
269                            (orig.start - padded.start, padded.stop - orig.stop)
270                            for orig, padded in zip(slices, slices_padded)
271                        ],
272                        mode="constant",
273                        constant_values=-1,
274                    )
275
276                gt_crop = _pad_to_shape(gt_crop)
277
278                # Write each label to their corresponding hierarchy names.
279                with write_lock:
280                    f.create_dataset(
281                        name=f"label_crop/{label_name}", data=gt_crop, dtype=gt_crop.dtype, compression="gzip"
282                    )
283                return label_name
284
285            if gt_level is not None:
286                # For this one (large) crop in particular, we store labels in serial
287                # as multiple threads cannot handle it and silently crash.
288                if crop.id == 247:
289                    for name in crop_group_inventory:
290                        _fetch_and_write_label(name)
291                        log.info(f"Saved ground truth crop '{crop.id}' for '{name}'.")
292                else:
293                    with ThreadPoolExecutor() as pool:
294                        futures = {pool.submit(_fetch_and_write_label, name): name for name in crop_group_inventory}
295                        for future in as_completed(futures):
296                            label_name = future.result()
297                            log.info(f"Saved ground truth crop '{crop.id}' for '{label_name}'.")
298
299        log.info(f"Saved crop '{crop.id}' to '{crop_path}'.")
300        log = log.unbind("crop_id", "dataset")
301
302    log.info(f"Done after {time.time() - fetch_save_start:0.3f}s")
303    log.info(f"Data saved to '{dest_path_abs}'.")
304
305    return path, all_crops
306
307
308def get_cellmap_data(
309    path: Union[os.PathLike, str],
310    crops: Union[str, Sequence[str]] = "all",
311    resolution: str = "s0",
312    padding: int = 64,
313    download: bool = False,
314) -> Tuple[str, List[str]]:
315    """Downloads the CellMap training data.
316
317    Args:
318        path: Filepath to a folder where the data will be downloaded for further processing.
319        crops: The choice of crops to download. By default, downloads `all` crops.
320            For multiple crops, provide the crop ids as a sequence of crop ids.
321        resolution: The choice of resolution in the original volumes.
322            By default, downloads the highest resolution: `s0`.
323        padding: The choice of padding along each dimensions.
324            By default, it pads '64' pixels along all dimensions.
325            You can set it to '0' for no padding at all.
326            For pixel regions without annotations, it labels the masks with id '-1'.
327        download: Whether to download the data if it is not present.
328
329    Returns:
330        Filepath where the data is stored for further processing.
331        List of crop ids.
332    """
333
334    data_path = os.path.join(path, "data_crops")
335    os.makedirs(data_path, exist_ok=True)
336
337    # Get the crops in 'csc' desired format.
338    if isinstance(crops, Sequence) and not isinstance(crops, str):  # for multiple values
339        crops = ",".join(str(c) for c in crops)
340
341    # NOTE: The function below is comparable to the CLI `csc fetch-data` from the original repo.
342    _data_path, final_crops = _download_cellmap_data(
343        path=data_path,
344        crops=crops,
345        resolution=resolution,
346        padding=padding,
347        download=download,
348    )
349
350    if _data_path is None or len(_data_path) == 0:
351        raise RuntimeError("Something went wrong. Please read the information logged above.")
352
353    assert len(final_crops) > 0, "There seems to be no valid crops in the list."
354
355    return data_path, final_crops
356
357
358def get_cellmap_paths(
359    path: Union[os.PathLike, str],
360    organelles: Optional[Union[str, List[str]]] = None,
361    crops: Union[str, Sequence[str]] = "all",
362    resolution: str = "s0",
363    voxel_size: Optional[Tuple[float]] = None,
364    padding: int = 64,
365    download: bool = False,
366    return_test_crops: bool = False,
367) -> List[str]:
368    """Get the paths to CellMap training data.
369
370    Args:
371        path: Filepath to a folder where the data will be downloaded for further processing
372        organelles: The choice of organelles to download. By default, loads all types of labels available.
373            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
374        crops: The choice of crops to download. By default, downloads `all` crops.
375            For multiple crops, provide the crop ids as a sequence of crop ids.
376        resolution: The choice of resolution in the original volumes.
377            By default, downloads the highest resolution: `s0`.
378        voxel_size: The choice of voxel size for the preprocessed crops to prepare the dataset.
379            By default, chooses all crops in scope.
380        padding: The choice of padding along each dimensions.
381            By default, it pads '64' pixels along all dimensions.
382            You can set it to '0' for no padding at all.
383            For pixel regions without annotations, it labels the masks with id '-1'.
384        download: Whether to download the data if it is not present.
385        return_test_crops: Whether to forcefully return the filepaths of the test crops for other analysis.
386
387    Returns:
388        List of the cropped volume data paths.
389    """
390
391    if not return_test_crops and ("test" in crops if isinstance(crops, (List, Tuple)) else crops == "test"):
392        raise NotImplementedError("The 'test' crops cannot be used in the dataloader.")
393
394    # Get the CellMap data crops.
395    data_path, crops = get_cellmap_data(
396        path=path, crops=crops, resolution=resolution, padding=padding, download=download,
397    )
398
399    # Get all crops.
400    volume_paths = [os.path.join(data_path, f"crop_{c}.h5") for c in crops]
401
402    # Check for valid organelles list to filter crops.
403    if organelles is None:
404        organelles = "all"
405
406    if isinstance(organelles, str):
407        organelles = [organelles]
408
409    other_volume_paths = []
410    for organelle in organelles:
411
412        if organelle != "all":
413            warnings.warn(
414                "You have chosen a different organelle annotations than 'all'. Please keep in mind "
415                f"that it is not guaranteed to provide you the correct masks for '{organelle}'. "
416                "We suggest sticking to 'all' labels and use the corresponding label ids."
417            )
418
419        for vpath in volume_paths:
420            if f"label_crop/{organelle}" in open_file(vpath) and vpath not in other_volume_paths:
421                other_volume_paths.append(vpath)
422
423    if len(other_volume_paths) == 0:
424        raise ValueError(f"'{organelles}' are not valid organelle(s) found in the crops: '{crops}'.")
425
426    # Next, we check for valid voxel size to filter crops.
427    if voxel_size is None:  # no filtering required.
428        another_volume_paths = other_volume_paths
429    else:
430        another_volume_paths = []
431        for vpath in other_volume_paths:
432            if all(np.array(voxel_size) == open_file(vpath).attrs["scale"]) and vpath not in another_volume_paths:
433                another_volume_paths.append(vpath)
434
435    if len(another_volume_paths) == 0:
436        raise ValueError(f"'{voxel_size}' is not a valid voxel size found in the crops: '{crops}'.")
437
438    # Check whether all volume paths exist.
439    for volume_path in another_volume_paths:
440        if not os.path.exists(volume_path):
441            raise FileNotFoundError(f"The volume '{volume_path}' could not be found.")
442
443    return another_volume_paths
444
445
446def get_cellmap_dataset(
447    path: Union[os.PathLike, str],
448    patch_shape: Tuple[int, ...],
449    organelles: Optional[Union[str, List[str]]] = None,
450    crops: Union[str, Sequence[str]] = "all",
451    resolution: str = "s0",
452    voxel_size: Optional[Tuple[float]] = None,
453    padding: int = 64,
454    download: bool = False,
455    **kwargs,
456) -> Dataset:
457    """Get the dataset for the CellMap training data for organelle segmentation.
458
459    Args:
460        path: Filepath to a folder where the data will be downloaded for further processing.
461        patch_shape: The patch shape to use for training.
462        organelles: The choice of organelles to download. By default, loads all types of labels available.
463            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
464        crops: The choice of crops to download. By default, downloads `all` crops.
465            For multiple crops, provide the crop ids as a sequence of crop ids.
466        resolution: The choice of resolution in the original volumes.
467            By default, downloads the highest resolution: `s0`.
468        voxel_size: The choice of voxel size for the preprocessed crops to prepare the dataset.
469            By default, chooses all crops in scope.
470        padding: The choice of padding along each dimensions.
471            By default, it pads '64' pixels along all dimensions.
472            You can set it to '0' for no padding at all.
473            For pixel regions without annotations, it labels the masks with id '-1'.
474        download: Whether to download the data if it is not present.
475        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
476
477    Returns:
478        The segmentation dataset.
479    """
480    volume_paths = get_cellmap_paths(
481        path=path,
482        organelles=organelles,
483        crops=crops,
484        resolution=resolution,
485        voxel_size=voxel_size,
486        padding=padding, download=download
487    )
488
489    # Arrange the organelle choices as expected for loading labels.
490    if organelles is None:
491        organelles = "label_crop/all"
492    else:
493        if isinstance(organelles, str):
494            organelles = f"label_crop/{organelles}"
495        else:
496            organelles = [f"label_crop/{curr_organelle}" for curr_organelle in organelles]
497            kwargs = util.update_kwargs(kwargs, "with_label_channels", True)
498
499    return torch_em.default_segmentation_dataset(
500        raw_paths=volume_paths,
501        raw_key="raw_crop",
502        label_paths=volume_paths,
503        label_key=organelles,
504        patch_shape=patch_shape,
505        is_seg_dataset=True,
506        **kwargs
507    )
508
509
510def get_cellmap_loader(
511    path: Union[os.PathLike, str],
512    batch_size: int,
513    patch_shape: Tuple[int, ...],
514    organelles: Optional[Union[str, List[str]]] = None,
515    crops: Union[str, Sequence[str]] = "all",
516    resolution: str = "s0",
517    voxel_size: Optional[Tuple[float]] = None,
518    padding: int = 64,
519    download: bool = False,
520    **kwargs,
521) -> DataLoader:
522    """Get the dataloader for the CellMap training data for organelle segmentation.
523
524    Args:
525        path: Filepath to a folder where the data will be downloaded for further processing.
526        batch_size: The batch size for training.
527        patch_shape: The patch shape to use for training.
528        organelles: The choice of organelles to download. By default, loads all types of labels available.
529            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
530        crops: The choice of crops to download. By default, downloads `all` crops.
531            For multiple crops, provide the crop ids as a sequence of crop ids.
532        resolution: The choice of resolution in the original volumes.
533            By default, downloads the highest resolution: `s0`.
534        voxel_size: The choice of voxel size for the preprocessed crops to prepare the dataset.
535            By default, chooses all crops in scope.
536        padding: The choice of padding along each dimensions.
537            By default, it pads '64' pixels along all dimensions.
538            You can set it to '0' for no padding at all.
539            For pixel regions without annotations, it labels the masks with id '-1'.
540        download: Whether to download the data if it is not present.
541        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
542
543    Returns:
544        The DataLoader.
545    """
546    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
547    dataset = get_cellmap_dataset(
548        path, patch_shape, organelles, crops, resolution, voxel_size, padding, download, **ds_kwargs
549    )
550    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
def get_cellmap_data( path: Union[os.PathLike, str], crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', padding: int = 64, download: bool = False) -> Tuple[str, List[str]]:
309def get_cellmap_data(
310    path: Union[os.PathLike, str],
311    crops: Union[str, Sequence[str]] = "all",
312    resolution: str = "s0",
313    padding: int = 64,
314    download: bool = False,
315) -> Tuple[str, List[str]]:
316    """Downloads the CellMap training data.
317
318    Args:
319        path: Filepath to a folder where the data will be downloaded for further processing.
320        crops: The choice of crops to download. By default, downloads `all` crops.
321            For multiple crops, provide the crop ids as a sequence of crop ids.
322        resolution: The choice of resolution in the original volumes.
323            By default, downloads the highest resolution: `s0`.
324        padding: The choice of padding along each dimensions.
325            By default, it pads '64' pixels along all dimensions.
326            You can set it to '0' for no padding at all.
327            For pixel regions without annotations, it labels the masks with id '-1'.
328        download: Whether to download the data if it is not present.
329
330    Returns:
331        Filepath where the data is stored for further processing.
332        List of crop ids.
333    """
334
335    data_path = os.path.join(path, "data_crops")
336    os.makedirs(data_path, exist_ok=True)
337
338    # Get the crops in 'csc' desired format.
339    if isinstance(crops, Sequence) and not isinstance(crops, str):  # for multiple values
340        crops = ",".join(str(c) for c in crops)
341
342    # NOTE: The function below is comparable to the CLI `csc fetch-data` from the original repo.
343    _data_path, final_crops = _download_cellmap_data(
344        path=data_path,
345        crops=crops,
346        resolution=resolution,
347        padding=padding,
348        download=download,
349    )
350
351    if _data_path is None or len(_data_path) == 0:
352        raise RuntimeError("Something went wrong. Please read the information logged above.")
353
354    assert len(final_crops) > 0, "There seems to be no valid crops in the list."
355
356    return data_path, final_crops

Downloads the CellMap training data.

Arguments:
  • path: Filepath to a folder where the data will be downloaded for further processing.
  • crops: The choice of crops to download. By default, downloads all crops. For multiple crops, provide the crop ids as a sequence of crop ids.
  • resolution: The choice of resolution in the original volumes. By default, downloads the highest resolution: s0.
  • padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is stored for further processing. List of crop ids.

def get_cellmap_paths( path: Union[os.PathLike, str], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', voxel_size: Optional[Tuple[float]] = None, padding: int = 64, download: bool = False, return_test_crops: bool = False) -> List[str]:
359def get_cellmap_paths(
360    path: Union[os.PathLike, str],
361    organelles: Optional[Union[str, List[str]]] = None,
362    crops: Union[str, Sequence[str]] = "all",
363    resolution: str = "s0",
364    voxel_size: Optional[Tuple[float]] = None,
365    padding: int = 64,
366    download: bool = False,
367    return_test_crops: bool = False,
368) -> List[str]:
369    """Get the paths to CellMap training data.
370
371    Args:
372        path: Filepath to a folder where the data will be downloaded for further processing
373        organelles: The choice of organelles to download. By default, loads all types of labels available.
374            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
375        crops: The choice of crops to download. By default, downloads `all` crops.
376            For multiple crops, provide the crop ids as a sequence of crop ids.
377        resolution: The choice of resolution in the original volumes.
378            By default, downloads the highest resolution: `s0`.
379        voxel_size: The choice of voxel size for the preprocessed crops to prepare the dataset.
380            By default, chooses all crops in scope.
381        padding: The choice of padding along each dimensions.
382            By default, it pads '64' pixels along all dimensions.
383            You can set it to '0' for no padding at all.
384            For pixel regions without annotations, it labels the masks with id '-1'.
385        download: Whether to download the data if it is not present.
386        return_test_crops: Whether to forcefully return the filepaths of the test crops for other analysis.
387
388    Returns:
389        List of the cropped volume data paths.
390    """
391
392    if not return_test_crops and ("test" in crops if isinstance(crops, (List, Tuple)) else crops == "test"):
393        raise NotImplementedError("The 'test' crops cannot be used in the dataloader.")
394
395    # Get the CellMap data crops.
396    data_path, crops = get_cellmap_data(
397        path=path, crops=crops, resolution=resolution, padding=padding, download=download,
398    )
399
400    # Get all crops.
401    volume_paths = [os.path.join(data_path, f"crop_{c}.h5") for c in crops]
402
403    # Check for valid organelles list to filter crops.
404    if organelles is None:
405        organelles = "all"
406
407    if isinstance(organelles, str):
408        organelles = [organelles]
409
410    other_volume_paths = []
411    for organelle in organelles:
412
413        if organelle != "all":
414            warnings.warn(
415                "You have chosen a different organelle annotations than 'all'. Please keep in mind "
416                f"that it is not guaranteed to provide you the correct masks for '{organelle}'. "
417                "We suggest sticking to 'all' labels and use the corresponding label ids."
418            )
419
420        for vpath in volume_paths:
421            if f"label_crop/{organelle}" in open_file(vpath) and vpath not in other_volume_paths:
422                other_volume_paths.append(vpath)
423
424    if len(other_volume_paths) == 0:
425        raise ValueError(f"'{organelles}' are not valid organelle(s) found in the crops: '{crops}'.")
426
427    # Next, we check for valid voxel size to filter crops.
428    if voxel_size is None:  # no filtering required.
429        another_volume_paths = other_volume_paths
430    else:
431        another_volume_paths = []
432        for vpath in other_volume_paths:
433            if all(np.array(voxel_size) == open_file(vpath).attrs["scale"]) and vpath not in another_volume_paths:
434                another_volume_paths.append(vpath)
435
436    if len(another_volume_paths) == 0:
437        raise ValueError(f"'{voxel_size}' is not a valid voxel size found in the crops: '{crops}'.")
438
439    # Check whether all volume paths exist.
440    for volume_path in another_volume_paths:
441        if not os.path.exists(volume_path):
442            raise FileNotFoundError(f"The volume '{volume_path}' could not be found.")
443
444    return another_volume_paths

Get the paths to CellMap training data.

Arguments:
  • path: Filepath to a folder where the data will be downloaded for further processing
  • organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
  • crops: The choice of crops to download. By default, downloads all crops. For multiple crops, provide the crop ids as a sequence of crop ids.
  • resolution: The choice of resolution in the original volumes. By default, downloads the highest resolution: s0.
  • voxel_size: The choice of voxel size for the preprocessed crops to prepare the dataset. By default, chooses all crops in scope.
  • padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
  • download: Whether to download the data if it is not present.
  • return_test_crops: Whether to forcefully return the filepaths of the test crops for other analysis.
Returns:

List of the cropped volume data paths.

def get_cellmap_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', voxel_size: Optional[Tuple[float]] = None, padding: int = 64, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
447def get_cellmap_dataset(
448    path: Union[os.PathLike, str],
449    patch_shape: Tuple[int, ...],
450    organelles: Optional[Union[str, List[str]]] = None,
451    crops: Union[str, Sequence[str]] = "all",
452    resolution: str = "s0",
453    voxel_size: Optional[Tuple[float]] = None,
454    padding: int = 64,
455    download: bool = False,
456    **kwargs,
457) -> Dataset:
458    """Get the dataset for the CellMap training data for organelle segmentation.
459
460    Args:
461        path: Filepath to a folder where the data will be downloaded for further processing.
462        patch_shape: The patch shape to use for training.
463        organelles: The choice of organelles to download. By default, loads all types of labels available.
464            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
465        crops: The choice of crops to download. By default, downloads `all` crops.
466            For multiple crops, provide the crop ids as a sequence of crop ids.
467        resolution: The choice of resolution in the original volumes.
468            By default, downloads the highest resolution: `s0`.
469        voxel_size: The choice of voxel size for the preprocessed crops to prepare the dataset.
470            By default, chooses all crops in scope.
471        padding: The choice of padding along each dimensions.
472            By default, it pads '64' pixels along all dimensions.
473            You can set it to '0' for no padding at all.
474            For pixel regions without annotations, it labels the masks with id '-1'.
475        download: Whether to download the data if it is not present.
476        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
477
478    Returns:
479        The segmentation dataset.
480    """
481    volume_paths = get_cellmap_paths(
482        path=path,
483        organelles=organelles,
484        crops=crops,
485        resolution=resolution,
486        voxel_size=voxel_size,
487        padding=padding, download=download
488    )
489
490    # Arrange the organelle choices as expected for loading labels.
491    if organelles is None:
492        organelles = "label_crop/all"
493    else:
494        if isinstance(organelles, str):
495            organelles = f"label_crop/{organelles}"
496        else:
497            organelles = [f"label_crop/{curr_organelle}" for curr_organelle in organelles]
498            kwargs = util.update_kwargs(kwargs, "with_label_channels", True)
499
500    return torch_em.default_segmentation_dataset(
501        raw_paths=volume_paths,
502        raw_key="raw_crop",
503        label_paths=volume_paths,
504        label_key=organelles,
505        patch_shape=patch_shape,
506        is_seg_dataset=True,
507        **kwargs
508    )

Get the dataset for the CellMap training data for organelle segmentation.

Arguments:
  • path: Filepath to a folder where the data will be downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
  • crops: The choice of crops to download. By default, downloads all crops. For multiple crops, provide the crop ids as a sequence of crop ids.
  • resolution: The choice of resolution in the original volumes. By default, downloads the highest resolution: s0.
  • voxel_size: The choice of voxel size for the preprocessed crops to prepare the dataset. By default, chooses all crops in scope.
  • padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_cellmap_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', voxel_size: Optional[Tuple[float]] = None, padding: int = 64, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
511def get_cellmap_loader(
512    path: Union[os.PathLike, str],
513    batch_size: int,
514    patch_shape: Tuple[int, ...],
515    organelles: Optional[Union[str, List[str]]] = None,
516    crops: Union[str, Sequence[str]] = "all",
517    resolution: str = "s0",
518    voxel_size: Optional[Tuple[float]] = None,
519    padding: int = 64,
520    download: bool = False,
521    **kwargs,
522) -> DataLoader:
523    """Get the dataloader for the CellMap training data for organelle segmentation.
524
525    Args:
526        path: Filepath to a folder where the data will be downloaded for further processing.
527        batch_size: The batch size for training.
528        patch_shape: The patch shape to use for training.
529        organelles: The choice of organelles to download. By default, loads all types of labels available.
530            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
531        crops: The choice of crops to download. By default, downloads `all` crops.
532            For multiple crops, provide the crop ids as a sequence of crop ids.
533        resolution: The choice of resolution in the original volumes.
534            By default, downloads the highest resolution: `s0`.
535        voxel_size: The choice of voxel size for the preprocessed crops to prepare the dataset.
536            By default, chooses all crops in scope.
537        padding: The choice of padding along each dimensions.
538            By default, it pads '64' pixels along all dimensions.
539            You can set it to '0' for no padding at all.
540            For pixel regions without annotations, it labels the masks with id '-1'.
541        download: Whether to download the data if it is not present.
542        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
543
544    Returns:
545        The DataLoader.
546    """
547    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
548    dataset = get_cellmap_dataset(
549        path, patch_shape, organelles, crops, resolution, voxel_size, padding, download, **ds_kwargs
550    )
551    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the dataloader for the CellMap training data for organelle segmentation.

Arguments:
  • path: Filepath to a folder where the data will be downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
  • crops: The choice of crops to download. By default, downloads all crops. For multiple crops, provide the crop ids as a sequence of crop ids.
  • resolution: The choice of resolution in the original volumes. By default, downloads the highest resolution: s0.
  • voxel_size: The choice of voxel size for the preprocessed crops to prepare the dataset. By default, chooses all crops in scope.
  • padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.