torch_em.data.datasets.electron_microscopy.cellmap

CellMap is a dataset for segmenting various organelles in electron microscopy. It contains a large amount of annotation crops from several species. This dataset is released for the CellMap Segmentation Challenge: https://cellmapchallenge.janelia.org/.

Please cite them if you use this data for your research.

  1"""CellMap is a dataset for segmenting various organelles in electron microscopy.
  2It contains a large amount of annotation crops from several species.
  3This dataset is released for the `CellMap Segmentation Challenge`: https://cellmapchallenge.janelia.org/.
  4- Official documentation: https://janelia-cellmap.github.io/cellmap-segmentation-challenge/.
  5- Original GitHub repository for the toolbox: https://github.com/janelia-cellmap/cellmap-segmentation-challenge.
  6- And associated collection doi for the data: https://doi.org/10.25378/janelia.c.7456966.
  7
  8Please cite them if you use this data for your research.
  9"""
 10
 11import os
 12import time
 13from pathlib import Path
 14from threading import Lock
 15from typing import Union, Optional, Tuple, List, Sequence
 16from concurrent.futures import ThreadPoolExecutor, as_completed
 17
 18import h5py
 19import numpy as np
 20import pandas as pd
 21from xarray import DataArray
 22
 23from torch.utils.data import Dataset, DataLoader
 24
 25import torch_em
 26
 27from .. import util
 28
 29
 30def _download_cellmap_data(path, crops, resolution, padding, download=False):
 31    """Download scripts for the CellMap data.
 32    
 33    Inspired by https://github.com/janelia-cellmap/cellmap-segmentation-challenge/blob/main/src/cellmap_segmentation_challenge/cli/fetch_data.py
 34
 35    NOTE: The download scripts below are intended to stay as close to the original `fetch-data` CLI,
 36    in order to ensure easy syncing with any changes to the original repository in future.
 37    """  # noqa
 38
 39    # Importing packages locally.
 40    # NOTE: Keeping the relevant imports here to avoid `torch-em` throwing missing module error.
 41
 42    try:
 43        from cellmap_segmentation_challenge.utils.fetch_data import read_group, subset_to_slice
 44        from cellmap_segmentation_challenge.utils.crops import fetch_crop_manifest, get_test_crops, TestCropRow
 45    except ImportError:
 46        raise ModuleNotFoundError(
 47            "Please install 'cellmap_segmentation_challenge' package using "
 48            "'pip install git+https://github.com/janelia-cellmap/cellmap-segmentation-challenge.git'."
 49        )
 50
 51    # The imports below will come with the above lines of 'csc' installation.
 52    import structlog
 53    from xarray_ome_ngff import read_multiscale_group
 54    from xarray_ome_ngff.v04.multiscale import transforms_from_coords
 55
 56    # Some important stuff.
 57    fetch_save_start = time.time()
 58    log = structlog.get_logger()
 59    array_wrapper = {"name": "dask_array", "config": {"chunks": "auto"}}
 60
 61    # Get the absolute path location to store crops.
 62    dest_path_abs = Path(path).absolute()
 63    dest_path_abs.mkdir(exist_ok=True)
 64
 65    # Get the entire crop manifest.
 66    crops_from_manifest = fetch_crop_manifest()
 67
 68    # Get the desired crop info from the manifest.
 69    if crops == "all":
 70        crops_parsed = crops_from_manifest
 71    elif crops == "test":
 72        crops_parsed = get_test_crops()
 73        log.info(f"Found '{len(crops_parsed)}' test crops.")
 74    else:  # Otherwise, custom crops are parsed.
 75        crops_split = tuple(int(x) for x in crops.split(","))
 76        crops_parsed = tuple(filter(lambda v: v.id in crops_split, crops_from_manifest))
 77
 78    # Now get the crop ids.
 79    if len(crops_parsed) == 0:
 80        log.info(f"No crops found matching '{crops}'. Doing nothing.")
 81        return
 82
 83    crop_ids = tuple(c.id for c in crops_parsed)
 84    log.info(f"Preparing to copy the following crops: '{crop_ids}'.")
 85    log.info(f"Data will be saved to '{dest_path_abs}'.")
 86
 87    all_crops = []
 88    for crop in crops_parsed:
 89        log = log.bind(crop_id=crop.id, dataset=crop.dataset)
 90
 91        # Get the crop id to a new list for forwarding them ahead.
 92        all_crops.append(crop.id)
 93
 94        # Check whether the crop path has been downloaded already or not.
 95        crop_path = dest_path_abs / f"crop_{crop.id}.h5"
 96        if crop_path.exists():
 97            log.info(f"The crop '{crop.id}' is already saved at '{crop_path}'.")
 98            log = log.unbind("crop_id", "dataset")
 99            continue
100
101        # If 'download' is set to 'False', we do not go further from here.
102        if not download:
103            log.error(f"Cannot download the crop '{crop.id}' as 'download' is set to 'False'.")
104            return
105
106        # Check whether the crop is a part of the test crops, i.e. where GT masks is not available.
107        if isinstance(crop.gt_source, TestCropRow):
108            log.info(f"The test crop '{crop.id}' does not have GT data. Fetching em data only.")
109        else:
110            log.info(f"Fetching GT data for crop '{crop.id}' from '{crop.gt_source}'.")
111
112            # Get the ground-truth (gt) masks.
113            gt_source_group = read_group(str(crop.gt_source), storage_options={"anon": True})
114
115            log.info(f"Found GT data at '{crop.gt_source}'.")
116
117            # Let's get all ground-truth hierarchies.
118            # NOTE: Following same as the original repo, relying on fs.find to avoid slowness in traversing online zarr.
119            fs = gt_source_group.store.fs
120            store_path = gt_source_group.store.path
121            gt_files = fs.find(store_path)
122
123            crop_group_inventory = tuple(fn.split(store_path)[-1] for fn in gt_files)
124            crop_group_inventory = tuple(curr_cg[1:].split("/")[0] for curr_cg in crop_group_inventory)
125            crop_group_inventory = np.unique(crop_group_inventory).tolist()
126            crop_group_inventory = [
127                curr_cg for curr_cg in crop_group_inventory if curr_cg not in [".zattrs", ".zgroup"]
128            ]
129
130            # Get the offset values for the ground truth crops.
131            crop_multiscale_group = None
132            for _, group in gt_source_group.groups():
133                try:  # Get groups for all resolutions.
134                    crop_multiscale_group = read_multiscale_group(group, array_wrapper=array_wrapper)
135                    break
136                except (ValueError, TypeError):
137                    continue
138
139            if crop_multiscale_group is None:
140                log.info(f"No multiscale groups found in '{crop.gt_source}'. No EM data can be fetched.")
141                continue
142
143        # Get the EM volume group.
144        em_source_group = read_group(str(crop.em_url), storage_options={"anon": True})
145        log.info(f"Found EM data at '{crop.em_url}'.")
146
147        # Let's get the multiscale model of the source em group.
148        em_source_arrays = read_multiscale_group(em_source_group, array_wrapper)
149
150        # Next, we need to rely on the scales of each resolution to identify whether the resolution-level is same
151        # for the EM volume and corresponding ground-truth mask crops (if available).
152
153        # For this, we first extract the EM volume scales per resolution.
154        em_resolutions = {}
155        for res_key, array in em_source_arrays.items():
156            try:
157                _, (em_scale, em_translation) = transforms_from_coords(array.coords, transform_precision=4)
158                em_resolutions[res_key] = (em_scale.scale, em_translation.translation)
159            except Exception:
160                continue
161
162        if isinstance(crop.gt_source, TestCropRow):
163            # Choose the scale ratio threshold (from the original scripts)
164            ratio_threshold = 0.8  # NOTE: hard-coded atm to follow along the original data download code logic.
165
166            # Choose the matching resolution level with marked GT.
167            em_level = next(
168                (
169                    k for k, (scale, _) in em_resolutions.items()
170                    if all(s / vs > ratio_threshold for s, vs in zip(scale, crop.gt_source.voxel_size))
171                ), None
172            )
173
174            assert em_level is not None, "There has to be a scale match for the EM volume. Something went wrong."
175
176            scale = em_resolutions[em_level][0]
177            em_array = em_source_arrays[em_level]
178
179            # Get the slices (NOTE: there is info for some crop logic stuff)
180            starts = crop.gt_source.translation
181            stops = tuple(
182                start + size * vs for start, size, vs in zip(starts, crop.gt_source.shape, crop.gt_source.voxel_size)
183            )
184            coords = em_array.coords.copy()
185            for k, v in zip(em_array.coords.keys(), np.array((starts, stops)).T):
186                coords[k] = v
187
188            slices = subset_to_slice(outer_array=em_array, inner_array=DataArray(dims=em_array.dims, coords=coords))
189
190            # Set 'gt_level' to 'None' for better handling of crops without labels.
191            gt_level = None
192
193        else:
194            # Next, we extract the ground-truth scales per resolution (for labeled crops).
195            gt_resolutions = {}
196            for res_key, array in crop_multiscale_group.items():
197                try:
198                    _, (gt_scale, gt_translation) = transforms_from_coords(array.coords, transform_precision=4)
199                    gt_resolutions[res_key] = (gt_scale.scale, gt_translation.translation)
200                except Exception:
201                    continue
202
203            # Now, we find the matching scales and use the respoective "resolution" keys.
204            matching_keys = []
205            for gt_key, (gt_scale, gt_translation) in gt_resolutions.items():
206                for em_key, (em_scale, em_translation) in em_resolutions.items():
207                    if np.allclose(gt_scale, em_scale, rtol=1e-3, atol=1e-6):
208                        matching_keys.append((gt_key, em_key, gt_scale, gt_translation, em_translation))
209
210            # If no match found, that is pretty weird.
211            if not matching_keys:
212                log.error(f"No EM resolution level matches any GT scale for crop ID '{crop.id}'.")
213                continue
214
215            # We get the desired resolution level for the EM volume, labels, and the scale of choice.
216            matching_keys.sort(key=lambda x: np.prod(x[2]))
217            gt_level, em_level, scale, gt_translation, em_translation = matching_keys[0]
218
219            # Get the desired values for the particular resolution level.
220            em_array = em_source_arrays[em_level]
221            gt_crop_shape = gt_source_group[f"all/{gt_level}"].shape  # since "all" exists "al"ways, we rely on it.
222
223            log.info(f"Found a resolution match for EM data at level '{em_level}' and GT data at level '{gt_level}'.")
224
225            # Compute the input reference crop from the ground truth metadata.
226            starts = gt_translation
227            stops = [start + size * vs for start, size, vs in zip(starts, gt_crop_shape, scale)]
228
229            # Get the slices.
230            em_starts = [int(round((p_start - em_translation[i]) / scale[i])) for i, p_start in enumerate(starts)]
231            em_stops = [int(round((p_stop - em_translation[i]) / scale[i])) for i, p_stop in enumerate(stops)]
232            slices = tuple(slice(start, stop) for start, stop in zip(em_starts, em_stops))
233
234        # Pad the slices (in voxel space)
235        slices_padded = tuple(
236            slice(max(0, sl.start - padding), min(sl.stop + padding, dim), sl.step)
237            for sl, dim in zip(slices, em_array.shape)
238        )
239
240        # Extract cropped EM volume from remote zarr files.
241        em_crop = em_array[tuple(slices_padded)].data.compute()
242
243        # Write all stuff in a crop-level h5 file.
244        write_lock = Lock()
245        with h5py.File(crop_path, "w") as f:
246            # Store metadata
247            f.attrs["crop_id"] = crop.id
248            f.attrs["scale"] = scale
249            f.attrs["em_level"] = em_level
250
251            if gt_level is not None:
252                f.attrs["translation"] = gt_translation
253                f.attrs["gt_level"] = gt_level
254
255            # Store inputs.
256            f.create_dataset(name="raw_crop", data=em_crop, dtype=em_crop.dtype, compression="gzip")
257
258            def _fetch_and_write_label(label_name):
259                gt_crop = gt_source_group[f"{label_name}/{gt_level}"][:]
260
261                # Next, pad the labels to match the input shape.
262                def _pad_to_shape(array):
263                    return np.pad(
264                        array=array.astype(np.int16),
265                        pad_width=[
266                            (orig.start - padded.start, padded.stop - orig.stop)
267                            for orig, padded in zip(slices, slices_padded)
268                        ],
269                        mode="constant",
270                        constant_values=-1,
271                    )
272
273                gt_crop = _pad_to_shape(gt_crop)
274
275                # Write each label to their corresponding hierarchy names.
276                with write_lock:
277                    f.create_dataset(
278                        name=f"label_crop/{label_name}", data=gt_crop, dtype=gt_crop.dtype, compression="gzip"
279                    )
280                return label_name
281
282            if gt_level is not None:
283                with ThreadPoolExecutor() as pool:
284                    futures = {pool.submit(_fetch_and_write_label, name): name for name in crop_group_inventory}
285                    for future in as_completed(futures):
286                        label_name = future.result()
287                        log.info(f"Saved ground truth crop '{crop.id}' for '{label_name}'.")
288
289        log.info(f"Saved crop '{crop.id}' to '{crop_path}'.")
290        log = log.unbind("crop_id", "dataset")
291
292    log.info(f"Done after {time.time() - fetch_save_start:0.3f}s")
293    log.info(f"Data saved to '{dest_path_abs}'.")
294
295    return path, all_crops
296
297
298def get_cellmap_data(
299    path: Union[os.PathLike, str],
300    organelles: Optional[Union[str, List[str]]] = None,
301    crops: Union[str, Sequence[str]] = "all",
302    resolution: str = "s0",
303    padding: int = 64,
304    download: bool = False,
305) -> Tuple[str, List[str]]:
306    """Downloads the CellMap training data.
307
308    Args:
309        path: Filepath to a folder where the data will be downloaded for further processing
310        organelles: The choice of organelles to download. By default, loads all types of labels available.
311            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
312        crops: The choice of crops to download. By default, downloads `all` crops.
313            For multiple crops, provide the crop ids as a sequence of crop ids.
314        resolution: The choice of resolution. By default, downloads the highest resolution: `s0`.
315        padding: The choice of padding along each dimensions.
316            By default, it pads '64' pixels along all dimensions.
317            You can set it to '0' for no padding at all.
318            For pixel regions without annotations, it labels the masks with id '-1'.
319        download: Whether to download the data if it is not present.
320
321    Returns:
322        Filepath where the data is stored for further processing.
323        List of crop ids.
324    """
325
326    data_path = os.path.join(path, "data_crops")
327    os.makedirs(data_path, exist_ok=True)
328
329    # Get the crops in 'csc' desired format.
330    if isinstance(crops, Sequence) and not isinstance(crops, str):  # for multiple values
331        crops = ",".join(str(c) for c in crops)
332
333    # NOTE: The function below is comparable to the CLI `csc fetch-data` from the original repo.
334    _data_path, final_crops = _download_cellmap_data(
335        path=data_path,
336        crops=crops,
337        resolution=resolution,
338        padding=padding,
339        download=download,
340    )
341
342    # Get the organelle-crop mapping.
343    from cellmap_segmentation_challenge import utils
344
345    # There is a file named 'train_crop_manifest' in the 'utils' sub-module. We need to get that first
346    train_metadata_file = os.path.join(str(Path(utils.__file__).parent / "train_crop_manifest.csv"))
347    train_metadata = pd.read_csv(train_metadata_file)
348
349    # Let's get the label to crop mapping from the manifest file.
350    organelle_to_crops = train_metadata.groupby('class_label')['crop_name'].apply(list).to_dict()
351
352    # By default, 'organelles' set to 'None' will give you 'all' organelle types.
353    if organelles is not None:  # The assumption here is that the user wants specific organelle(s).
354        # Validate whether the organelle exists in the desired crops at all.
355        if isinstance(organelles, str):
356            organelles = [organelles]
357
358        # Next, we check whether they match the crops.
359        for curr_organelle in organelles:
360            if curr_organelle not in organelle_to_crops:  # Check whether the organelle is valid or not.
361                raise ValueError(f"The chosen organelle: '{curr_organelle}' seems to be an invalid choice.")
362
363            # Lastly, we check whether the final crops have the organelle(s) or not.
364            # Otherwise, we throw a warning and go ahead with the true valid choices.
365            # NOTE: The priority below is higher for organelles than crops.
366            for curr_crop in final_crops:
367                if curr_crop not in organelle_to_crops.get(curr_organelle):
368                    raise ValueError(f"The crop '{curr_crop}' does not have the chosen organelle '{curr_organelle}'.")
369
370    if _data_path is None or len(_data_path) == 0:
371        raise RuntimeError("Something went wrong. Please read the information logged above.")
372
373    assert len(final_crops) > 0, "There seems to be no valid crops in the list."
374
375    return data_path, final_crops
376
377
378def get_cellmap_paths(
379    path: Union[os.PathLike, str],
380    organelles: Optional[Union[str, List[str]]] = None,
381    crops: Union[str, Sequence[str]] = "all",
382    resolution: str = "s0",
383    padding: int = 64,
384    download: bool = False,
385    return_test_crops: bool = False,
386) -> List[str]:
387    """Get the paths to CellMap training data.
388
389    Args:
390        path: Filepath to a folder where the data will be downloaded for further processing
391        organelles: The choice of organelles to download. By default, loads all types of labels available.
392            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
393        crops: The choice of crops to download. By default, downloads `all` crops.
394            For multiple crops, provide the crop ids as a sequence of crop ids.
395        resolution: The choice of resolution. By default, downloads the highest resolution: `s0`.
396        padding: The choice of padding along each dimensions.
397            By default, it pads '64' pixels along all dimensions.
398            You can set it to '0' for no padding at all.
399            For pixel regions without annotations, it labels the masks with id '-1'.
400        download: Whether to download the data if it is not present.
401        return_test_crops: Whether to forcefully return the filepaths of the test crops for other analysis.
402
403    Returns:
404        List of the cropped volume data paths.
405    """
406
407    if not return_test_crops and ("test" in crops if isinstance(crops, (List, Tuple)) else crops == "test"):
408        raise NotImplementedError("The 'test' crops cannot be used in the dataloader.")
409
410    # Get the CellMap data crops.
411    data_path, crops = get_cellmap_data(
412        path=path, organelles=organelles, crops=crops, resolution=resolution, padding=padding, download=download
413    )
414
415    # Get all crops.
416    volume_paths = [os.path.join(data_path, f"crop_{c}.h5") for c in crops]
417
418    # Check whether all volume paths exist.
419    for volume_path in volume_paths:
420        if not os.path.exists(volume_path):
421            raise FileNotFoundError(f"The volume '{volume_path}' could not be found.")
422
423    return volume_paths
424
425
426def get_cellmap_dataset(
427    path: Union[os.PathLike, str],
428    patch_shape: Tuple[int, ...],
429    organelles: Optional[Union[str, List[str]]] = None,
430    crops: Union[str, Sequence[str]] = "all",
431    resolution: str = "s0",
432    padding: int = 64,
433    download: bool = False,
434    **kwargs,
435) -> Dataset:
436    """Get the dataset for the CellMap training data for organelle segmentation.
437
438    Args:
439        path: Filepath to a folder where the data will be downloaded for further processing.
440        patch_shape: The patch shape to use for training.
441        organelles: The choice of organelles to download. By default, loads all types of labels available.
442            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
443        crops: The choice of crops to download. By default, downloads `all` crops.
444            For multiple crops, provide the crop ids as a sequence of crop ids.
445        resolution: The choice of resolution. By default, downloads the highest resolution: `s0`.
446        padding: The choice of padding along each dimensions.
447            By default, it pads '64' pixels along all dimensions.
448            You can set it to '0' for no padding at all.
449            For pixel regions without annotations, it labels the masks with id '-1'.
450        download: Whether to download the data if it is not present.
451        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
452
453    Returns:
454        The segmentation dataset.
455    """
456    volume_paths = get_cellmap_paths(
457        path=path, organelles=organelles, crops=crops, resolution=resolution, padding=padding, download=download
458    )
459
460    # Arrange the organelle choices as expected for loading labels.
461    if organelles is None:
462        organelles = "label_crop/all"
463    else:
464        if isinstance(organelles, str):
465            organelles = f"label_crop/{organelles}"
466        else:
467            organelles = [f"label_crop/{curr_organelle}" for curr_organelle in organelles]
468            kwargs = util.update_kwargs(kwargs, "with_label_channels", True)
469
470    return torch_em.default_segmentation_dataset(
471        raw_paths=volume_paths,
472        raw_key="raw_crop",
473        label_paths=volume_paths,
474        label_key=organelles,
475        patch_shape=patch_shape,
476        is_seg_dataset=True,
477        **kwargs
478    )
479
480
481def get_cellmap_loader(
482    path: Union[os.PathLike, str],
483    batch_size: int,
484    patch_shape: Tuple[int, ...],
485    organelles: Optional[Union[str, List[str]]] = None,
486    crops: Union[str, Sequence[str]] = "all",
487    resolution: str = "s0",
488    padding: int = 64,
489    download: bool = False,
490    **kwargs,
491) -> DataLoader:
492    """Get the dataloader for the CellMap training data for organelle segmentation.
493
494    Args:
495        path: Filepath to a folder where the data will be downloaded for further processing.
496        batch_size: The batch size for training.
497        patch_shape: The patch shape to use for training.
498        organelles: The choice of organelles to download. By default, loads all types of labels available.
499            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
500        crops: The choice of crops to download. By default, downloads `all` crops.
501            For multiple crops, provide the crop ids as a sequence of crop ids.
502        resolution: The choice of resolution. By default, downloads the highest resolution: `s0`.
503        padding: The choice of padding along each dimensions.
504            By default, it pads '64' pixels along all dimensions.
505            You can set it to '0' for no padding at all.
506            For pixel regions without annotations, it labels the masks with id '-1'.
507        download: Whether to download the data if it is not present.
508        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
509
510    Returns:
511        The DataLoader.
512    """
513    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
514    dataset = get_cellmap_dataset(path, patch_shape, organelles, crops, resolution, padding, download, **ds_kwargs)
515    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
def get_cellmap_data( path: Union[os.PathLike, str], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', padding: int = 64, download: bool = False) -> Tuple[str, List[str]]:
299def get_cellmap_data(
300    path: Union[os.PathLike, str],
301    organelles: Optional[Union[str, List[str]]] = None,
302    crops: Union[str, Sequence[str]] = "all",
303    resolution: str = "s0",
304    padding: int = 64,
305    download: bool = False,
306) -> Tuple[str, List[str]]:
307    """Downloads the CellMap training data.
308
309    Args:
310        path: Filepath to a folder where the data will be downloaded for further processing
311        organelles: The choice of organelles to download. By default, loads all types of labels available.
312            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
313        crops: The choice of crops to download. By default, downloads `all` crops.
314            For multiple crops, provide the crop ids as a sequence of crop ids.
315        resolution: The choice of resolution. By default, downloads the highest resolution: `s0`.
316        padding: The choice of padding along each dimensions.
317            By default, it pads '64' pixels along all dimensions.
318            You can set it to '0' for no padding at all.
319            For pixel regions without annotations, it labels the masks with id '-1'.
320        download: Whether to download the data if it is not present.
321
322    Returns:
323        Filepath where the data is stored for further processing.
324        List of crop ids.
325    """
326
327    data_path = os.path.join(path, "data_crops")
328    os.makedirs(data_path, exist_ok=True)
329
330    # Get the crops in 'csc' desired format.
331    if isinstance(crops, Sequence) and not isinstance(crops, str):  # for multiple values
332        crops = ",".join(str(c) for c in crops)
333
334    # NOTE: The function below is comparable to the CLI `csc fetch-data` from the original repo.
335    _data_path, final_crops = _download_cellmap_data(
336        path=data_path,
337        crops=crops,
338        resolution=resolution,
339        padding=padding,
340        download=download,
341    )
342
343    # Get the organelle-crop mapping.
344    from cellmap_segmentation_challenge import utils
345
346    # There is a file named 'train_crop_manifest' in the 'utils' sub-module. We need to get that first
347    train_metadata_file = os.path.join(str(Path(utils.__file__).parent / "train_crop_manifest.csv"))
348    train_metadata = pd.read_csv(train_metadata_file)
349
350    # Let's get the label to crop mapping from the manifest file.
351    organelle_to_crops = train_metadata.groupby('class_label')['crop_name'].apply(list).to_dict()
352
353    # By default, 'organelles' set to 'None' will give you 'all' organelle types.
354    if organelles is not None:  # The assumption here is that the user wants specific organelle(s).
355        # Validate whether the organelle exists in the desired crops at all.
356        if isinstance(organelles, str):
357            organelles = [organelles]
358
359        # Next, we check whether they match the crops.
360        for curr_organelle in organelles:
361            if curr_organelle not in organelle_to_crops:  # Check whether the organelle is valid or not.
362                raise ValueError(f"The chosen organelle: '{curr_organelle}' seems to be an invalid choice.")
363
364            # Lastly, we check whether the final crops have the organelle(s) or not.
365            # Otherwise, we throw a warning and go ahead with the true valid choices.
366            # NOTE: The priority below is higher for organelles than crops.
367            for curr_crop in final_crops:
368                if curr_crop not in organelle_to_crops.get(curr_organelle):
369                    raise ValueError(f"The crop '{curr_crop}' does not have the chosen organelle '{curr_organelle}'.")
370
371    if _data_path is None or len(_data_path) == 0:
372        raise RuntimeError("Something went wrong. Please read the information logged above.")
373
374    assert len(final_crops) > 0, "There seems to be no valid crops in the list."
375
376    return data_path, final_crops

Downloads the CellMap training data.

Arguments:
  • path: Filepath to a folder where the data will be downloaded for further processing
  • organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
  • crops: The choice of crops to download. By default, downloads all crops. For multiple crops, provide the crop ids as a sequence of crop ids.
  • resolution: The choice of resolution. By default, downloads the highest resolution: s0.
  • padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is stored for further processing. List of crop ids.

def get_cellmap_paths( path: Union[os.PathLike, str], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', padding: int = 64, download: bool = False, return_test_crops: bool = False) -> List[str]:
379def get_cellmap_paths(
380    path: Union[os.PathLike, str],
381    organelles: Optional[Union[str, List[str]]] = None,
382    crops: Union[str, Sequence[str]] = "all",
383    resolution: str = "s0",
384    padding: int = 64,
385    download: bool = False,
386    return_test_crops: bool = False,
387) -> List[str]:
388    """Get the paths to CellMap training data.
389
390    Args:
391        path: Filepath to a folder where the data will be downloaded for further processing
392        organelles: The choice of organelles to download. By default, loads all types of labels available.
393            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
394        crops: The choice of crops to download. By default, downloads `all` crops.
395            For multiple crops, provide the crop ids as a sequence of crop ids.
396        resolution: The choice of resolution. By default, downloads the highest resolution: `s0`.
397        padding: The choice of padding along each dimensions.
398            By default, it pads '64' pixels along all dimensions.
399            You can set it to '0' for no padding at all.
400            For pixel regions without annotations, it labels the masks with id '-1'.
401        download: Whether to download the data if it is not present.
402        return_test_crops: Whether to forcefully return the filepaths of the test crops for other analysis.
403
404    Returns:
405        List of the cropped volume data paths.
406    """
407
408    if not return_test_crops and ("test" in crops if isinstance(crops, (List, Tuple)) else crops == "test"):
409        raise NotImplementedError("The 'test' crops cannot be used in the dataloader.")
410
411    # Get the CellMap data crops.
412    data_path, crops = get_cellmap_data(
413        path=path, organelles=organelles, crops=crops, resolution=resolution, padding=padding, download=download
414    )
415
416    # Get all crops.
417    volume_paths = [os.path.join(data_path, f"crop_{c}.h5") for c in crops]
418
419    # Check whether all volume paths exist.
420    for volume_path in volume_paths:
421        if not os.path.exists(volume_path):
422            raise FileNotFoundError(f"The volume '{volume_path}' could not be found.")
423
424    return volume_paths

Get the paths to CellMap training data.

Arguments:
  • path: Filepath to a folder where the data will be downloaded for further processing
  • organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
  • crops: The choice of crops to download. By default, downloads all crops. For multiple crops, provide the crop ids as a sequence of crop ids.
  • resolution: The choice of resolution. By default, downloads the highest resolution: s0.
  • padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
  • download: Whether to download the data if it is not present.
  • return_test_crops: Whether to forcefully return the filepaths of the test crops for other analysis.
Returns:

List of the cropped volume data paths.

def get_cellmap_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', padding: int = 64, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
427def get_cellmap_dataset(
428    path: Union[os.PathLike, str],
429    patch_shape: Tuple[int, ...],
430    organelles: Optional[Union[str, List[str]]] = None,
431    crops: Union[str, Sequence[str]] = "all",
432    resolution: str = "s0",
433    padding: int = 64,
434    download: bool = False,
435    **kwargs,
436) -> Dataset:
437    """Get the dataset for the CellMap training data for organelle segmentation.
438
439    Args:
440        path: Filepath to a folder where the data will be downloaded for further processing.
441        patch_shape: The patch shape to use for training.
442        organelles: The choice of organelles to download. By default, loads all types of labels available.
443            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
444        crops: The choice of crops to download. By default, downloads `all` crops.
445            For multiple crops, provide the crop ids as a sequence of crop ids.
446        resolution: The choice of resolution. By default, downloads the highest resolution: `s0`.
447        padding: The choice of padding along each dimensions.
448            By default, it pads '64' pixels along all dimensions.
449            You can set it to '0' for no padding at all.
450            For pixel regions without annotations, it labels the masks with id '-1'.
451        download: Whether to download the data if it is not present.
452        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
453
454    Returns:
455        The segmentation dataset.
456    """
457    volume_paths = get_cellmap_paths(
458        path=path, organelles=organelles, crops=crops, resolution=resolution, padding=padding, download=download
459    )
460
461    # Arrange the organelle choices as expected for loading labels.
462    if organelles is None:
463        organelles = "label_crop/all"
464    else:
465        if isinstance(organelles, str):
466            organelles = f"label_crop/{organelles}"
467        else:
468            organelles = [f"label_crop/{curr_organelle}" for curr_organelle in organelles]
469            kwargs = util.update_kwargs(kwargs, "with_label_channels", True)
470
471    return torch_em.default_segmentation_dataset(
472        raw_paths=volume_paths,
473        raw_key="raw_crop",
474        label_paths=volume_paths,
475        label_key=organelles,
476        patch_shape=patch_shape,
477        is_seg_dataset=True,
478        **kwargs
479    )

Get the dataset for the CellMap training data for organelle segmentation.

Arguments:
  • path: Filepath to a folder where the data will be downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
  • crops: The choice of crops to download. By default, downloads all crops. For multiple crops, provide the crop ids as a sequence of crop ids.
  • resolution: The choice of resolution. By default, downloads the highest resolution: s0.
  • padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_cellmap_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', padding: int = 64, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
482def get_cellmap_loader(
483    path: Union[os.PathLike, str],
484    batch_size: int,
485    patch_shape: Tuple[int, ...],
486    organelles: Optional[Union[str, List[str]]] = None,
487    crops: Union[str, Sequence[str]] = "all",
488    resolution: str = "s0",
489    padding: int = 64,
490    download: bool = False,
491    **kwargs,
492) -> DataLoader:
493    """Get the dataloader for the CellMap training data for organelle segmentation.
494
495    Args:
496        path: Filepath to a folder where the data will be downloaded for further processing.
497        batch_size: The batch size for training.
498        patch_shape: The patch shape to use for training.
499        organelles: The choice of organelles to download. By default, loads all types of labels available.
500            For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
501        crops: The choice of crops to download. By default, downloads `all` crops.
502            For multiple crops, provide the crop ids as a sequence of crop ids.
503        resolution: The choice of resolution. By default, downloads the highest resolution: `s0`.
504        padding: The choice of padding along each dimensions.
505            By default, it pads '64' pixels along all dimensions.
506            You can set it to '0' for no padding at all.
507            For pixel regions without annotations, it labels the masks with id '-1'.
508        download: Whether to download the data if it is not present.
509        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
510
511    Returns:
512        The DataLoader.
513    """
514    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
515    dataset = get_cellmap_dataset(path, patch_shape, organelles, crops, resolution, padding, download, **ds_kwargs)
516    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the dataloader for the CellMap training data for organelle segmentation.

Arguments:
  • path: Filepath to a folder where the data will be downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
  • crops: The choice of crops to download. By default, downloads all crops. For multiple crops, provide the crop ids as a sequence of crop ids.
  • resolution: The choice of resolution. By default, downloads the highest resolution: s0.
  • padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.