torch_em.shallow2deep.prepare_shallow2deep

  1import os
  2import copy
  3import pickle
  4import warnings
  5from concurrent import futures
  6from glob import glob
  7from functools import partial
  8from typing import Callable, Dict, Optional, Sequence, Tuple, Union
  9
 10import numpy as np
 11import torch_em
 12from scipy.ndimage import gaussian_filter, convolve
 13from skimage.feature import peak_local_max
 14from sklearn.ensemble import RandomForestClassifier
 15from torch_em.segmentation import check_paths, is_segmentation_dataset, samples_to_datasets
 16from tqdm import tqdm
 17
 18import vigra
 19try:
 20    import fastfilters as filter_impl
 21except ImportError:
 22    import vigra.filters as filter_impl
 23
 24
 25class RFSegmentationDataset(torch_em.data.SegmentationDataset):
 26    """@private
 27    """
 28    _patch_shape_min = None
 29    _patch_shape_max = None
 30
 31    @property
 32    def patch_shape_min(self):
 33        return self._patch_shape_min
 34
 35    @patch_shape_min.setter
 36    def patch_shape_min(self, value):
 37        self._patch_shape_min = value
 38
 39    @property
 40    def patch_shape_max(self):
 41        return self._patch_shape_max
 42
 43    @patch_shape_max.setter
 44    def patch_shape_max(self, value):
 45        self._patch_shape_max = value
 46
 47    def _sample_bounding_box(self):
 48        assert self._patch_shape_min is not None and self._patch_shape_max is not None
 49        sample_shape = [
 50            pmin if pmin == pmax else np.random.randint(pmin, pmax)
 51            for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max)
 52        ]
 53        bb_start = [
 54            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
 55            for sh, psh in zip(self.shape, sample_shape)
 56        ]
 57        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, sample_shape))
 58
 59
 60class RFImageCollectionDataset(torch_em.data.ImageCollectionDataset):
 61    """@private
 62    """
 63    _patch_shape_min = None
 64    _patch_shape_max = None
 65
 66    @property
 67    def patch_shape_min(self):
 68        return self._patch_shape_min
 69
 70    @patch_shape_min.setter
 71    def patch_shape_min(self, value):
 72        self._patch_shape_min = value
 73
 74    @property
 75    def patch_shape_max(self):
 76        return self._patch_shape_max
 77
 78    @patch_shape_max.setter
 79    def patch_shape_max(self, value):
 80        self._patch_shape_max = value
 81
 82    def _sample_bounding_box(self, shape):
 83        if any(sh < psh for sh, psh in zip(shape, self.patch_shape_max)):
 84            raise NotImplementedError("Image padding is not supported yet.")
 85        assert self._patch_shape_min is not None and self._patch_shape_max is not None
 86        patch_shape = [
 87            pmin if pmin == pmax else np.random.randint(pmin, pmax)
 88            for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max)
 89        ]
 90        bb_start = [
 91            np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, patch_shape)
 92        ]
 93        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape))
 94
 95
 96def _load_rf_segmentation_dataset(
 97    raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, **kwargs
 98):
 99    rois = kwargs.pop("rois", None)
100    sampler = kwargs.pop("sampler", None)
101    sampler = sampler if sampler else torch_em.data.MinForegroundSampler(min_fraction=0.01)
102    if isinstance(raw_paths, str):
103        if rois is not None:
104            assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois)
105        ds = RFSegmentationDataset(
106            raw_paths, raw_key, label_paths, label_key, roi=rois, patch_shape=patch_shape_min, sampler=sampler, **kwargs
107        )
108        ds.patch_shape_min = patch_shape_min
109        ds.patch_shape_max = patch_shape_max
110    else:
111        assert len(raw_paths) > 0
112        if rois is not None:
113            assert len(rois) == len(label_paths)
114            assert all(isinstance(roi, tuple) for roi in rois)
115        n_samples = kwargs.pop("n_samples", None)
116
117        samples_per_ds = (
118            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
119        )
120        ds = []
121        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
122            roi = None if rois is None else rois[i]
123            dset = RFSegmentationDataset(
124                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i],
125                patch_shape=patch_shape_min, sampler=sampler, **kwargs
126            )
127            dset.patch_shape_min = patch_shape_min
128            dset.patch_shape_max = patch_shape_max
129            ds.append(dset)
130        ds = torch_em.data.ConcatDataset(*ds)
131    return ds
132
133
134def _load_rf_image_collection_dataset(
135    raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, roi, **kwargs
136):
137    def _get_paths(rpath, rkey, lpath, lkey, this_roi):
138        rpath = glob(os.path.join(rpath, rkey))
139        rpath.sort()
140        if len(rpath) == 0:
141            raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}")
142        lpath = glob(os.path.join(lpath, lkey))
143        lpath.sort()
144        if len(rpath) != len(lpath):
145            raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}")
146
147        if this_roi is not None:
148            rpath, lpath = rpath[roi], lpath[roi]
149
150        return rpath, lpath
151
152    def _check_patch(patch_shape):
153        if len(patch_shape) == 3:
154            if patch_shape[0] != 1:
155                raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}")
156            patch_shape = patch_shape[1:]
157        assert len(patch_shape) == 2
158        return patch_shape
159
160    patch_shape_min = _check_patch(patch_shape_min)
161    patch_shape_max = _check_patch(patch_shape_max)
162
163    if isinstance(raw_paths, str):
164        raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi)
165        ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs)
166        ds.patch_shape_min = patch_shape_min
167        ds.patch_shape_max = patch_shape_max
168    elif raw_key is None:
169        assert label_key is None
170        assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple))
171        assert len(raw_paths) == len(label_paths)
172        ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs)
173        ds.patch_shape_min = patch_shape_min
174        ds.patch_shape_max = patch_shape_max
175    else:
176        ds = []
177        n_samples = kwargs.pop("n_samples", None)
178        samples_per_ds = (
179            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
180        )
181        if roi is None:
182            roi = len(raw_paths) * [None]
183        assert len(roi) == len(raw_paths)
184        for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)):
185            rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi)
186            dset = RFImageCollectionDataset(
187                rpath, lpath, patch_shape=patch_shape_min, n_samples=samples_per_ds[i], **kwargs
188            )
189            dset.patch_shape_min = patch_shape_min
190            dset.patch_shape_max = patch_shape_max
191            ds.append(dset)
192        ds = torch_em.data.ConcatDataset(*ds)
193    return ds
194
195
196def _get_filters(ndim, filters_and_sigmas):
197    # subset of ilastik default features
198    if filters_and_sigmas is None:
199        filters = [filter_impl.gaussianSmoothing,
200                   filter_impl.laplacianOfGaussian,
201                   filter_impl.gaussianGradientMagnitude,
202                   filter_impl.hessianOfGaussianEigenvalues,
203                   filter_impl.structureTensorEigenvalues]
204        sigmas = [0.7, 1.6, 3.5, 5.0]
205        filters_and_sigmas = [
206            (filt, sigma) if i != len(filters) - 1 else (partial(filt, outerScale=0.5*sigma), sigma)
207            for i, filt in enumerate(filters) for sigma in sigmas
208        ]
209    # validate the filter config
210    assert isinstance(filters_and_sigmas, (list, tuple))
211    for filt_and_sig in filters_and_sigmas:
212        filt, sig = filt_and_sig
213        assert callable(filt) or (isinstance(filt, str) and hasattr(filter_impl, filt))
214        assert isinstance(sig, (float, tuple))
215        if isinstance(sig, tuple):
216            assert ndim is not None and len(sig) == ndim
217            assert all(isinstance(sigg, float) for sigg in sig)
218    return filters_and_sigmas
219
220
221def _calculate_response(raw, filter_, sigma):
222    if callable(filter_):
223        return filter_(raw, sigma)
224
225    # filter_ is still string, convert it to function
226    # fastfilters does not support passing sigma as tuple
227    func = getattr(vigra.filters, filter_) if isinstance(sigma, tuple) else getattr(filter_impl, filter_)
228
229    # special case since additional argument outerScale
230    # is needed for structureTensorEigenvalues functions
231    if filter_ == "structureTensorEigenvalues":
232        outerScale = tuple([s*2 for s in sigma]) if isinstance(sigma, tuple) else 2*sigma
233        return func(raw, sigma, outerScale=outerScale)
234
235    return func(raw, sigma)
236
237
238def _apply_filters(raw, filters_and_sigmas):
239    features = []
240    for filter_, sigma in filters_and_sigmas:
241        response = _calculate_response(raw, filter_, sigma)
242        if response.ndim > raw.ndim:
243            for c in range(response.shape[-1]):
244                features.append(response[..., c].flatten())
245        else:
246            features.append(response.flatten())
247    features = np.concatenate([ff[:, None] for ff in features], axis=1)
248    return features
249
250
251def _apply_filters_with_mask(raw, filters_and_sigmas, mask):
252    features = []
253    for filter_, sigma in filters_and_sigmas:
254        response = _calculate_response(raw, filter_, sigma)
255        if response.ndim > raw.ndim:
256            for c in range(response.shape[-1]):
257                features.append(response[..., c][mask])
258        else:
259            features.append(response[mask])
260    features = np.concatenate([ff[:, None] for ff in features], axis=1)
261    return features
262
263
264def _balance_labels(labels, mask):
265    class_ids, label_counts = np.unique(labels[mask], return_counts=True)
266    n_classes = len(class_ids)
267    assert class_ids.tolist() == list(range(n_classes))
268
269    min_class = class_ids[np.argmin(label_counts)]
270    n_labels = label_counts[min_class]
271
272    for class_id in class_ids:
273        if class_id == min_class:
274            continue
275        n_discard = label_counts[class_id] - n_labels
276        # sample from the current class
277        # shuffle the positions and only keep up to n_labels in the mask
278        label_pos = np.where(labels == class_id)
279        discard_ids = np.arange(len(label_pos[0]))
280        np.random.shuffle(discard_ids)
281        discard_ids = discard_ids[:n_discard]
282        discard_mask = tuple(pos[discard_ids] for pos in label_pos)
283        mask[discard_mask] = False
284
285    assert mask.sum() == n_classes * n_labels
286    return mask
287
288
289def _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels, return_mask=False):
290    # find the mask for where we compute filters and labels
291    # by default we exclude everything that has label -1
292    assert labels.shape == raw.shape
293    mask = labels != -1
294    if balance_labels:
295        mask = _balance_labels(labels, mask)
296    labels = labels[mask]
297    assert labels.ndim == 1
298    features = _apply_filters_with_mask(raw, filters_and_sigmas, mask)
299    assert features.ndim == 2
300    assert len(features) == len(labels)
301    if return_mask:
302        return features, labels, mask
303    else:
304        return features, labels
305
306
307def _prepare_shallow2deep(
308    raw_paths,
309    raw_key,
310    label_paths,
311    label_key,
312    patch_shape_min,
313    patch_shape_max,
314    n_forests,
315    ndim,
316    raw_transform,
317    label_transform,
318    rois,
319    is_seg_dataset,
320    filter_config,
321    sampler,
322):
323    assert len(patch_shape_min) == len(patch_shape_max)
324    assert all(maxs >= mins for maxs, mins in zip(patch_shape_max, patch_shape_min))
325    check_paths(raw_paths, label_paths)
326
327    # get the correct dataset
328    if is_seg_dataset is None:
329        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
330    if is_seg_dataset:
331        ds = _load_rf_segmentation_dataset(raw_paths, raw_key, label_paths, label_key,
332                                           patch_shape_min, patch_shape_max,
333                                           raw_transform=raw_transform, label_transform=label_transform,
334                                           rois=rois, n_samples=n_forests, sampler=sampler)
335    else:
336        ds = _load_rf_image_collection_dataset(raw_paths, raw_key, label_paths, label_key,
337                                               patch_shape_min, patch_shape_max, roi=rois,
338                                               raw_transform=raw_transform, label_transform=label_transform,
339                                               n_samples=n_forests)
340
341    assert len(ds) == n_forests, f"{len(ds), {n_forests}}"
342    filters_and_sigmas = _get_filters(ndim, filter_config)
343    return ds, filters_and_sigmas
344
345
346def _serialize_feature_config(filters_and_sigmas):
347    feature_config = [
348        (filt if isinstance(filt, str) else (filt.func.__name__ if isinstance(filt, partial) else filt.__name__), sigma)
349        for filt, sigma in filters_and_sigmas
350    ]
351    return feature_config
352
353
354def prepare_shallow2deep(
355    raw_paths: Union[str, Sequence[str]],
356    raw_key: Optional[str],
357    label_paths: Union[str, Sequence[str]],
358    label_key: Optional[str],
359    patch_shape_min: Tuple[int, ...],
360    patch_shape_max: Tuple[int, ...],
361    n_forests: int,
362    n_threads: int,
363    output_folder: str,
364    ndim: int,
365    raw_transform: Optional[Callable] = None,
366    label_transform: Optional[Callable] = None,
367    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
368    is_seg_dataset: Optional[bool] = None,
369    balance_labels: bool = True,
370    filter_config: Optional[Dict] = None,
371    sampler: Optional[Callable] = None,
372    **rf_kwargs,
373) -> None:
374    """Prepare shallow2deep enhancer training by pre-training random forests.
375
376    Args:
377        raw_paths: The file paths to the raw data. May also be a single file.
378        raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
379        label_paths: The file paths to the lable data. May also be a single file.
380        label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
381        patch_shape_min: The minimal patch shape loaded for training a random forest.
382        patch_shape_max: The maximal patch shape loaded for training a random forest.
383        n_forests: The number of random forests to train.
384        n_threads: The number of threads for parallelizing the training.
385        output_folder: The folder for saving the random forests.
386        ndim: The dimensionality of the data.
387        raw_transform: A transform to apply to the raw data before computing feautres on it.
388        label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
389        rois: Region of interests for the training data.
390        is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset.
391            If None, this wil be determined from the data.
392        balance_labels: Whether to balance the training labels for the random forest.
393        filter_config: The configuration for the image filters that are used to compute features for the random forest.
394        sampler: A sampler to reject samples from training.
395        rf_kwargs: Keyword arguments for creating the random forest.
396    """
397    os.makedirs(output_folder, exist_ok=True)
398    ds, filters_and_sigmas = _prepare_shallow2deep(
399        raw_paths, raw_key, label_paths, label_key,
400        patch_shape_min, patch_shape_max, n_forests, ndim,
401        raw_transform, label_transform, rois, is_seg_dataset,
402        filter_config, sampler,
403    )
404    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
405
406    def _train_rf(rf_id):
407        # Sample random patch with dataset.
408        raw, labels = ds[rf_id]
409        # Cast to numpy and remove channel axis.
410        # Need to update this to support multi-channel input data and/or multi class prediction.
411        raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
412        assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
413        features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels)
414        rf = RandomForestClassifier(**rf_kwargs)
415        rf.fit(features, labels)
416        # Monkey patch these so that we know the feature config and dimensionality.
417        rf.feature_ndim = ndim
418        rf.feature_config = serialized_feature_config
419        out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
420        with open(out_path, "wb") as f:
421            pickle.dump(rf, f)
422
423    with futures.ThreadPoolExecutor(n_threads) as tp:
424        list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests))
425
426
427def _score_based_points(
428    score_function,
429    features, labels, rf_id,
430    forests, forests_per_stage,
431    sample_fraction_per_stage,
432    accumulate_samples,
433):
434    # get the corresponding random forest from the last stage
435    # and predict with it
436    last_forest = forests[rf_id - forests_per_stage]
437    pred = last_forest.predict_proba(features)
438
439    score = score_function(pred, labels)
440    assert len(score) == len(features)
441
442    # get training samples based on the label-prediction diff
443    samples = []
444    nc = len(np.unique(labels))
445    # sample in a class balanced way
446    n_samples = int(sample_fraction_per_stage * len(features))
447    n_samples_class = n_samples // nc
448    for class_id in range(nc):
449        class_indices = np.where(labels == class_id)[0]
450        this_samples = class_indices[np.argsort(score[class_indices])[::-1][:n_samples_class]]
451        samples.append(this_samples)
452    samples = np.concatenate(samples)
453
454    # get the features and labels, add from previous rf if specified
455    features, labels = features[samples], labels[samples]
456    if accumulate_samples:
457        features = np.concatenate([last_forest.train_features, features], axis=0)
458        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
459
460    return features, labels
461
462
463def worst_points(
464    features, labels, rf_id,
465    forests, forests_per_stage,
466    sample_fraction_per_stage,
467    accumulate_samples=True,
468    **kwargs,
469):
470    """@private
471    """
472    def score(pred, labels):
473        # labels to one-hot encoding
474        unique, inverse = np.unique(labels, return_inverse=True)
475        onehot = np.eye(unique.shape[0])[inverse]
476        # compute the difference between labels and prediction
477        return np.abs(onehot - pred).sum(axis=1)
478
479    return _score_based_points(
480        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
481    )
482
483
484def uncertain_points(
485    features, labels, rf_id,
486    forests, forests_per_stage,
487    sample_fraction_per_stage,
488    accumulate_samples=True,
489    **kwargs,
490):
491    """@private
492    """
493    def score(pred, labels):
494        assert pred.ndim == 2
495        channel_sorted = np.sort(pred, axis=1)
496        uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2]
497        return uncertainty
498
499    return _score_based_points(
500        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
501    )
502
503
504def uncertain_worst_points(
505    features, labels, rf_id,
506    forests, forests_per_stage,
507    sample_fraction_per_stage,
508    accumulate_samples=True,
509    alpha=0.5,
510    **kwargs,
511):
512    """@private
513    """
514    def score(pred, labels):
515        assert pred.ndim == 2
516
517        # labels to one-hot encoding
518        unique, inverse = np.unique(labels, return_inverse=True)
519        onehot = np.eye(unique.shape[0])[inverse]
520        # compute the difference between labels and prediction
521        diff = np.abs(onehot - pred).sum(axis=1)
522
523        channel_sorted = np.sort(pred, axis=1)
524        uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2]
525        return alpha * diff + (1.0 - alpha) * uncertainty
526
527    return _score_based_points(
528        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
529    )
530
531
532def random_points(
533    features, labels, rf_id,
534    forests, forests_per_stage,
535    sample_fraction_per_stage,
536    accumulate_samples=True,
537    **kwargs,
538):
539    """@private
540    """
541    samples = []
542    nc = len(np.unique(labels))
543    # sample in a class balanced way
544    n_samples = int(sample_fraction_per_stage * len(features))
545    n_samples_class = n_samples // nc
546    for class_id in range(nc):
547        class_indices = np.where(labels == class_id)[0]
548        this_samples = np.random.choice(
549            class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
550        )
551        samples.append(this_samples)
552    samples = np.concatenate(samples)
553    features, labels = features[samples], labels[samples]
554
555    if accumulate_samples and rf_id >= forests_per_stage:
556        last_forest = forests[rf_id - forests_per_stage]
557        features = np.concatenate([last_forest.train_features, features], axis=0)
558        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
559
560    return features, labels
561
562
563def worst_tiles(
564    features, labels, rf_id,
565    forests, forests_per_stage,
566    sample_fraction_per_stage,
567    img_shape,
568    mask,
569    tile_shape=[25, 25],
570    smoothing_sigma=None,
571    accumulate_samples=True,
572    **kwargs,
573):
574    """@private
575    """
576    # check inputs
577    ndim = len(img_shape)
578    assert ndim in [2, 3], img_shape
579    assert len(tile_shape) == ndim, tile_shape
580
581    # get the corresponding random forest from the last stage
582    # and predict with it
583    last_forest = forests[rf_id - forests_per_stage]
584    pred = last_forest.predict_proba(features)
585
586    # labels to one-hot encoding
587    unique, inverse = np.unique(labels, return_inverse=True)
588    onehot = np.eye(unique.shape[0])[inverse]
589
590    # compute the difference between labels and prediction
591    diff = np.abs(onehot - pred)
592    assert len(diff) == len(features)
593
594    # reshape diff to image shape
595    # we need to also take into account the mask here, and if we apply any masking
596    # because we can't directly reshape if we have it
597    if mask.sum() != mask.size:
598        # get the diff image
599        diff_img = np.zeros(img_shape + diff.shape[-1:], dtype=diff.dtype)
600        diff_img[mask] = diff
601        # inflate the features
602        full_features = np.zeros((mask.size,) + features.shape[-1:], dtype=features.dtype)
603        full_features[mask.ravel()] = features
604        features = full_features
605        # inflate the labels (with -1 so this will not be sampled)
606        full_labels = np.full(mask.size, -1, dtype="int8")
607        full_labels[mask.ravel()] = labels
608        labels = full_labels
609    else:
610        diff_img = diff.reshape(img_shape + (-1,))
611
612    # get the number of classes (not counting ignore label)
613    class_ids = np.unique(labels)
614    nc = len(class_ids) - 1 if -1 in class_ids else len(class_ids)
615
616    # sample in a class balanced way
617    n_samples_class = int(sample_fraction_per_stage * len(features)) // nc
618    samples = []
619    for class_id in range(nc):
620        # smooth either with gaussian or 1-kernel
621        if smoothing_sigma:
622            diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode="constant")
623        else:
624            kernel = np.ones(tile_shape)
625            diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode="constant")
626
627        # get training samples based on tiles around maxima of the label-prediction diff
628        # do this in a class-specific way to ensure that each class is sampled
629        # get maxima of the label-prediction diff (they seem to be sorted already)
630        max_centers = peak_local_max(
631            diff_img_smooth,
632            min_distance=max(tile_shape),
633            exclude_border=tuple([s // 2 for s in tile_shape])
634        )
635
636        # get indices of tiles around maxima
637        tiles = []
638        for center in max_centers:
639            tile_slice = tuple(
640                slice(
641                    center[d]-tile_shape[d]//2,
642                    center[d]+tile_shape[d]//2 + 1,
643                    None
644                ) for d in range(ndim)
645            )
646            grid = np.mgrid[tile_slice]
647            samples_in_tile = grid.reshape(ndim, -1)
648            samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape)
649            tiles.append(samples_in_tile)
650
651        # this (very rarely) fails due to empty tile list. Since we usually
652        # accumulate the features this doesn't hurt much and we can continue
653        try:
654            tiles = np.concatenate(tiles)
655            # take samples that belong to the current class
656            this_samples = tiles[labels[tiles] == class_id][:n_samples_class]
657            samples.append(this_samples)
658        except ValueError:
659            pass
660
661    try:
662        samples = np.concatenate(samples)
663        features, labels = features[samples], labels[samples]
664
665        # get the features and labels, add from previous rf if specified
666        if accumulate_samples:
667            features = np.concatenate([last_forest.train_features, features], axis=0)
668            labels = np.concatenate([last_forest.train_labels, labels], axis=0)
669    except ValueError:
670        features, labels = last_forest.train_features, last_forest.train_labels
671        warnings.warn(
672            f"No features were sampled for forest {rf_id} using features of forest {rf_id - forests_per_stage}"
673        )
674
675    return features, labels
676
677
678def balanced_dense_accumulate(
679    features, labels, rf_id,
680    forests, forests_per_stage,
681    sample_fraction_per_stage,
682    accumulate_samples=True,
683    **kwargs,
684):
685    """@private
686    """
687    samples = []
688    nc = len(np.unique(labels))
689    # sample in a class balanced way
690    # take all pixels from minority class
691    # and choose same amount from other classes randomly
692    n_samples_class = np.unique(labels, return_counts=True)[1].min()
693    for class_id in range(nc):
694        class_indices = np.where(labels == class_id)[0]
695        this_samples = np.random.choice(
696            class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
697        )
698        samples.append(this_samples)
699    samples = np.concatenate(samples)
700    features, labels = features[samples], labels[samples]
701
702    # accumulate
703    if accumulate_samples and rf_id >= forests_per_stage:
704        last_forest = forests[rf_id - forests_per_stage]
705        features = np.concatenate([last_forest.train_features, features], axis=0)
706        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
707
708    return features, labels
709
710
711SAMPLING_STRATEGIES = {
712    "random_points": random_points,
713    "uncertain_points": uncertain_points,
714    "uncertain_worst_points": uncertain_worst_points,
715    "worst_points": worst_points,
716    "worst_tiles": worst_tiles,
717    "balanced_dense_accumulate": balanced_dense_accumulate,
718}
719"""@private
720"""
721
722
723def prepare_shallow2deep_advanced(
724    raw_paths: Union[str, Sequence[str]],
725    raw_key: Optional[str],
726    label_paths: Union[str, Sequence[str]],
727    label_key: Optional[str],
728    patch_shape_min: Tuple[int, ...],
729    patch_shape_max: Tuple[int, ...],
730    n_forests: int,
731    n_threads: int,
732    output_folder: str,
733    ndim: int,
734    forests_per_stage: int,
735    sample_fraction_per_stage: float,
736    sampling_strategy: Union[str, Callable] = "worst_points",
737    sampling_kwargs: Dict = {},
738    raw_transform: Optional[Callable] = None,
739    label_transform: Optional[Callable] = None,
740    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
741    is_seg_dataset: Optional[bool] = None,
742    balance_labels: bool = True,
743    filter_config: Optional[Dict] = None,
744    sampler: Optional[Callable] = None,
745    **rf_kwargs,
746) -> None:
747    """Prepare shallow2deep enhancer training by pre-training random forests.
748
749    This function implements an advanced training procedure compared to `prepare_shallow2deep`.
750    The 'sampling_strategy' argument determines an advnaced sampling strategies,
751    which selects the samples to use for training the random forests.
752    The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests
753    are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage.
754    The random forests in stage 0 are trained from random balanced labels.
755    For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature
756    '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)',
757    and return the sampled features and labels. See for example the 'worst_points' function.
758    Alternatively, one of the pre-defined strategies can be selected by passing one of the following names:
759    - "random_poinst": Select random points.
760    - "uncertain_points": Select points with the highest uncertainty.
761    - "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies.
762    - "worst_points": Select the points with the worst accuracies.
763    - "worst_tiles": Selectt the tiles with the worst accuracies.
764    - "balanced_dense_accumulate": Balanced dense accumulation.
765
766    Args:
767        raw_paths: The file paths to the raw data. May also be a single file.
768        raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
769        label_paths: The file paths to the lable data. May also be a single file.
770        label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
771        patch_shape_min: The minimal patch shape loaded for training a random forest.
772        patch_shape_max: The maximal patch shape loaded for training a random forest.
773        n_forests: The number of random forests to train.
774        n_threads: The number of threads for parallelizing the training.
775        output_folder: The folder for saving the random forests.
776        ndim: The dimensionality of the data.
777        forests_per_stage: The number of forests to train per stage.
778        sample_fraction_per_stage: The fraction of samples to use per stage.
779        sampling_strategy: The sampling strategy.
780        sampling_kwargs: The keyword arguments for the sampling strategy.
781        raw_transform: A transform to apply to the raw data before computing feautres on it.
782        label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
783        rois: Region of interests for the training data.
784        is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset.
785            If None, this wil be determined from the data.
786        balance_labels: Whether to balance the training labels for the random forest.
787        filter_config: The configuration for the image filters that are used to compute features for the random forest.
788        sampler: A sampler to reject samples from training.
789        rf_kwargs: Keyword arguments for creating the random forest.
790    """
791    os.makedirs(output_folder, exist_ok=True)
792    ds, filters_and_sigmas = _prepare_shallow2deep(
793        raw_paths, raw_key, label_paths, label_key,
794        patch_shape_min, patch_shape_max, n_forests, ndim,
795        raw_transform, label_transform, rois, is_seg_dataset,
796        filter_config, sampler,
797    )
798    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
799
800    forests = []
801    n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\
802        n_forests // forests_per_stage + 1
803
804    if isinstance(sampling_strategy, str):
805        assert sampling_strategy in SAMPLING_STRATEGIES, \
806            f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}"
807        sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy]
808    assert callable(sampling_strategy)
809
810    with tqdm(total=n_forests) as pbar:
811
812        def _train_rf(rf_id):
813            # sample random patch with dataset
814            raw, labels = ds[rf_id]
815
816            # cast to numpy and remove channel axis
817            # need to update this to support multi-channel input data and/or multi class prediction
818            raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
819            assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
820
821            # monkey patch original shape to sampling_kwargs
822            # deepcopy needed due to multithreading
823            current_kwargs = copy.deepcopy(sampling_kwargs)
824            current_kwargs["img_shape"] = raw.shape
825
826            # only balance samples for the first (densely trained) rfs
827            features, labels, mask = _get_features_and_labels(
828                raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True
829            )
830            if forests:  # we have forests: apply the sampling strategy
831                features, labels = sampling_strategy(
832                    features, labels, rf_id,
833                    forests, forests_per_stage,
834                    sample_fraction_per_stage,
835                    mask=mask,
836                    **current_kwargs,
837                )
838            else:  # sample randomly
839                features, labels = random_points(
840                    features, labels, rf_id,
841                    forests, forests_per_stage,
842                    sample_fraction_per_stage,
843                )
844
845            # fit the random forest
846            assert len(features) == len(labels)
847            rf = RandomForestClassifier(**rf_kwargs)
848            rf.fit(features, labels)
849            # monkey patch these so that we know the feature config and dimensionality
850            rf.feature_ndim = ndim
851            rf.feature_config = serialized_feature_config
852
853            # save the random forest, update pbar, return it
854            out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
855            with open(out_path, "wb") as f:
856                pickle.dump(rf, f)
857
858            # monkey patch the training data and labels so we can re-use it in later stages
859            rf.train_features = features
860            rf.train_labels = labels
861
862            pbar.update(1)
863            return rf
864
865        for stage in range(n_stages):
866            pbar.set_description(f"Train RFs for stage {stage}")
867            with futures.ThreadPoolExecutor(n_threads) as tp:
868                this_forests = list(tp.map(
869                    _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1))
870                ))
871                forests.extend(this_forests)
def prepare_shallow2deep( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], label_paths: Union[str, Sequence[str]], label_key: Optional[str], patch_shape_min: Tuple[int, ...], patch_shape_max: Tuple[int, ...], n_forests: int, n_threads: int, output_folder: str, ndim: int, raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, is_seg_dataset: Optional[bool] = None, balance_labels: bool = True, filter_config: Optional[Dict] = None, sampler: Optional[Callable] = None, **rf_kwargs) -> None:
355def prepare_shallow2deep(
356    raw_paths: Union[str, Sequence[str]],
357    raw_key: Optional[str],
358    label_paths: Union[str, Sequence[str]],
359    label_key: Optional[str],
360    patch_shape_min: Tuple[int, ...],
361    patch_shape_max: Tuple[int, ...],
362    n_forests: int,
363    n_threads: int,
364    output_folder: str,
365    ndim: int,
366    raw_transform: Optional[Callable] = None,
367    label_transform: Optional[Callable] = None,
368    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
369    is_seg_dataset: Optional[bool] = None,
370    balance_labels: bool = True,
371    filter_config: Optional[Dict] = None,
372    sampler: Optional[Callable] = None,
373    **rf_kwargs,
374) -> None:
375    """Prepare shallow2deep enhancer training by pre-training random forests.
376
377    Args:
378        raw_paths: The file paths to the raw data. May also be a single file.
379        raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
380        label_paths: The file paths to the lable data. May also be a single file.
381        label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
382        patch_shape_min: The minimal patch shape loaded for training a random forest.
383        patch_shape_max: The maximal patch shape loaded for training a random forest.
384        n_forests: The number of random forests to train.
385        n_threads: The number of threads for parallelizing the training.
386        output_folder: The folder for saving the random forests.
387        ndim: The dimensionality of the data.
388        raw_transform: A transform to apply to the raw data before computing feautres on it.
389        label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
390        rois: Region of interests for the training data.
391        is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset.
392            If None, this wil be determined from the data.
393        balance_labels: Whether to balance the training labels for the random forest.
394        filter_config: The configuration for the image filters that are used to compute features for the random forest.
395        sampler: A sampler to reject samples from training.
396        rf_kwargs: Keyword arguments for creating the random forest.
397    """
398    os.makedirs(output_folder, exist_ok=True)
399    ds, filters_and_sigmas = _prepare_shallow2deep(
400        raw_paths, raw_key, label_paths, label_key,
401        patch_shape_min, patch_shape_max, n_forests, ndim,
402        raw_transform, label_transform, rois, is_seg_dataset,
403        filter_config, sampler,
404    )
405    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
406
407    def _train_rf(rf_id):
408        # Sample random patch with dataset.
409        raw, labels = ds[rf_id]
410        # Cast to numpy and remove channel axis.
411        # Need to update this to support multi-channel input data and/or multi class prediction.
412        raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
413        assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
414        features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels)
415        rf = RandomForestClassifier(**rf_kwargs)
416        rf.fit(features, labels)
417        # Monkey patch these so that we know the feature config and dimensionality.
418        rf.feature_ndim = ndim
419        rf.feature_config = serialized_feature_config
420        out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
421        with open(out_path, "wb") as f:
422            pickle.dump(rf, f)
423
424    with futures.ThreadPoolExecutor(n_threads) as tp:
425        list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests))

Prepare shallow2deep enhancer training by pre-training random forests.

Arguments:
  • raw_paths: The file paths to the raw data. May also be a single file.
  • raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
  • label_paths: The file paths to the lable data. May also be a single file.
  • label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
  • patch_shape_min: The minimal patch shape loaded for training a random forest.
  • patch_shape_max: The maximal patch shape loaded for training a random forest.
  • n_forests: The number of random forests to train.
  • n_threads: The number of threads for parallelizing the training.
  • output_folder: The folder for saving the random forests.
  • ndim: The dimensionality of the data.
  • raw_transform: A transform to apply to the raw data before computing feautres on it.
  • label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
  • rois: Region of interests for the training data.
  • is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. If None, this wil be determined from the data.
  • balance_labels: Whether to balance the training labels for the random forest.
  • filter_config: The configuration for the image filters that are used to compute features for the random forest.
  • sampler: A sampler to reject samples from training.
  • rf_kwargs: Keyword arguments for creating the random forest.
def prepare_shallow2deep_advanced( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], label_paths: Union[str, Sequence[str]], label_key: Optional[str], patch_shape_min: Tuple[int, ...], patch_shape_max: Tuple[int, ...], n_forests: int, n_threads: int, output_folder: str, ndim: int, forests_per_stage: int, sample_fraction_per_stage: float, sampling_strategy: Union[str, Callable] = 'worst_points', sampling_kwargs: Dict = {}, raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, is_seg_dataset: Optional[bool] = None, balance_labels: bool = True, filter_config: Optional[Dict] = None, sampler: Optional[Callable] = None, **rf_kwargs) -> None:
724def prepare_shallow2deep_advanced(
725    raw_paths: Union[str, Sequence[str]],
726    raw_key: Optional[str],
727    label_paths: Union[str, Sequence[str]],
728    label_key: Optional[str],
729    patch_shape_min: Tuple[int, ...],
730    patch_shape_max: Tuple[int, ...],
731    n_forests: int,
732    n_threads: int,
733    output_folder: str,
734    ndim: int,
735    forests_per_stage: int,
736    sample_fraction_per_stage: float,
737    sampling_strategy: Union[str, Callable] = "worst_points",
738    sampling_kwargs: Dict = {},
739    raw_transform: Optional[Callable] = None,
740    label_transform: Optional[Callable] = None,
741    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
742    is_seg_dataset: Optional[bool] = None,
743    balance_labels: bool = True,
744    filter_config: Optional[Dict] = None,
745    sampler: Optional[Callable] = None,
746    **rf_kwargs,
747) -> None:
748    """Prepare shallow2deep enhancer training by pre-training random forests.
749
750    This function implements an advanced training procedure compared to `prepare_shallow2deep`.
751    The 'sampling_strategy' argument determines an advnaced sampling strategies,
752    which selects the samples to use for training the random forests.
753    The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests
754    are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage.
755    The random forests in stage 0 are trained from random balanced labels.
756    For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature
757    '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)',
758    and return the sampled features and labels. See for example the 'worst_points' function.
759    Alternatively, one of the pre-defined strategies can be selected by passing one of the following names:
760    - "random_poinst": Select random points.
761    - "uncertain_points": Select points with the highest uncertainty.
762    - "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies.
763    - "worst_points": Select the points with the worst accuracies.
764    - "worst_tiles": Selectt the tiles with the worst accuracies.
765    - "balanced_dense_accumulate": Balanced dense accumulation.
766
767    Args:
768        raw_paths: The file paths to the raw data. May also be a single file.
769        raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
770        label_paths: The file paths to the lable data. May also be a single file.
771        label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
772        patch_shape_min: The minimal patch shape loaded for training a random forest.
773        patch_shape_max: The maximal patch shape loaded for training a random forest.
774        n_forests: The number of random forests to train.
775        n_threads: The number of threads for parallelizing the training.
776        output_folder: The folder for saving the random forests.
777        ndim: The dimensionality of the data.
778        forests_per_stage: The number of forests to train per stage.
779        sample_fraction_per_stage: The fraction of samples to use per stage.
780        sampling_strategy: The sampling strategy.
781        sampling_kwargs: The keyword arguments for the sampling strategy.
782        raw_transform: A transform to apply to the raw data before computing feautres on it.
783        label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
784        rois: Region of interests for the training data.
785        is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset.
786            If None, this wil be determined from the data.
787        balance_labels: Whether to balance the training labels for the random forest.
788        filter_config: The configuration for the image filters that are used to compute features for the random forest.
789        sampler: A sampler to reject samples from training.
790        rf_kwargs: Keyword arguments for creating the random forest.
791    """
792    os.makedirs(output_folder, exist_ok=True)
793    ds, filters_and_sigmas = _prepare_shallow2deep(
794        raw_paths, raw_key, label_paths, label_key,
795        patch_shape_min, patch_shape_max, n_forests, ndim,
796        raw_transform, label_transform, rois, is_seg_dataset,
797        filter_config, sampler,
798    )
799    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
800
801    forests = []
802    n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\
803        n_forests // forests_per_stage + 1
804
805    if isinstance(sampling_strategy, str):
806        assert sampling_strategy in SAMPLING_STRATEGIES, \
807            f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}"
808        sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy]
809    assert callable(sampling_strategy)
810
811    with tqdm(total=n_forests) as pbar:
812
813        def _train_rf(rf_id):
814            # sample random patch with dataset
815            raw, labels = ds[rf_id]
816
817            # cast to numpy and remove channel axis
818            # need to update this to support multi-channel input data and/or multi class prediction
819            raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
820            assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
821
822            # monkey patch original shape to sampling_kwargs
823            # deepcopy needed due to multithreading
824            current_kwargs = copy.deepcopy(sampling_kwargs)
825            current_kwargs["img_shape"] = raw.shape
826
827            # only balance samples for the first (densely trained) rfs
828            features, labels, mask = _get_features_and_labels(
829                raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True
830            )
831            if forests:  # we have forests: apply the sampling strategy
832                features, labels = sampling_strategy(
833                    features, labels, rf_id,
834                    forests, forests_per_stage,
835                    sample_fraction_per_stage,
836                    mask=mask,
837                    **current_kwargs,
838                )
839            else:  # sample randomly
840                features, labels = random_points(
841                    features, labels, rf_id,
842                    forests, forests_per_stage,
843                    sample_fraction_per_stage,
844                )
845
846            # fit the random forest
847            assert len(features) == len(labels)
848            rf = RandomForestClassifier(**rf_kwargs)
849            rf.fit(features, labels)
850            # monkey patch these so that we know the feature config and dimensionality
851            rf.feature_ndim = ndim
852            rf.feature_config = serialized_feature_config
853
854            # save the random forest, update pbar, return it
855            out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
856            with open(out_path, "wb") as f:
857                pickle.dump(rf, f)
858
859            # monkey patch the training data and labels so we can re-use it in later stages
860            rf.train_features = features
861            rf.train_labels = labels
862
863            pbar.update(1)
864            return rf
865
866        for stage in range(n_stages):
867            pbar.set_description(f"Train RFs for stage {stage}")
868            with futures.ThreadPoolExecutor(n_threads) as tp:
869                this_forests = list(tp.map(
870                    _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1))
871                ))
872                forests.extend(this_forests)

Prepare shallow2deep enhancer training by pre-training random forests.

This function implements an advanced training procedure compared to prepare_shallow2deep. The 'sampling_strategy' argument determines an advnaced sampling strategies, which selects the samples to use for training the random forests. The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage. The random forests in stage 0 are trained from random balanced labels. For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)', and return the sampled features and labels. See for example the 'worst_points' function. Alternatively, one of the pre-defined strategies can be selected by passing one of the following names:

  • "random_poinst": Select random points.
  • "uncertain_points": Select points with the highest uncertainty.
  • "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies.
  • "worst_points": Select the points with the worst accuracies.
  • "worst_tiles": Selectt the tiles with the worst accuracies.
  • "balanced_dense_accumulate": Balanced dense accumulation.
Arguments:
  • raw_paths: The file paths to the raw data. May also be a single file.
  • raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
  • label_paths: The file paths to the lable data. May also be a single file.
  • label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
  • patch_shape_min: The minimal patch shape loaded for training a random forest.
  • patch_shape_max: The maximal patch shape loaded for training a random forest.
  • n_forests: The number of random forests to train.
  • n_threads: The number of threads for parallelizing the training.
  • output_folder: The folder for saving the random forests.
  • ndim: The dimensionality of the data.
  • forests_per_stage: The number of forests to train per stage.
  • sample_fraction_per_stage: The fraction of samples to use per stage.
  • sampling_strategy: The sampling strategy.
  • sampling_kwargs: The keyword arguments for the sampling strategy.
  • raw_transform: A transform to apply to the raw data before computing feautres on it.
  • label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
  • rois: Region of interests for the training data.
  • is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. If None, this wil be determined from the data.
  • balance_labels: Whether to balance the training labels for the random forest.
  • filter_config: The configuration for the image filters that are used to compute features for the random forest.
  • sampler: A sampler to reject samples from training.
  • rf_kwargs: Keyword arguments for creating the random forest.