torch_em.shallow2deep.shallow2deep_dataset
1import os 2import pickle 3import warnings 4from glob import glob 5from typing import Callable, Dict, Optional, Sequence, Tuple, Union 6 7import numpy as np 8import torch 9from torch_em.segmentation import (check_paths, is_segmentation_dataset, 10 get_data_loader, get_raw_transform, 11 samples_to_datasets, _get_default_transform) 12from torch_em.data import ConcatDataset, ImageCollectionDataset, SegmentationDataset 13from .prepare_shallow2deep import _get_filters, _apply_filters 14from ..util import ensure_tensor_with_channels, ensure_spatial_array 15 16 17class _Shallow2DeepBase: 18 _rf_paths = None 19 _filter_config = None 20 21 @property 22 def rf_paths(self): 23 return self._rf_paths 24 25 @rf_paths.setter 26 def rf_paths(self, value): 27 self._rf_paths = value 28 29 @property 30 def filter_config(self): 31 return self._filter_config 32 33 @filter_config.setter 34 def filter_config(self, value): 35 self._filter_config = value 36 37 @property 38 def rf_channels(self): 39 return self._rf_channels 40 41 @rf_channels.setter 42 def rf_channels(self, value): 43 if isinstance(value, int): 44 self.rf_channels = (value,) 45 else: 46 assert isinstance(value, tuple) 47 self._rf_channels = value 48 49 def _predict(self, raw, rf, filters_and_sigmas): 50 features = _apply_filters(raw, filters_and_sigmas) 51 assert rf.n_features_in_ == features.shape[1], f"{rf.n_features_in_}, {features.shape[1]}" 52 53 try: 54 pred_ = rf.predict_proba(features) 55 assert pred_.shape[1] > max(self.rf_channels), f"{pred_.shape}, {self.rf_channels}" 56 pred_ = pred_[:, self.rf_channels] 57 except IndexError: 58 warnings.warn(f"Random forest prediction failed for input features of shape: {features.shape}") 59 pred_shape = (len(features), len(self.rf_channels)) 60 pred_ = np.zeros(pred_shape, dtype="float32") 61 62 spatial_shape = raw.shape 63 out_shape = (len(self.rf_channels),) + spatial_shape 64 prediction = np.zeros(out_shape, dtype="float32") 65 for chan in range(pred_.shape[1]): 66 prediction[chan] = pred_[:, chan].reshape(spatial_shape) 67 68 return prediction 69 70 def _predict_rf(self, raw): 71 n_rfs = len(self._rf_paths) 72 rf_path = self._rf_paths[np.random.randint(0, n_rfs)] 73 with open(rf_path, "rb") as f: 74 rf = pickle.load(f) 75 filters_and_sigmas = _get_filters(self.ndim, self._filter_config) 76 return self._predict(raw, rf, filters_and_sigmas) 77 78 def _predict_rf_anisotropic(self, raw): 79 n_rfs = len(self._rf_paths) 80 rf_path = self._rf_paths[np.random.randint(0, n_rfs)] 81 with open(rf_path, "rb") as f: 82 rf = pickle.load(f) 83 filters_and_sigmas = _get_filters(2, self._filter_config) 84 85 n_channels = len(self.rf_channels) 86 prediction = np.zeros((n_channels,) + raw.shape, dtype="float32") 87 for z in range(raw.shape[0]): 88 pred = self._predict(raw[z], rf, filters_and_sigmas) 89 prediction[:, z] = pred 90 91 return prediction 92 93 94class Shallow2DeepDataset(SegmentationDataset, _Shallow2DeepBase): 95 """@private 96 """ 97 def __getitem__(self, index): 98 assert self._rf_paths is not None 99 raw, labels = self._get_sample(index) 100 initial_label_dtype = labels.dtype 101 102 if self.raw_transform is not None: 103 raw = self.raw_transform(raw) 104 if self.label_transform is not None: 105 labels = self.label_transform(labels) 106 if self.transform is not None: 107 raw, labels = self.transform(raw, labels) 108 if self.trafo_halo is not None: 109 raw = self.crop(raw) 110 labels = self.crop(labels) 111 # support enlarging bounding box here as well (for affinity transform) ? 112 if self.label_transform2 is not None: 113 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 114 labels = self.label_transform2(labels) 115 116 if isinstance(raw, (list, tuple)): # this can be a list or tuple due to transforms 117 assert len(raw) == 1 118 raw = raw[0] 119 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 120 if raw.shape[0] > 1: 121 raise NotImplementedError( 122 f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels" 123 ) 124 125 # NOTE we assume single channel raw data here; this needs to be changed for multi-channel 126 if getattr(self, "is_anisotropic", False): 127 prediction = self._predict_rf_anisotropic(raw[0].numpy()) 128 else: 129 prediction = self._predict_rf(raw[0].numpy()) 130 prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype) 131 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 132 return prediction, labels 133 134 135class Shallow2DeepImageCollectionDataset(ImageCollectionDataset, _Shallow2DeepBase): 136 """@private 137 """ 138 def __getitem__(self, index): 139 raw, labels = self._get_sample(index) 140 initial_label_dtype = labels.dtype 141 142 if self.raw_transform is not None: 143 raw = self.raw_transform(raw) 144 145 if self.label_transform is not None: 146 labels = self.label_transform(labels) 147 148 if self.transform is not None: 149 raw, labels = self.transform(raw, labels) 150 151 # support enlarging bounding box here as well (for affinity transform) ? 152 if self.label_transform2 is not None: 153 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 154 labels = self.label_transform2(labels) 155 156 if isinstance(raw, (list, tuple)): # this can be a list or tuple due to transforms 157 assert len(raw) == 1 158 raw = raw[0] 159 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 160 if raw.shape[0] > 1: 161 raise NotImplementedError( 162 f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels" 163 ) 164 165 # NOTE we assume single channel raw data here; this needs to be changed for multi-channel 166 prediction = self._predict_rf(raw[0].numpy()) 167 prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype) 168 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 169 return prediction, labels 170 171 172def _load_shallow2deep_segmentation_dataset( 173 raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, ndim, **kwargs 174): 175 rois = kwargs.pop("rois", None) 176 filter_config = kwargs.pop("filter_config", None) 177 if ndim == "anisotropic": 178 ndim = 3 179 is_anisotropic = True 180 else: 181 is_anisotropic = False 182 183 if isinstance(raw_paths, str): 184 if rois is not None: 185 assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois) 186 ds = Shallow2DeepDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, ndim=ndim, **kwargs) 187 ds.rf_paths = rf_paths 188 ds.filter_config = filter_config 189 ds.rf_channels = rf_channels 190 ds.is_anisotropic = is_anisotropic 191 else: 192 assert len(raw_paths) > 0 193 if rois is not None: 194 assert len(rois) == len(label_paths), f"{len(rois)}, {len(label_paths)}" 195 assert all(isinstance(roi, tuple) for roi in rois) 196 n_samples = kwargs.pop("n_samples", None) 197 198 samples_per_ds = ( 199 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 200 ) 201 ds = [] 202 for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): 203 roi = None if rois is None else rois[i] 204 dset = Shallow2DeepDataset( 205 raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], ndim=ndim, **kwargs 206 ) 207 dset.rf_paths = rf_paths 208 dset.filter_config = filter_config 209 dset.rf_channels = rf_channels 210 dset.is_anisotropic = is_anisotropic 211 ds.append(dset) 212 ds = ConcatDataset(*ds) 213 return ds 214 215 216def _load_shallow2deep_image_collection_dataset( 217 raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape, **kwargs 218): 219 if isinstance(raw_paths, str): 220 assert isinstance(label_paths, str) 221 raw_file_paths = glob(os.path.join(raw_paths, raw_key)) 222 raw_file_paths.sort() 223 label_file_paths = glob(os.path.join(label_paths, label_key)) 224 label_file_paths.sort() 225 ds = Shallow2DeepImageCollectionDataset(raw_file_paths, label_file_paths, patch_shape, **kwargs) 226 elif isinstance(raw_paths, list) and raw_key is None: 227 assert isinstance(label_paths, list) 228 assert label_key is None 229 assert all(os.path.exists(pp) for pp in raw_paths) 230 assert all(os.path.exists(pp) for pp in label_paths) 231 ds = Shallow2DeepImageCollectionDataset(raw_paths, label_paths, patch_shape, **kwargs) 232 else: 233 raise NotImplementedError 234 235 filter_config = kwargs.pop("filter_config", None) 236 ds.rf_paths = rf_paths 237 ds.filter_config = filter_config 238 ds.rf_channels = rf_channels 239 return ds 240 241 242def get_shallow2deep_dataset( 243 raw_paths: Union[str, Sequence[str]], 244 raw_key: Optional[str], 245 label_paths: Union[str, Sequence[str]], 246 label_key: Optional[str], 247 rf_paths: Sequence[str], 248 patch_shape: Tuple[int, ...], 249 raw_transform: Optional[Callable] = None, 250 label_transform: Optional[Callable] = None, 251 transform: Optional[Callable] = None, 252 dtype: Union[str, torch.dtype] = torch.float32, 253 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 254 n_samples: Optional[int] = None, 255 sampler: Optional[Callable] = None, 256 ndim: Optional[int] = None, 257 is_seg_dataset: Optional[bool] = None, 258 with_channels: bool = False, 259 filter_config: Optional[Dict] = None, 260 rf_channels: Tuple[int, ...] = (1,), 261) -> torch.utils.data.Dataset: 262 """Get a dataset for shallow2deep enhancer training. 263 264 Args: 265 raw_paths: The file paths to the image data. May also be a single file. 266 raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif. 267 label_paths: The file paths to the label data. May also be a single file. 268 label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif. 269 rf_paths: The file paths to the pretrained random forests. 270 patch_shape: The patch shape to load for a sample. 271 raw_transform: The transform to apply to the raw data. 272 label_transform: The transform to apply to the label data. 273 transform: The transform to apply to raw and label data, e.g. to implement augmentations. 274 dtype: The data type for the raw data. 275 rois: The regions of interest for the data. 276 n_samples: The length of this dataset. 277 sampler: A sampler to reject samples based on a pre-defined criterion. 278 ndim: The dimensionality of the data. 279 is_seg_dataset: Whether this is a segmentation or an image collection dataset. 280 If set to None, this will be determined from the data. 281 with_channels: Whether the raw data has channels. 282 filter_config: The filter configuration for the random forest. 283 rf_channels: The random forest channel to use as input for the enhancer model. 284 285 Returns: 286 The dataset. 287 """ 288 check_paths(raw_paths, label_paths) 289 if is_seg_dataset is None: 290 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 291 292 # we always use a raw transform in the convenience function 293 if raw_transform is None: 294 raw_transform = get_raw_transform() 295 296 # we always use augmentations in the convenience function 297 if transform is None: 298 transform = _get_default_transform( 299 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, 300 3 if ndim == "anisotropic" else ndim 301 ) 302 303 if is_seg_dataset: 304 ds = _load_shallow2deep_segmentation_dataset( 305 raw_paths, 306 raw_key, 307 label_paths, 308 label_key, 309 rf_paths, 310 patch_shape=patch_shape, 311 raw_transform=raw_transform, 312 label_transform=label_transform, 313 transform=transform, 314 rois=rois, 315 n_samples=n_samples, 316 sampler=sampler, 317 ndim=ndim, 318 dtype=dtype, 319 with_channels=with_channels, 320 filter_config=filter_config, 321 rf_channels=rf_channels, 322 ) 323 else: 324 if rois is not None: 325 raise NotImplementedError 326 ds = _load_shallow2deep_image_collection_dataset( 327 raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape, 328 raw_transform=raw_transform, label_transform=label_transform, 329 transform=transform, dtype=dtype, n_samples=n_samples, 330 ) 331 return ds 332 333 334def get_shallow2deep_loader( 335 raw_paths: Union[str, Sequence[str]], 336 raw_key: Optional[str], 337 label_paths: Union[str, Sequence[str]], 338 label_key: Optional[str], 339 rf_paths: Sequence[str], 340 batch_size: int, 341 patch_shape: Tuple[int, ...], 342 raw_transform: Optional[Callable] = None, 343 label_transform: Optional[Callable] = None, 344 transform: Optional[Callable] = None, 345 dtype: Union[str, torch.dtype] = torch.float32, 346 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 347 n_samples: Optional[int] = None, 348 sampler: Optional[Callable] = None, 349 ndim: Optional[int] = None, 350 is_seg_dataset: Optional[bool] = None, 351 with_channels: bool = False, 352 filter_config: Optional[Dict] = None, 353 rf_channels: Tuple[int, ...] = (1,), 354 **loader_kwargs, 355) -> torch.utils.data.DataLoader: 356 """Get a dataloader for shallow2deep enhancer training. 357 358 Args: 359 raw_paths: The file paths to the image data. May also be a single file. 360 raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif. 361 label_paths: The file paths to the label data. May also be a single file. 362 label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif. 363 rf_paths: The file paths to the pretrained random forests. 364 batch_size: The batch size for the data loader. 365 patch_shape: The patch shape to load for a sample. 366 raw_transform: The transform to apply to the raw data. 367 label_transform: The transform to apply to the label data. 368 transform: The transform to apply to raw and label data, e.g. to implement augmentations. 369 dtype: The data type for the raw data. 370 rois: The regions of interest for the data. 371 n_samples: The length of this dataset. 372 sampler: A sampler to reject samples based on a pre-defined criterion. 373 ndim: The dimensionality of the data. 374 is_seg_dataset: Whether this is a segmentation or an image collection dataset. 375 If set to None, this will be determined from the data. 376 with_channels: Whether the raw data has channels. 377 filter_config: The filter configuration for the random forest. 378 rf_channels: The random forest channel to use as input for the enhancer model. 379 loader_kwargs: The keyword arguments for the data loader. 380 381 Returns: 382 The dataloader 383 """ 384 ds = get_shallow2deep_dataset( 385 raw_paths=raw_paths, 386 raw_key=raw_key, 387 label_paths=label_paths, 388 label_key=label_key, 389 rf_paths=rf_paths, 390 patch_shape=patch_shape, 391 raw_transform=raw_transform, 392 label_transform=label_transform, 393 transform=transform, 394 rois=rois, 395 n_samples=n_samples, 396 sampler=sampler, 397 ndim=ndim, 398 is_seg_dataset=is_seg_dataset, 399 with_channels=with_channels, 400 filter_config=filter_config, 401 rf_channels=rf_channels, 402 ) 403 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
def
get_shallow2deep_dataset( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], label_paths: Union[str, Sequence[str]], label_key: Optional[str], rf_paths: Sequence[str], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: Union[str, torch.dtype] = torch.float32, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, filter_config: Optional[Dict] = None, rf_channels: Tuple[int, ...] = (1,)) -> torch.utils.data.dataset.Dataset:
243def get_shallow2deep_dataset( 244 raw_paths: Union[str, Sequence[str]], 245 raw_key: Optional[str], 246 label_paths: Union[str, Sequence[str]], 247 label_key: Optional[str], 248 rf_paths: Sequence[str], 249 patch_shape: Tuple[int, ...], 250 raw_transform: Optional[Callable] = None, 251 label_transform: Optional[Callable] = None, 252 transform: Optional[Callable] = None, 253 dtype: Union[str, torch.dtype] = torch.float32, 254 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 255 n_samples: Optional[int] = None, 256 sampler: Optional[Callable] = None, 257 ndim: Optional[int] = None, 258 is_seg_dataset: Optional[bool] = None, 259 with_channels: bool = False, 260 filter_config: Optional[Dict] = None, 261 rf_channels: Tuple[int, ...] = (1,), 262) -> torch.utils.data.Dataset: 263 """Get a dataset for shallow2deep enhancer training. 264 265 Args: 266 raw_paths: The file paths to the image data. May also be a single file. 267 raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif. 268 label_paths: The file paths to the label data. May also be a single file. 269 label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif. 270 rf_paths: The file paths to the pretrained random forests. 271 patch_shape: The patch shape to load for a sample. 272 raw_transform: The transform to apply to the raw data. 273 label_transform: The transform to apply to the label data. 274 transform: The transform to apply to raw and label data, e.g. to implement augmentations. 275 dtype: The data type for the raw data. 276 rois: The regions of interest for the data. 277 n_samples: The length of this dataset. 278 sampler: A sampler to reject samples based on a pre-defined criterion. 279 ndim: The dimensionality of the data. 280 is_seg_dataset: Whether this is a segmentation or an image collection dataset. 281 If set to None, this will be determined from the data. 282 with_channels: Whether the raw data has channels. 283 filter_config: The filter configuration for the random forest. 284 rf_channels: The random forest channel to use as input for the enhancer model. 285 286 Returns: 287 The dataset. 288 """ 289 check_paths(raw_paths, label_paths) 290 if is_seg_dataset is None: 291 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 292 293 # we always use a raw transform in the convenience function 294 if raw_transform is None: 295 raw_transform = get_raw_transform() 296 297 # we always use augmentations in the convenience function 298 if transform is None: 299 transform = _get_default_transform( 300 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, 301 3 if ndim == "anisotropic" else ndim 302 ) 303 304 if is_seg_dataset: 305 ds = _load_shallow2deep_segmentation_dataset( 306 raw_paths, 307 raw_key, 308 label_paths, 309 label_key, 310 rf_paths, 311 patch_shape=patch_shape, 312 raw_transform=raw_transform, 313 label_transform=label_transform, 314 transform=transform, 315 rois=rois, 316 n_samples=n_samples, 317 sampler=sampler, 318 ndim=ndim, 319 dtype=dtype, 320 with_channels=with_channels, 321 filter_config=filter_config, 322 rf_channels=rf_channels, 323 ) 324 else: 325 if rois is not None: 326 raise NotImplementedError 327 ds = _load_shallow2deep_image_collection_dataset( 328 raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape, 329 raw_transform=raw_transform, label_transform=label_transform, 330 transform=transform, dtype=dtype, n_samples=n_samples, 331 ) 332 return ds
Get a dataset for shallow2deep enhancer training.
Arguments:
- raw_paths: The file paths to the image data. May also be a single file.
- raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif.
- label_paths: The file paths to the label data. May also be a single file.
- label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif.
- rf_paths: The file paths to the pretrained random forests.
- patch_shape: The patch shape to load for a sample.
- raw_transform: The transform to apply to the raw data.
- label_transform: The transform to apply to the label data.
- transform: The transform to apply to raw and label data, e.g. to implement augmentations.
- dtype: The data type for the raw data.
- rois: The regions of interest for the data.
- n_samples: The length of this dataset.
- sampler: A sampler to reject samples based on a pre-defined criterion.
- ndim: The dimensionality of the data.
- is_seg_dataset: Whether this is a segmentation or an image collection dataset. If set to None, this will be determined from the data.
- with_channels: Whether the raw data has channels.
- filter_config: The filter configuration for the random forest.
- rf_channels: The random forest channel to use as input for the enhancer model.
Returns:
The dataset.
def
get_shallow2deep_loader( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], label_paths: Union[str, Sequence[str]], label_key: Optional[str], rf_paths: Sequence[str], batch_size: int, patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: Union[str, torch.dtype] = torch.float32, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, filter_config: Optional[Dict] = None, rf_channels: Tuple[int, ...] = (1,), **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
335def get_shallow2deep_loader( 336 raw_paths: Union[str, Sequence[str]], 337 raw_key: Optional[str], 338 label_paths: Union[str, Sequence[str]], 339 label_key: Optional[str], 340 rf_paths: Sequence[str], 341 batch_size: int, 342 patch_shape: Tuple[int, ...], 343 raw_transform: Optional[Callable] = None, 344 label_transform: Optional[Callable] = None, 345 transform: Optional[Callable] = None, 346 dtype: Union[str, torch.dtype] = torch.float32, 347 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 348 n_samples: Optional[int] = None, 349 sampler: Optional[Callable] = None, 350 ndim: Optional[int] = None, 351 is_seg_dataset: Optional[bool] = None, 352 with_channels: bool = False, 353 filter_config: Optional[Dict] = None, 354 rf_channels: Tuple[int, ...] = (1,), 355 **loader_kwargs, 356) -> torch.utils.data.DataLoader: 357 """Get a dataloader for shallow2deep enhancer training. 358 359 Args: 360 raw_paths: The file paths to the image data. May also be a single file. 361 raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif. 362 label_paths: The file paths to the label data. May also be a single file. 363 label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif. 364 rf_paths: The file paths to the pretrained random forests. 365 batch_size: The batch size for the data loader. 366 patch_shape: The patch shape to load for a sample. 367 raw_transform: The transform to apply to the raw data. 368 label_transform: The transform to apply to the label data. 369 transform: The transform to apply to raw and label data, e.g. to implement augmentations. 370 dtype: The data type for the raw data. 371 rois: The regions of interest for the data. 372 n_samples: The length of this dataset. 373 sampler: A sampler to reject samples based on a pre-defined criterion. 374 ndim: The dimensionality of the data. 375 is_seg_dataset: Whether this is a segmentation or an image collection dataset. 376 If set to None, this will be determined from the data. 377 with_channels: Whether the raw data has channels. 378 filter_config: The filter configuration for the random forest. 379 rf_channels: The random forest channel to use as input for the enhancer model. 380 loader_kwargs: The keyword arguments for the data loader. 381 382 Returns: 383 The dataloader 384 """ 385 ds = get_shallow2deep_dataset( 386 raw_paths=raw_paths, 387 raw_key=raw_key, 388 label_paths=label_paths, 389 label_key=label_key, 390 rf_paths=rf_paths, 391 patch_shape=patch_shape, 392 raw_transform=raw_transform, 393 label_transform=label_transform, 394 transform=transform, 395 rois=rois, 396 n_samples=n_samples, 397 sampler=sampler, 398 ndim=ndim, 399 is_seg_dataset=is_seg_dataset, 400 with_channels=with_channels, 401 filter_config=filter_config, 402 rf_channels=rf_channels, 403 ) 404 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
Get a dataloader for shallow2deep enhancer training.
Arguments:
- raw_paths: The file paths to the image data. May also be a single file.
- raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif.
- label_paths: The file paths to the label data. May also be a single file.
- label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif.
- rf_paths: The file paths to the pretrained random forests.
- batch_size: The batch size for the data loader.
- patch_shape: The patch shape to load for a sample.
- raw_transform: The transform to apply to the raw data.
- label_transform: The transform to apply to the label data.
- transform: The transform to apply to raw and label data, e.g. to implement augmentations.
- dtype: The data type for the raw data.
- rois: The regions of interest for the data.
- n_samples: The length of this dataset.
- sampler: A sampler to reject samples based on a pre-defined criterion.
- ndim: The dimensionality of the data.
- is_seg_dataset: Whether this is a segmentation or an image collection dataset. If set to None, this will be determined from the data.
- with_channels: Whether the raw data has channels.
- filter_config: The filter configuration for the random forest.
- rf_channels: The random forest channel to use as input for the enhancer model.
- loader_kwargs: The keyword arguments for the data loader.
Returns:
The dataloader