torch_em.shallow2deep.shallow2deep_dataset

  1import os
  2import pickle
  3import warnings
  4from glob import glob
  5from typing import Callable, Dict, Optional, Sequence, Tuple, Union
  6
  7import numpy as np
  8import torch
  9from torch_em.segmentation import (check_paths, is_segmentation_dataset,
 10                                   get_data_loader, get_raw_transform,
 11                                   samples_to_datasets, _get_default_transform)
 12from torch_em.data import ConcatDataset, ImageCollectionDataset, SegmentationDataset
 13from .prepare_shallow2deep import _get_filters, _apply_filters
 14from ..util import ensure_tensor_with_channels, ensure_spatial_array
 15
 16
 17class _Shallow2DeepBase:
 18    _rf_paths = None
 19    _filter_config = None
 20
 21    @property
 22    def rf_paths(self):
 23        return self._rf_paths
 24
 25    @rf_paths.setter
 26    def rf_paths(self, value):
 27        self._rf_paths = value
 28
 29    @property
 30    def filter_config(self):
 31        return self._filter_config
 32
 33    @filter_config.setter
 34    def filter_config(self, value):
 35        self._filter_config = value
 36
 37    @property
 38    def rf_channels(self):
 39        return self._rf_channels
 40
 41    @rf_channels.setter
 42    def rf_channels(self, value):
 43        if isinstance(value, int):
 44            self.rf_channels = (value,)
 45        else:
 46            assert isinstance(value, tuple)
 47            self._rf_channels = value
 48
 49    def _predict(self, raw, rf, filters_and_sigmas):
 50        features = _apply_filters(raw, filters_and_sigmas)
 51        assert rf.n_features_in_ == features.shape[1], f"{rf.n_features_in_}, {features.shape[1]}"
 52
 53        try:
 54            pred_ = rf.predict_proba(features)
 55            assert pred_.shape[1] > max(self.rf_channels), f"{pred_.shape}, {self.rf_channels}"
 56            pred_ = pred_[:, self.rf_channels]
 57        except IndexError:
 58            warnings.warn(f"Random forest prediction failed for input features of shape: {features.shape}")
 59            pred_shape = (len(features), len(self.rf_channels))
 60            pred_ = np.zeros(pred_shape, dtype="float32")
 61
 62        spatial_shape = raw.shape
 63        out_shape = (len(self.rf_channels),) + spatial_shape
 64        prediction = np.zeros(out_shape, dtype="float32")
 65        for chan in range(pred_.shape[1]):
 66            prediction[chan] = pred_[:, chan].reshape(spatial_shape)
 67
 68        return prediction
 69
 70    def _predict_rf(self, raw):
 71        n_rfs = len(self._rf_paths)
 72        rf_path = self._rf_paths[np.random.randint(0, n_rfs)]
 73        with open(rf_path, "rb") as f:
 74            rf = pickle.load(f)
 75        filters_and_sigmas = _get_filters(self.ndim, self._filter_config)
 76        return self._predict(raw, rf, filters_and_sigmas)
 77
 78    def _predict_rf_anisotropic(self, raw):
 79        n_rfs = len(self._rf_paths)
 80        rf_path = self._rf_paths[np.random.randint(0, n_rfs)]
 81        with open(rf_path, "rb") as f:
 82            rf = pickle.load(f)
 83        filters_and_sigmas = _get_filters(2, self._filter_config)
 84
 85        n_channels = len(self.rf_channels)
 86        prediction = np.zeros((n_channels,) + raw.shape, dtype="float32")
 87        for z in range(raw.shape[0]):
 88            pred = self._predict(raw[z], rf, filters_and_sigmas)
 89            prediction[:, z] = pred
 90
 91        return prediction
 92
 93
 94class Shallow2DeepDataset(SegmentationDataset, _Shallow2DeepBase):
 95    """@private
 96    """
 97    def __getitem__(self, index):
 98        assert self._rf_paths is not None
 99        raw, labels = self._get_sample(index)
100        initial_label_dtype = labels.dtype
101
102        if self.raw_transform is not None:
103            raw = self.raw_transform(raw)
104        if self.label_transform is not None:
105            labels = self.label_transform(labels)
106        if self.transform is not None:
107            raw, labels = self.transform(raw, labels)
108            if self.trafo_halo is not None:
109                raw = self.crop(raw)
110                labels = self.crop(labels)
111        # support enlarging bounding box here as well (for affinity transform) ?
112        if self.label_transform2 is not None:
113            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
114            labels = self.label_transform2(labels)
115
116        if isinstance(raw, (list, tuple)):  # this can be a list or tuple due to transforms
117            assert len(raw) == 1
118            raw = raw[0]
119        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
120        if raw.shape[0] > 1:
121            raise NotImplementedError(
122                f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels"
123            )
124
125        # NOTE we assume single channel raw data here; this needs to be changed for multi-channel
126        if getattr(self, "is_anisotropic", False):
127            prediction = self._predict_rf_anisotropic(raw[0].numpy())
128        else:
129            prediction = self._predict_rf(raw[0].numpy())
130        prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype)
131        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
132        return prediction, labels
133
134
135class Shallow2DeepImageCollectionDataset(ImageCollectionDataset, _Shallow2DeepBase):
136    """@private
137    """
138    def __getitem__(self, index):
139        raw, labels = self._get_sample(index)
140        initial_label_dtype = labels.dtype
141
142        if self.raw_transform is not None:
143            raw = self.raw_transform(raw)
144
145        if self.label_transform is not None:
146            labels = self.label_transform(labels)
147
148        if self.transform is not None:
149            raw, labels = self.transform(raw, labels)
150
151        # support enlarging bounding box here as well (for affinity transform) ?
152        if self.label_transform2 is not None:
153            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
154            labels = self.label_transform2(labels)
155
156        if isinstance(raw, (list, tuple)):  # this can be a list or tuple due to transforms
157            assert len(raw) == 1
158            raw = raw[0]
159        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
160        if raw.shape[0] > 1:
161            raise NotImplementedError(
162                f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels"
163            )
164
165        # NOTE we assume single channel raw data here; this needs to be changed for multi-channel
166        prediction = self._predict_rf(raw[0].numpy())
167        prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype)
168        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
169        return prediction, labels
170
171
172def _load_shallow2deep_segmentation_dataset(
173    raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, ndim, **kwargs
174):
175    rois = kwargs.pop("rois", None)
176    filter_config = kwargs.pop("filter_config", None)
177    if ndim == "anisotropic":
178        ndim = 3
179        is_anisotropic = True
180    else:
181        is_anisotropic = False
182
183    if isinstance(raw_paths, str):
184        if rois is not None:
185            assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois)
186        ds = Shallow2DeepDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, ndim=ndim, **kwargs)
187        ds.rf_paths = rf_paths
188        ds.filter_config = filter_config
189        ds.rf_channels = rf_channels
190        ds.is_anisotropic = is_anisotropic
191    else:
192        assert len(raw_paths) > 0
193        if rois is not None:
194            assert len(rois) == len(label_paths), f"{len(rois)}, {len(label_paths)}"
195            assert all(isinstance(roi, tuple) for roi in rois)
196        n_samples = kwargs.pop("n_samples", None)
197
198        samples_per_ds = (
199            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
200        )
201        ds = []
202        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
203            roi = None if rois is None else rois[i]
204            dset = Shallow2DeepDataset(
205                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], ndim=ndim, **kwargs
206            )
207            dset.rf_paths = rf_paths
208            dset.filter_config = filter_config
209            dset.rf_channels = rf_channels
210            dset.is_anisotropic = is_anisotropic
211            ds.append(dset)
212        ds = ConcatDataset(*ds)
213    return ds
214
215
216def _load_shallow2deep_image_collection_dataset(
217    raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape, **kwargs
218):
219    if isinstance(raw_paths, str):
220        assert isinstance(label_paths, str)
221        raw_file_paths = glob(os.path.join(raw_paths, raw_key))
222        raw_file_paths.sort()
223        label_file_paths = glob(os.path.join(label_paths, label_key))
224        label_file_paths.sort()
225        ds = Shallow2DeepImageCollectionDataset(raw_file_paths, label_file_paths, patch_shape, **kwargs)
226    elif isinstance(raw_paths, list) and raw_key is None:
227        assert isinstance(label_paths, list)
228        assert label_key is None
229        assert all(os.path.exists(pp) for pp in raw_paths)
230        assert all(os.path.exists(pp) for pp in label_paths)
231        ds = Shallow2DeepImageCollectionDataset(raw_paths, label_paths, patch_shape, **kwargs)
232    else:
233        raise NotImplementedError
234
235    filter_config = kwargs.pop("filter_config", None)
236    ds.rf_paths = rf_paths
237    ds.filter_config = filter_config
238    ds.rf_channels = rf_channels
239    return ds
240
241
242def get_shallow2deep_dataset(
243    raw_paths: Union[str, Sequence[str]],
244    raw_key: Optional[str],
245    label_paths: Union[str, Sequence[str]],
246    label_key: Optional[str],
247    rf_paths: Sequence[str],
248    patch_shape: Tuple[int, ...],
249    raw_transform: Optional[Callable] = None,
250    label_transform: Optional[Callable] = None,
251    transform: Optional[Callable] = None,
252    dtype: Union[str, torch.dtype] = torch.float32,
253    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
254    n_samples: Optional[int] = None,
255    sampler: Optional[Callable] = None,
256    ndim: Optional[int] = None,
257    is_seg_dataset: Optional[bool] = None,
258    with_channels: bool = False,
259    filter_config: Optional[Dict] = None,
260    rf_channels: Tuple[int, ...] = (1,),
261) -> torch.utils.data.Dataset:
262    """Get a dataset for shallow2deep enhancer training.
263
264    Args:
265        raw_paths: The file paths to the image data. May also be a single file.
266        raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif.
267        label_paths: The file paths to the label data. May also be a single file.
268        label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif.
269        rf_paths: The file paths to the pretrained random forests.
270        patch_shape: The patch shape to load for a sample.
271        raw_transform: The transform to apply to the raw data.
272        label_transform: The transform to apply to the label data.
273        transform: The transform to apply to raw and label data, e.g. to implement augmentations.
274        dtype: The data type for the raw data.
275        rois: The regions of interest for the data.
276        n_samples: The length of this dataset.
277        sampler: A sampler to reject samples based on a pre-defined criterion.
278        ndim: The dimensionality of the data.
279        is_seg_dataset: Whether this is a segmentation or an image collection dataset.
280            If set to None, this will be determined from the data.
281        with_channels: Whether the raw data has channels.
282        filter_config: The filter configuration for the random forest.
283        rf_channels: The random forest channel to use as input for the enhancer model.
284
285    Returns:
286        The dataset.
287    """
288    check_paths(raw_paths, label_paths)
289    if is_seg_dataset is None:
290        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
291
292    # we always use a raw transform in the convenience function
293    if raw_transform is None:
294        raw_transform = get_raw_transform()
295
296    # we always use augmentations in the convenience function
297    if transform is None:
298        transform = _get_default_transform(
299            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset,
300            3 if ndim == "anisotropic" else ndim
301        )
302
303    if is_seg_dataset:
304        ds = _load_shallow2deep_segmentation_dataset(
305            raw_paths,
306            raw_key,
307            label_paths,
308            label_key,
309            rf_paths,
310            patch_shape=patch_shape,
311            raw_transform=raw_transform,
312            label_transform=label_transform,
313            transform=transform,
314            rois=rois,
315            n_samples=n_samples,
316            sampler=sampler,
317            ndim=ndim,
318            dtype=dtype,
319            with_channels=with_channels,
320            filter_config=filter_config,
321            rf_channels=rf_channels,
322        )
323    else:
324        if rois is not None:
325            raise NotImplementedError
326        ds = _load_shallow2deep_image_collection_dataset(
327            raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape,
328            raw_transform=raw_transform, label_transform=label_transform,
329            transform=transform, dtype=dtype, n_samples=n_samples,
330        )
331    return ds
332
333
334def get_shallow2deep_loader(
335    raw_paths: Union[str, Sequence[str]],
336    raw_key: Optional[str],
337    label_paths: Union[str, Sequence[str]],
338    label_key: Optional[str],
339    rf_paths: Sequence[str],
340    batch_size: int,
341    patch_shape: Tuple[int, ...],
342    raw_transform: Optional[Callable] = None,
343    label_transform: Optional[Callable] = None,
344    transform: Optional[Callable] = None,
345    dtype: Union[str, torch.dtype] = torch.float32,
346    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
347    n_samples: Optional[int] = None,
348    sampler: Optional[Callable] = None,
349    ndim: Optional[int] = None,
350    is_seg_dataset: Optional[bool] = None,
351    with_channels: bool = False,
352    filter_config: Optional[Dict] = None,
353    rf_channels: Tuple[int, ...] = (1,),
354    **loader_kwargs,
355) -> torch.utils.data.DataLoader:
356    """Get a dataloader for shallow2deep enhancer training.
357
358    Args:
359        raw_paths: The file paths to the image data. May also be a single file.
360        raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif.
361        label_paths: The file paths to the label data. May also be a single file.
362        label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif.
363        rf_paths: The file paths to the pretrained random forests.
364        batch_size: The batch size for the data loader.
365        patch_shape: The patch shape to load for a sample.
366        raw_transform: The transform to apply to the raw data.
367        label_transform: The transform to apply to the label data.
368        transform: The transform to apply to raw and label data, e.g. to implement augmentations.
369        dtype: The data type for the raw data.
370        rois: The regions of interest for the data.
371        n_samples: The length of this dataset.
372        sampler: A sampler to reject samples based on a pre-defined criterion.
373        ndim: The dimensionality of the data.
374        is_seg_dataset: Whether this is a segmentation or an image collection dataset.
375            If set to None, this will be determined from the data.
376        with_channels: Whether the raw data has channels.
377        filter_config: The filter configuration for the random forest.
378        rf_channels: The random forest channel to use as input for the enhancer model.
379        loader_kwargs: The keyword arguments for the data loader.
380
381    Returns:
382        The dataloader
383    """
384    ds = get_shallow2deep_dataset(
385        raw_paths=raw_paths,
386        raw_key=raw_key,
387        label_paths=label_paths,
388        label_key=label_key,
389        rf_paths=rf_paths,
390        patch_shape=patch_shape,
391        raw_transform=raw_transform,
392        label_transform=label_transform,
393        transform=transform,
394        rois=rois,
395        n_samples=n_samples,
396        sampler=sampler,
397        ndim=ndim,
398        is_seg_dataset=is_seg_dataset,
399        with_channels=with_channels,
400        filter_config=filter_config,
401        rf_channels=rf_channels,
402    )
403    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
def get_shallow2deep_dataset( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], label_paths: Union[str, Sequence[str]], label_key: Optional[str], rf_paths: Sequence[str], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: Union[str, torch.dtype] = torch.float32, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, filter_config: Optional[Dict] = None, rf_channels: Tuple[int, ...] = (1,)) -> torch.utils.data.dataset.Dataset:
243def get_shallow2deep_dataset(
244    raw_paths: Union[str, Sequence[str]],
245    raw_key: Optional[str],
246    label_paths: Union[str, Sequence[str]],
247    label_key: Optional[str],
248    rf_paths: Sequence[str],
249    patch_shape: Tuple[int, ...],
250    raw_transform: Optional[Callable] = None,
251    label_transform: Optional[Callable] = None,
252    transform: Optional[Callable] = None,
253    dtype: Union[str, torch.dtype] = torch.float32,
254    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
255    n_samples: Optional[int] = None,
256    sampler: Optional[Callable] = None,
257    ndim: Optional[int] = None,
258    is_seg_dataset: Optional[bool] = None,
259    with_channels: bool = False,
260    filter_config: Optional[Dict] = None,
261    rf_channels: Tuple[int, ...] = (1,),
262) -> torch.utils.data.Dataset:
263    """Get a dataset for shallow2deep enhancer training.
264
265    Args:
266        raw_paths: The file paths to the image data. May also be a single file.
267        raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif.
268        label_paths: The file paths to the label data. May also be a single file.
269        label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif.
270        rf_paths: The file paths to the pretrained random forests.
271        patch_shape: The patch shape to load for a sample.
272        raw_transform: The transform to apply to the raw data.
273        label_transform: The transform to apply to the label data.
274        transform: The transform to apply to raw and label data, e.g. to implement augmentations.
275        dtype: The data type for the raw data.
276        rois: The regions of interest for the data.
277        n_samples: The length of this dataset.
278        sampler: A sampler to reject samples based on a pre-defined criterion.
279        ndim: The dimensionality of the data.
280        is_seg_dataset: Whether this is a segmentation or an image collection dataset.
281            If set to None, this will be determined from the data.
282        with_channels: Whether the raw data has channels.
283        filter_config: The filter configuration for the random forest.
284        rf_channels: The random forest channel to use as input for the enhancer model.
285
286    Returns:
287        The dataset.
288    """
289    check_paths(raw_paths, label_paths)
290    if is_seg_dataset is None:
291        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
292
293    # we always use a raw transform in the convenience function
294    if raw_transform is None:
295        raw_transform = get_raw_transform()
296
297    # we always use augmentations in the convenience function
298    if transform is None:
299        transform = _get_default_transform(
300            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset,
301            3 if ndim == "anisotropic" else ndim
302        )
303
304    if is_seg_dataset:
305        ds = _load_shallow2deep_segmentation_dataset(
306            raw_paths,
307            raw_key,
308            label_paths,
309            label_key,
310            rf_paths,
311            patch_shape=patch_shape,
312            raw_transform=raw_transform,
313            label_transform=label_transform,
314            transform=transform,
315            rois=rois,
316            n_samples=n_samples,
317            sampler=sampler,
318            ndim=ndim,
319            dtype=dtype,
320            with_channels=with_channels,
321            filter_config=filter_config,
322            rf_channels=rf_channels,
323        )
324    else:
325        if rois is not None:
326            raise NotImplementedError
327        ds = _load_shallow2deep_image_collection_dataset(
328            raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape,
329            raw_transform=raw_transform, label_transform=label_transform,
330            transform=transform, dtype=dtype, n_samples=n_samples,
331        )
332    return ds

Get a dataset for shallow2deep enhancer training.

Arguments:
  • raw_paths: The file paths to the image data. May also be a single file.
  • raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif.
  • label_paths: The file paths to the label data. May also be a single file.
  • label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif.
  • rf_paths: The file paths to the pretrained random forests.
  • patch_shape: The patch shape to load for a sample.
  • raw_transform: The transform to apply to the raw data.
  • label_transform: The transform to apply to the label data.
  • transform: The transform to apply to raw and label data, e.g. to implement augmentations.
  • dtype: The data type for the raw data.
  • rois: The regions of interest for the data.
  • n_samples: The length of this dataset.
  • sampler: A sampler to reject samples based on a pre-defined criterion.
  • ndim: The dimensionality of the data.
  • is_seg_dataset: Whether this is a segmentation or an image collection dataset. If set to None, this will be determined from the data.
  • with_channels: Whether the raw data has channels.
  • filter_config: The filter configuration for the random forest.
  • rf_channels: The random forest channel to use as input for the enhancer model.
Returns:

The dataset.

def get_shallow2deep_loader( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], label_paths: Union[str, Sequence[str]], label_key: Optional[str], rf_paths: Sequence[str], batch_size: int, patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: Union[str, torch.dtype] = torch.float32, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, filter_config: Optional[Dict] = None, rf_channels: Tuple[int, ...] = (1,), **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
335def get_shallow2deep_loader(
336    raw_paths: Union[str, Sequence[str]],
337    raw_key: Optional[str],
338    label_paths: Union[str, Sequence[str]],
339    label_key: Optional[str],
340    rf_paths: Sequence[str],
341    batch_size: int,
342    patch_shape: Tuple[int, ...],
343    raw_transform: Optional[Callable] = None,
344    label_transform: Optional[Callable] = None,
345    transform: Optional[Callable] = None,
346    dtype: Union[str, torch.dtype] = torch.float32,
347    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
348    n_samples: Optional[int] = None,
349    sampler: Optional[Callable] = None,
350    ndim: Optional[int] = None,
351    is_seg_dataset: Optional[bool] = None,
352    with_channels: bool = False,
353    filter_config: Optional[Dict] = None,
354    rf_channels: Tuple[int, ...] = (1,),
355    **loader_kwargs,
356) -> torch.utils.data.DataLoader:
357    """Get a dataloader for shallow2deep enhancer training.
358
359    Args:
360        raw_paths: The file paths to the image data. May also be a single file.
361        raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif.
362        label_paths: The file paths to the label data. May also be a single file.
363        label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif.
364        rf_paths: The file paths to the pretrained random forests.
365        batch_size: The batch size for the data loader.
366        patch_shape: The patch shape to load for a sample.
367        raw_transform: The transform to apply to the raw data.
368        label_transform: The transform to apply to the label data.
369        transform: The transform to apply to raw and label data, e.g. to implement augmentations.
370        dtype: The data type for the raw data.
371        rois: The regions of interest for the data.
372        n_samples: The length of this dataset.
373        sampler: A sampler to reject samples based on a pre-defined criterion.
374        ndim: The dimensionality of the data.
375        is_seg_dataset: Whether this is a segmentation or an image collection dataset.
376            If set to None, this will be determined from the data.
377        with_channels: Whether the raw data has channels.
378        filter_config: The filter configuration for the random forest.
379        rf_channels: The random forest channel to use as input for the enhancer model.
380        loader_kwargs: The keyword arguments for the data loader.
381
382    Returns:
383        The dataloader
384    """
385    ds = get_shallow2deep_dataset(
386        raw_paths=raw_paths,
387        raw_key=raw_key,
388        label_paths=label_paths,
389        label_key=label_key,
390        rf_paths=rf_paths,
391        patch_shape=patch_shape,
392        raw_transform=raw_transform,
393        label_transform=label_transform,
394        transform=transform,
395        rois=rois,
396        n_samples=n_samples,
397        sampler=sampler,
398        ndim=ndim,
399        is_seg_dataset=is_seg_dataset,
400        with_channels=with_channels,
401        filter_config=filter_config,
402        rf_channels=rf_channels,
403    )
404    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)

Get a dataloader for shallow2deep enhancer training.

Arguments:
  • raw_paths: The file paths to the image data. May also be a single file.
  • raw_key: The name of the internal dataset for the raw data. Set to None for a regular image file, like tif.
  • label_paths: The file paths to the label data. May also be a single file.
  • label_key: The name of the internal dataset for the label data. Set to None for a regular image file, like tif.
  • rf_paths: The file paths to the pretrained random forests.
  • batch_size: The batch size for the data loader.
  • patch_shape: The patch shape to load for a sample.
  • raw_transform: The transform to apply to the raw data.
  • label_transform: The transform to apply to the label data.
  • transform: The transform to apply to raw and label data, e.g. to implement augmentations.
  • dtype: The data type for the raw data.
  • rois: The regions of interest for the data.
  • n_samples: The length of this dataset.
  • sampler: A sampler to reject samples based on a pre-defined criterion.
  • ndim: The dimensionality of the data.
  • is_seg_dataset: Whether this is a segmentation or an image collection dataset. If set to None, this will be determined from the data.
  • with_channels: Whether the raw data has channels.
  • filter_config: The filter configuration for the random forest.
  • rf_channels: The random forest channel to use as input for the enhancer model.
  • loader_kwargs: The keyword arguments for the data loader.
Returns:

The dataloader