torch_em.data.datasets.electron_microscopy.cellmap
CellMap is a dataset for segmenting various organelles in electron microscopy.
It contains a large amount of annotation crops from several species.
This dataset is released for the CellMap Segmentation Challenge
: https://cellmapchallenge.janelia.org/.
- Official documentation: https://janelia-cellmap.github.io/cellmap-segmentation-challenge/.
- Original GitHub repository for the toolbox: https://github.com/janelia-cellmap/cellmap-segmentation-challenge.
- And associated collection doi for the data: https://doi.org/10.25378/janelia.c.7456966.
Please cite them if you use this data for your research.
1"""CellMap is a dataset for segmenting various organelles in electron microscopy. 2It contains a large amount of annotation crops from several species. 3This dataset is released for the `CellMap Segmentation Challenge`: https://cellmapchallenge.janelia.org/. 4- Official documentation: https://janelia-cellmap.github.io/cellmap-segmentation-challenge/. 5- Original GitHub repository for the toolbox: https://github.com/janelia-cellmap/cellmap-segmentation-challenge. 6- And associated collection doi for the data: https://doi.org/10.25378/janelia.c.7456966. 7 8Please cite them if you use this data for your research. 9""" 10 11import os 12import time 13from pathlib import Path 14from threading import Lock 15from typing import Union, Optional, Tuple, List, Sequence 16from concurrent.futures import ThreadPoolExecutor, as_completed 17 18import h5py 19import numpy as np 20import pandas as pd 21from xarray import DataArray 22 23from torch.utils.data import Dataset, DataLoader 24 25import torch_em 26 27from .. import util 28 29 30def _download_cellmap_data(path, crops, resolution, padding, download=False): 31 """Download scripts for the CellMap data. 32 33 Inspired by https://github.com/janelia-cellmap/cellmap-segmentation-challenge/blob/main/src/cellmap_segmentation_challenge/cli/fetch_data.py 34 35 NOTE: The download scripts below are intended to stay as close to the original `fetch-data` CLI, 36 in order to ensure easy syncing with any changes to the original repository in future. 37 """ # noqa 38 39 # Importing packages locally. 40 # NOTE: Keeping the relevant imports here to avoid `torch-em` throwing missing module error. 41 42 try: 43 from cellmap_segmentation_challenge.utils.fetch_data import read_group, subset_to_slice 44 from cellmap_segmentation_challenge.utils.crops import fetch_crop_manifest, get_test_crops, TestCropRow 45 except ImportError: 46 raise ModuleNotFoundError( 47 "Please install 'cellmap_segmentation_challenge' package using " 48 "'pip install git+https://github.com/janelia-cellmap/cellmap-segmentation-challenge.git'." 49 ) 50 51 # The imports below will come with the above lines of 'csc' installation. 52 import structlog 53 from xarray_ome_ngff import read_multiscale_group 54 from xarray_ome_ngff.v04.multiscale import transforms_from_coords 55 56 # Some important stuff. 57 fetch_save_start = time.time() 58 log = structlog.get_logger() 59 array_wrapper = {"name": "dask_array", "config": {"chunks": "auto"}} 60 61 # Get the absolute path location to store crops. 62 dest_path_abs = Path(path).absolute() 63 dest_path_abs.mkdir(exist_ok=True) 64 65 # Get the entire crop manifest. 66 crops_from_manifest = fetch_crop_manifest() 67 68 # Get the desired crop info from the manifest. 69 if crops == "all": 70 crops_parsed = crops_from_manifest 71 elif crops == "test": 72 crops_parsed = get_test_crops() 73 log.info(f"Found '{len(crops_parsed)}' test crops.") 74 else: # Otherwise, custom crops are parsed. 75 crops_split = tuple(int(x) for x in crops.split(",")) 76 crops_parsed = tuple(filter(lambda v: v.id in crops_split, crops_from_manifest)) 77 78 # Now get the crop ids. 79 if len(crops_parsed) == 0: 80 log.info(f"No crops found matching '{crops}'. Doing nothing.") 81 return 82 83 crop_ids = tuple(c.id for c in crops_parsed) 84 log.info(f"Preparing to copy the following crops: '{crop_ids}'.") 85 log.info(f"Data will be saved to '{dest_path_abs}'.") 86 87 all_crops = [] 88 for crop in crops_parsed: 89 log = log.bind(crop_id=crop.id, dataset=crop.dataset) 90 91 # Get the crop id to a new list for forwarding them ahead. 92 all_crops.append(crop.id) 93 94 # Check whether the crop path has been downloaded already or not. 95 crop_path = dest_path_abs / f"crop_{crop.id}.h5" 96 if crop_path.exists(): 97 log.info(f"The crop '{crop.id}' is already saved at '{crop_path}'.") 98 log = log.unbind("crop_id", "dataset") 99 continue 100 101 # If 'download' is set to 'False', we do not go further from here. 102 if not download: 103 log.error(f"Cannot download the crop '{crop.id}' as 'download' is set to 'False'.") 104 return 105 106 # Check whether the crop is a part of the test crops, i.e. where GT masks is not available. 107 if isinstance(crop.gt_source, TestCropRow): 108 log.info(f"The test crop '{crop.id}' does not have GT data. Fetching em data only.") 109 else: 110 log.info(f"Fetching GT data for crop '{crop.id}' from '{crop.gt_source}'.") 111 112 # Get the ground-truth (gt) masks. 113 gt_source_group = read_group(str(crop.gt_source), storage_options={"anon": True}) 114 115 log.info(f"Found GT data at '{crop.gt_source}'.") 116 117 # Let's get all ground-truth hierarchies. 118 # NOTE: Following same as the original repo, relying on fs.find to avoid slowness in traversing online zarr. 119 fs = gt_source_group.store.fs 120 store_path = gt_source_group.store.path 121 gt_files = fs.find(store_path) 122 123 crop_group_inventory = tuple(fn.split(store_path)[-1] for fn in gt_files) 124 crop_group_inventory = tuple(curr_cg[1:].split("/")[0] for curr_cg in crop_group_inventory) 125 crop_group_inventory = np.unique(crop_group_inventory).tolist() 126 crop_group_inventory = [ 127 curr_cg for curr_cg in crop_group_inventory if curr_cg not in [".zattrs", ".zgroup"] 128 ] 129 130 # Get the offset values for the ground truth crops. 131 crop_multiscale_group = None 132 for _, group in gt_source_group.groups(): 133 try: # Get groups for all resolutions. 134 crop_multiscale_group = read_multiscale_group(group, array_wrapper=array_wrapper) 135 break 136 except (ValueError, TypeError): 137 continue 138 139 if crop_multiscale_group is None: 140 log.info(f"No multiscale groups found in '{crop.gt_source}'. No EM data can be fetched.") 141 continue 142 143 # Get the EM volume group. 144 em_source_group = read_group(str(crop.em_url), storage_options={"anon": True}) 145 log.info(f"Found EM data at '{crop.em_url}'.") 146 147 # Let's get the multiscale model of the source em group. 148 em_source_arrays = read_multiscale_group(em_source_group, array_wrapper) 149 150 # Next, we need to rely on the scales of each resolution to identify whether the resolution-level is same 151 # for the EM volume and corresponding ground-truth mask crops (if available). 152 153 # For this, we first extract the EM volume scales per resolution. 154 em_resolutions = {} 155 for res_key, array in em_source_arrays.items(): 156 try: 157 _, (em_scale, em_translation) = transforms_from_coords(array.coords, transform_precision=4) 158 em_resolutions[res_key] = (em_scale.scale, em_translation.translation) 159 except Exception: 160 continue 161 162 if isinstance(crop.gt_source, TestCropRow): 163 # Choose the scale ratio threshold (from the original scripts) 164 ratio_threshold = 0.8 # NOTE: hard-coded atm to follow along the original data download code logic. 165 166 # Choose the matching resolution level with marked GT. 167 em_level = next( 168 ( 169 k for k, (scale, _) in em_resolutions.items() 170 if all(s / vs > ratio_threshold for s, vs in zip(scale, crop.gt_source.voxel_size)) 171 ), None 172 ) 173 174 assert em_level is not None, "There has to be a scale match for the EM volume. Something went wrong." 175 176 scale = em_resolutions[em_level][0] 177 em_array = em_source_arrays[em_level] 178 179 # Get the slices (NOTE: there is info for some crop logic stuff) 180 starts = crop.gt_source.translation 181 stops = tuple( 182 start + size * vs for start, size, vs in zip(starts, crop.gt_source.shape, crop.gt_source.voxel_size) 183 ) 184 coords = em_array.coords.copy() 185 for k, v in zip(em_array.coords.keys(), np.array((starts, stops)).T): 186 coords[k] = v 187 188 slices = subset_to_slice(outer_array=em_array, inner_array=DataArray(dims=em_array.dims, coords=coords)) 189 190 # Set 'gt_level' to 'None' for better handling of crops without labels. 191 gt_level = None 192 193 else: 194 # Next, we extract the ground-truth scales per resolution (for labeled crops). 195 gt_resolutions = {} 196 for res_key, array in crop_multiscale_group.items(): 197 try: 198 _, (gt_scale, gt_translation) = transforms_from_coords(array.coords, transform_precision=4) 199 gt_resolutions[res_key] = (gt_scale.scale, gt_translation.translation) 200 except Exception: 201 continue 202 203 # Now, we find the matching scales and use the respoective "resolution" keys. 204 matching_keys = [] 205 for gt_key, (gt_scale, gt_translation) in gt_resolutions.items(): 206 for em_key, (em_scale, em_translation) in em_resolutions.items(): 207 if np.allclose(gt_scale, em_scale, rtol=1e-3, atol=1e-6): 208 matching_keys.append((gt_key, em_key, gt_scale, gt_translation, em_translation)) 209 210 # If no match found, that is pretty weird. 211 if not matching_keys: 212 log.error(f"No EM resolution level matches any GT scale for crop ID '{crop.id}'.") 213 continue 214 215 # We get the desired resolution level for the EM volume, labels, and the scale of choice. 216 matching_keys.sort(key=lambda x: np.prod(x[2])) 217 gt_level, em_level, scale, gt_translation, em_translation = matching_keys[0] 218 219 # Get the desired values for the particular resolution level. 220 em_array = em_source_arrays[em_level] 221 gt_crop_shape = gt_source_group[f"all/{gt_level}"].shape # since "all" exists "al"ways, we rely on it. 222 223 log.info(f"Found a resolution match for EM data at level '{em_level}' and GT data at level '{gt_level}'.") 224 225 # Compute the input reference crop from the ground truth metadata. 226 starts = gt_translation 227 stops = [start + size * vs for start, size, vs in zip(starts, gt_crop_shape, scale)] 228 229 # Get the slices. 230 em_starts = [int(round((p_start - em_translation[i]) / scale[i])) for i, p_start in enumerate(starts)] 231 em_stops = [int(round((p_stop - em_translation[i]) / scale[i])) for i, p_stop in enumerate(stops)] 232 slices = tuple(slice(start, stop) for start, stop in zip(em_starts, em_stops)) 233 234 # Pad the slices (in voxel space) 235 slices_padded = tuple( 236 slice(max(0, sl.start - padding), min(sl.stop + padding, dim), sl.step) 237 for sl, dim in zip(slices, em_array.shape) 238 ) 239 240 # Extract cropped EM volume from remote zarr files. 241 em_crop = em_array[tuple(slices_padded)].data.compute() 242 243 # Write all stuff in a crop-level h5 file. 244 write_lock = Lock() 245 with h5py.File(crop_path, "w") as f: 246 # Store metadata 247 f.attrs["crop_id"] = crop.id 248 f.attrs["scale"] = scale 249 f.attrs["em_level"] = em_level 250 251 if gt_level is not None: 252 f.attrs["translation"] = gt_translation 253 f.attrs["gt_level"] = gt_level 254 255 # Store inputs. 256 f.create_dataset(name="raw_crop", data=em_crop, dtype=em_crop.dtype, compression="gzip") 257 258 def _fetch_and_write_label(label_name): 259 gt_crop = gt_source_group[f"{label_name}/{gt_level}"][:] 260 261 # Next, pad the labels to match the input shape. 262 def _pad_to_shape(array): 263 return np.pad( 264 array=array.astype(np.int16), 265 pad_width=[ 266 (orig.start - padded.start, padded.stop - orig.stop) 267 for orig, padded in zip(slices, slices_padded) 268 ], 269 mode="constant", 270 constant_values=-1, 271 ) 272 273 gt_crop = _pad_to_shape(gt_crop) 274 275 # Write each label to their corresponding hierarchy names. 276 with write_lock: 277 f.create_dataset( 278 name=f"label_crop/{label_name}", data=gt_crop, dtype=gt_crop.dtype, compression="gzip" 279 ) 280 return label_name 281 282 if gt_level is not None: 283 with ThreadPoolExecutor() as pool: 284 futures = {pool.submit(_fetch_and_write_label, name): name for name in crop_group_inventory} 285 for future in as_completed(futures): 286 label_name = future.result() 287 log.info(f"Saved ground truth crop '{crop.id}' for '{label_name}'.") 288 289 log.info(f"Saved crop '{crop.id}' to '{crop_path}'.") 290 log = log.unbind("crop_id", "dataset") 291 292 log.info(f"Done after {time.time() - fetch_save_start:0.3f}s") 293 log.info(f"Data saved to '{dest_path_abs}'.") 294 295 return path, all_crops 296 297 298def get_cellmap_data( 299 path: Union[os.PathLike, str], 300 organelles: Optional[Union[str, List[str]]] = None, 301 crops: Union[str, Sequence[str]] = "all", 302 resolution: str = "s0", 303 padding: int = 64, 304 download: bool = False, 305) -> Tuple[str, List[str]]: 306 """Downloads the CellMap training data. 307 308 Args: 309 path: Filepath to a folder where the data will be downloaded for further processing 310 organelles: The choice of organelles to download. By default, loads all types of labels available. 311 For one for multiple organelles, specify either like 'mito' or ['mito', 'cell']. 312 crops: The choice of crops to download. By default, downloads `all` crops. 313 For multiple crops, provide the crop ids as a sequence of crop ids. 314 resolution: The choice of resolution. By default, downloads the highest resolution: `s0`. 315 padding: The choice of padding along each dimensions. 316 By default, it pads '64' pixels along all dimensions. 317 You can set it to '0' for no padding at all. 318 For pixel regions without annotations, it labels the masks with id '-1'. 319 download: Whether to download the data if it is not present. 320 321 Returns: 322 Filepath where the data is stored for further processing. 323 List of crop ids. 324 """ 325 326 data_path = os.path.join(path, "data_crops") 327 os.makedirs(data_path, exist_ok=True) 328 329 # Get the crops in 'csc' desired format. 330 if isinstance(crops, Sequence) and not isinstance(crops, str): # for multiple values 331 crops = ",".join(str(c) for c in crops) 332 333 # NOTE: The function below is comparable to the CLI `csc fetch-data` from the original repo. 334 _data_path, final_crops = _download_cellmap_data( 335 path=data_path, 336 crops=crops, 337 resolution=resolution, 338 padding=padding, 339 download=download, 340 ) 341 342 # Get the organelle-crop mapping. 343 from cellmap_segmentation_challenge import utils 344 345 # There is a file named 'train_crop_manifest' in the 'utils' sub-module. We need to get that first 346 train_metadata_file = os.path.join(str(Path(utils.__file__).parent / "train_crop_manifest.csv")) 347 train_metadata = pd.read_csv(train_metadata_file) 348 349 # Let's get the label to crop mapping from the manifest file. 350 organelle_to_crops = train_metadata.groupby('class_label')['crop_name'].apply(list).to_dict() 351 352 # By default, 'organelles' set to 'None' will give you 'all' organelle types. 353 if organelles is not None: # The assumption here is that the user wants specific organelle(s). 354 # Validate whether the organelle exists in the desired crops at all. 355 if isinstance(organelles, str): 356 organelles = [organelles] 357 358 # Next, we check whether they match the crops. 359 for curr_organelle in organelles: 360 if curr_organelle not in organelle_to_crops: # Check whether the organelle is valid or not. 361 raise ValueError(f"The chosen organelle: '{curr_organelle}' seems to be an invalid choice.") 362 363 # Lastly, we check whether the final crops have the organelle(s) or not. 364 # Otherwise, we throw a warning and go ahead with the true valid choices. 365 # NOTE: The priority below is higher for organelles than crops. 366 for curr_crop in final_crops: 367 if curr_crop not in organelle_to_crops.get(curr_organelle): 368 raise ValueError(f"The crop '{curr_crop}' does not have the chosen organelle '{curr_organelle}'.") 369 370 if _data_path is None or len(_data_path) == 0: 371 raise RuntimeError("Something went wrong. Please read the information logged above.") 372 373 assert len(final_crops) > 0, "There seems to be no valid crops in the list." 374 375 return data_path, final_crops 376 377 378def get_cellmap_paths( 379 path: Union[os.PathLike, str], 380 organelles: Optional[Union[str, List[str]]] = None, 381 crops: Union[str, Sequence[str]] = "all", 382 resolution: str = "s0", 383 padding: int = 64, 384 download: bool = False, 385 return_test_crops: bool = False, 386) -> List[str]: 387 """Get the paths to CellMap training data. 388 389 Args: 390 path: Filepath to a folder where the data will be downloaded for further processing 391 organelles: The choice of organelles to download. By default, loads all types of labels available. 392 For one for multiple organelles, specify either like 'mito' or ['mito', 'cell']. 393 crops: The choice of crops to download. By default, downloads `all` crops. 394 For multiple crops, provide the crop ids as a sequence of crop ids. 395 resolution: The choice of resolution. By default, downloads the highest resolution: `s0`. 396 padding: The choice of padding along each dimensions. 397 By default, it pads '64' pixels along all dimensions. 398 You can set it to '0' for no padding at all. 399 For pixel regions without annotations, it labels the masks with id '-1'. 400 download: Whether to download the data if it is not present. 401 return_test_crops: Whether to forcefully return the filepaths of the test crops for other analysis. 402 403 Returns: 404 List of the cropped volume data paths. 405 """ 406 407 if not return_test_crops and ("test" in crops if isinstance(crops, (List, Tuple)) else crops == "test"): 408 raise NotImplementedError("The 'test' crops cannot be used in the dataloader.") 409 410 # Get the CellMap data crops. 411 data_path, crops = get_cellmap_data( 412 path=path, organelles=organelles, crops=crops, resolution=resolution, padding=padding, download=download 413 ) 414 415 # Get all crops. 416 volume_paths = [os.path.join(data_path, f"crop_{c}.h5") for c in crops] 417 418 # Check whether all volume paths exist. 419 for volume_path in volume_paths: 420 if not os.path.exists(volume_path): 421 raise FileNotFoundError(f"The volume '{volume_path}' could not be found.") 422 423 return volume_paths 424 425 426def get_cellmap_dataset( 427 path: Union[os.PathLike, str], 428 patch_shape: Tuple[int, ...], 429 organelles: Optional[Union[str, List[str]]] = None, 430 crops: Union[str, Sequence[str]] = "all", 431 resolution: str = "s0", 432 padding: int = 64, 433 download: bool = False, 434 **kwargs, 435) -> Dataset: 436 """Get the dataset for the CellMap training data for organelle segmentation. 437 438 Args: 439 path: Filepath to a folder where the data will be downloaded for further processing. 440 patch_shape: The patch shape to use for training. 441 organelles: The choice of organelles to download. By default, loads all types of labels available. 442 For one for multiple organelles, specify either like 'mito' or ['mito', 'cell']. 443 crops: The choice of crops to download. By default, downloads `all` crops. 444 For multiple crops, provide the crop ids as a sequence of crop ids. 445 resolution: The choice of resolution. By default, downloads the highest resolution: `s0`. 446 padding: The choice of padding along each dimensions. 447 By default, it pads '64' pixels along all dimensions. 448 You can set it to '0' for no padding at all. 449 For pixel regions without annotations, it labels the masks with id '-1'. 450 download: Whether to download the data if it is not present. 451 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 452 453 Returns: 454 The segmentation dataset. 455 """ 456 volume_paths = get_cellmap_paths( 457 path=path, organelles=organelles, crops=crops, resolution=resolution, padding=padding, download=download 458 ) 459 460 # Arrange the organelle choices as expected for loading labels. 461 if organelles is None: 462 organelles = "label_crop/all" 463 else: 464 if isinstance(organelles, str): 465 organelles = f"label_crop/{organelles}" 466 else: 467 organelles = [f"label_crop/{curr_organelle}" for curr_organelle in organelles] 468 kwargs = util.update_kwargs(kwargs, "with_label_channels", True) 469 470 return torch_em.default_segmentation_dataset( 471 raw_paths=volume_paths, 472 raw_key="raw_crop", 473 label_paths=volume_paths, 474 label_key=organelles, 475 patch_shape=patch_shape, 476 is_seg_dataset=True, 477 **kwargs 478 ) 479 480 481def get_cellmap_loader( 482 path: Union[os.PathLike, str], 483 batch_size: int, 484 patch_shape: Tuple[int, ...], 485 organelles: Optional[Union[str, List[str]]] = None, 486 crops: Union[str, Sequence[str]] = "all", 487 resolution: str = "s0", 488 padding: int = 64, 489 download: bool = False, 490 **kwargs, 491) -> DataLoader: 492 """Get the dataloader for the CellMap training data for organelle segmentation. 493 494 Args: 495 path: Filepath to a folder where the data will be downloaded for further processing. 496 batch_size: The batch size for training. 497 patch_shape: The patch shape to use for training. 498 organelles: The choice of organelles to download. By default, loads all types of labels available. 499 For one for multiple organelles, specify either like 'mito' or ['mito', 'cell']. 500 crops: The choice of crops to download. By default, downloads `all` crops. 501 For multiple crops, provide the crop ids as a sequence of crop ids. 502 resolution: The choice of resolution. By default, downloads the highest resolution: `s0`. 503 padding: The choice of padding along each dimensions. 504 By default, it pads '64' pixels along all dimensions. 505 You can set it to '0' for no padding at all. 506 For pixel regions without annotations, it labels the masks with id '-1'. 507 download: Whether to download the data if it is not present. 508 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 509 510 Returns: 511 The DataLoader. 512 """ 513 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 514 dataset = get_cellmap_dataset(path, patch_shape, organelles, crops, resolution, padding, download, **ds_kwargs) 515 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
def
get_cellmap_data( path: Union[os.PathLike, str], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', padding: int = 64, download: bool = False) -> Tuple[str, List[str]]:
299def get_cellmap_data( 300 path: Union[os.PathLike, str], 301 organelles: Optional[Union[str, List[str]]] = None, 302 crops: Union[str, Sequence[str]] = "all", 303 resolution: str = "s0", 304 padding: int = 64, 305 download: bool = False, 306) -> Tuple[str, List[str]]: 307 """Downloads the CellMap training data. 308 309 Args: 310 path: Filepath to a folder where the data will be downloaded for further processing 311 organelles: The choice of organelles to download. By default, loads all types of labels available. 312 For one for multiple organelles, specify either like 'mito' or ['mito', 'cell']. 313 crops: The choice of crops to download. By default, downloads `all` crops. 314 For multiple crops, provide the crop ids as a sequence of crop ids. 315 resolution: The choice of resolution. By default, downloads the highest resolution: `s0`. 316 padding: The choice of padding along each dimensions. 317 By default, it pads '64' pixels along all dimensions. 318 You can set it to '0' for no padding at all. 319 For pixel regions without annotations, it labels the masks with id '-1'. 320 download: Whether to download the data if it is not present. 321 322 Returns: 323 Filepath where the data is stored for further processing. 324 List of crop ids. 325 """ 326 327 data_path = os.path.join(path, "data_crops") 328 os.makedirs(data_path, exist_ok=True) 329 330 # Get the crops in 'csc' desired format. 331 if isinstance(crops, Sequence) and not isinstance(crops, str): # for multiple values 332 crops = ",".join(str(c) for c in crops) 333 334 # NOTE: The function below is comparable to the CLI `csc fetch-data` from the original repo. 335 _data_path, final_crops = _download_cellmap_data( 336 path=data_path, 337 crops=crops, 338 resolution=resolution, 339 padding=padding, 340 download=download, 341 ) 342 343 # Get the organelle-crop mapping. 344 from cellmap_segmentation_challenge import utils 345 346 # There is a file named 'train_crop_manifest' in the 'utils' sub-module. We need to get that first 347 train_metadata_file = os.path.join(str(Path(utils.__file__).parent / "train_crop_manifest.csv")) 348 train_metadata = pd.read_csv(train_metadata_file) 349 350 # Let's get the label to crop mapping from the manifest file. 351 organelle_to_crops = train_metadata.groupby('class_label')['crop_name'].apply(list).to_dict() 352 353 # By default, 'organelles' set to 'None' will give you 'all' organelle types. 354 if organelles is not None: # The assumption here is that the user wants specific organelle(s). 355 # Validate whether the organelle exists in the desired crops at all. 356 if isinstance(organelles, str): 357 organelles = [organelles] 358 359 # Next, we check whether they match the crops. 360 for curr_organelle in organelles: 361 if curr_organelle not in organelle_to_crops: # Check whether the organelle is valid or not. 362 raise ValueError(f"The chosen organelle: '{curr_organelle}' seems to be an invalid choice.") 363 364 # Lastly, we check whether the final crops have the organelle(s) or not. 365 # Otherwise, we throw a warning and go ahead with the true valid choices. 366 # NOTE: The priority below is higher for organelles than crops. 367 for curr_crop in final_crops: 368 if curr_crop not in organelle_to_crops.get(curr_organelle): 369 raise ValueError(f"The crop '{curr_crop}' does not have the chosen organelle '{curr_organelle}'.") 370 371 if _data_path is None or len(_data_path) == 0: 372 raise RuntimeError("Something went wrong. Please read the information logged above.") 373 374 assert len(final_crops) > 0, "There seems to be no valid crops in the list." 375 376 return data_path, final_crops
Downloads the CellMap training data.
Arguments:
- path: Filepath to a folder where the data will be downloaded for further processing
- organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
- crops: The choice of crops to download. By default, downloads
all
crops. For multiple crops, provide the crop ids as a sequence of crop ids. - resolution: The choice of resolution. By default, downloads the highest resolution:
s0
. - padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is stored for further processing. List of crop ids.
def
get_cellmap_paths( path: Union[os.PathLike, str], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', padding: int = 64, download: bool = False, return_test_crops: bool = False) -> List[str]:
379def get_cellmap_paths( 380 path: Union[os.PathLike, str], 381 organelles: Optional[Union[str, List[str]]] = None, 382 crops: Union[str, Sequence[str]] = "all", 383 resolution: str = "s0", 384 padding: int = 64, 385 download: bool = False, 386 return_test_crops: bool = False, 387) -> List[str]: 388 """Get the paths to CellMap training data. 389 390 Args: 391 path: Filepath to a folder where the data will be downloaded for further processing 392 organelles: The choice of organelles to download. By default, loads all types of labels available. 393 For one for multiple organelles, specify either like 'mito' or ['mito', 'cell']. 394 crops: The choice of crops to download. By default, downloads `all` crops. 395 For multiple crops, provide the crop ids as a sequence of crop ids. 396 resolution: The choice of resolution. By default, downloads the highest resolution: `s0`. 397 padding: The choice of padding along each dimensions. 398 By default, it pads '64' pixels along all dimensions. 399 You can set it to '0' for no padding at all. 400 For pixel regions without annotations, it labels the masks with id '-1'. 401 download: Whether to download the data if it is not present. 402 return_test_crops: Whether to forcefully return the filepaths of the test crops for other analysis. 403 404 Returns: 405 List of the cropped volume data paths. 406 """ 407 408 if not return_test_crops and ("test" in crops if isinstance(crops, (List, Tuple)) else crops == "test"): 409 raise NotImplementedError("The 'test' crops cannot be used in the dataloader.") 410 411 # Get the CellMap data crops. 412 data_path, crops = get_cellmap_data( 413 path=path, organelles=organelles, crops=crops, resolution=resolution, padding=padding, download=download 414 ) 415 416 # Get all crops. 417 volume_paths = [os.path.join(data_path, f"crop_{c}.h5") for c in crops] 418 419 # Check whether all volume paths exist. 420 for volume_path in volume_paths: 421 if not os.path.exists(volume_path): 422 raise FileNotFoundError(f"The volume '{volume_path}' could not be found.") 423 424 return volume_paths
Get the paths to CellMap training data.
Arguments:
- path: Filepath to a folder where the data will be downloaded for further processing
- organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
- crops: The choice of crops to download. By default, downloads
all
crops. For multiple crops, provide the crop ids as a sequence of crop ids. - resolution: The choice of resolution. By default, downloads the highest resolution:
s0
. - padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
- download: Whether to download the data if it is not present.
- return_test_crops: Whether to forcefully return the filepaths of the test crops for other analysis.
Returns:
List of the cropped volume data paths.
def
get_cellmap_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', padding: int = 64, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
427def get_cellmap_dataset( 428 path: Union[os.PathLike, str], 429 patch_shape: Tuple[int, ...], 430 organelles: Optional[Union[str, List[str]]] = None, 431 crops: Union[str, Sequence[str]] = "all", 432 resolution: str = "s0", 433 padding: int = 64, 434 download: bool = False, 435 **kwargs, 436) -> Dataset: 437 """Get the dataset for the CellMap training data for organelle segmentation. 438 439 Args: 440 path: Filepath to a folder where the data will be downloaded for further processing. 441 patch_shape: The patch shape to use for training. 442 organelles: The choice of organelles to download. By default, loads all types of labels available. 443 For one for multiple organelles, specify either like 'mito' or ['mito', 'cell']. 444 crops: The choice of crops to download. By default, downloads `all` crops. 445 For multiple crops, provide the crop ids as a sequence of crop ids. 446 resolution: The choice of resolution. By default, downloads the highest resolution: `s0`. 447 padding: The choice of padding along each dimensions. 448 By default, it pads '64' pixels along all dimensions. 449 You can set it to '0' for no padding at all. 450 For pixel regions without annotations, it labels the masks with id '-1'. 451 download: Whether to download the data if it is not present. 452 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 453 454 Returns: 455 The segmentation dataset. 456 """ 457 volume_paths = get_cellmap_paths( 458 path=path, organelles=organelles, crops=crops, resolution=resolution, padding=padding, download=download 459 ) 460 461 # Arrange the organelle choices as expected for loading labels. 462 if organelles is None: 463 organelles = "label_crop/all" 464 else: 465 if isinstance(organelles, str): 466 organelles = f"label_crop/{organelles}" 467 else: 468 organelles = [f"label_crop/{curr_organelle}" for curr_organelle in organelles] 469 kwargs = util.update_kwargs(kwargs, "with_label_channels", True) 470 471 return torch_em.default_segmentation_dataset( 472 raw_paths=volume_paths, 473 raw_key="raw_crop", 474 label_paths=volume_paths, 475 label_key=organelles, 476 patch_shape=patch_shape, 477 is_seg_dataset=True, 478 **kwargs 479 )
Get the dataset for the CellMap training data for organelle segmentation.
Arguments:
- path: Filepath to a folder where the data will be downloaded for further processing.
- patch_shape: The patch shape to use for training.
- organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
- crops: The choice of crops to download. By default, downloads
all
crops. For multiple crops, provide the crop ids as a sequence of crop ids. - resolution: The choice of resolution. By default, downloads the highest resolution:
s0
. - padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_cellmap_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], organelles: Union[List[str], str, NoneType] = None, crops: Union[str, Sequence[str]] = 'all', resolution: str = 's0', padding: int = 64, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
482def get_cellmap_loader( 483 path: Union[os.PathLike, str], 484 batch_size: int, 485 patch_shape: Tuple[int, ...], 486 organelles: Optional[Union[str, List[str]]] = None, 487 crops: Union[str, Sequence[str]] = "all", 488 resolution: str = "s0", 489 padding: int = 64, 490 download: bool = False, 491 **kwargs, 492) -> DataLoader: 493 """Get the dataloader for the CellMap training data for organelle segmentation. 494 495 Args: 496 path: Filepath to a folder where the data will be downloaded for further processing. 497 batch_size: The batch size for training. 498 patch_shape: The patch shape to use for training. 499 organelles: The choice of organelles to download. By default, loads all types of labels available. 500 For one for multiple organelles, specify either like 'mito' or ['mito', 'cell']. 501 crops: The choice of crops to download. By default, downloads `all` crops. 502 For multiple crops, provide the crop ids as a sequence of crop ids. 503 resolution: The choice of resolution. By default, downloads the highest resolution: `s0`. 504 padding: The choice of padding along each dimensions. 505 By default, it pads '64' pixels along all dimensions. 506 You can set it to '0' for no padding at all. 507 For pixel regions without annotations, it labels the masks with id '-1'. 508 download: Whether to download the data if it is not present. 509 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 510 511 Returns: 512 The DataLoader. 513 """ 514 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 515 dataset = get_cellmap_dataset(path, patch_shape, organelles, crops, resolution, padding, download, **ds_kwargs) 516 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the dataloader for the CellMap training data for organelle segmentation.
Arguments:
- path: Filepath to a folder where the data will be downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- organelles: The choice of organelles to download. By default, loads all types of labels available. For one for multiple organelles, specify either like 'mito' or ['mito', 'cell'].
- crops: The choice of crops to download. By default, downloads
all
crops. For multiple crops, provide the crop ids as a sequence of crop ids. - resolution: The choice of resolution. By default, downloads the highest resolution:
s0
. - padding: The choice of padding along each dimensions. By default, it pads '64' pixels along all dimensions. You can set it to '0' for no padding at all. For pixel regions without annotations, it labels the masks with id '-1'.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.