torch_em.shallow2deep.prepare_shallow2deep

  1import os
  2import copy
  3import pickle
  4import warnings
  5from concurrent import futures
  6from glob import glob
  7from functools import partial
  8from typing import Callable, Dict, Optional, Sequence, Tuple, Union
  9
 10import numpy as np
 11import torch_em
 12from scipy.ndimage import gaussian_filter, convolve
 13from skimage.feature import peak_local_max
 14from sklearn.ensemble import RandomForestClassifier
 15from torch_em.segmentation import check_paths, is_segmentation_dataset, samples_to_datasets
 16from tqdm import tqdm
 17
 18import bioimage_cpp as bic
 19
 20# Map legacy vigra/fastfilters CamelCase filter names to bioimage_cpp.filters functions.
 21_FILTER_NAMES = {
 22    "gaussianSmoothing": "gaussian_smoothing",
 23    "laplacianOfGaussian": "laplacian_of_gaussian",
 24    "gaussianGradientMagnitude": "gaussian_gradient_magnitude",
 25    "hessianOfGaussianEigenvalues": "hessian_of_gaussian_eigenvalues",
 26    "structureTensorEigenvalues": "structure_tensor_eigenvalues",
 27}
 28
 29
 30def _resolve_filter(name):
 31    return getattr(bic.filters, _FILTER_NAMES.get(name, name))
 32
 33
 34class RFSegmentationDataset(torch_em.data.SegmentationDataset):
 35    """@private
 36    """
 37    _patch_shape_min = None
 38    _patch_shape_max = None
 39
 40    @property
 41    def patch_shape_min(self):
 42        return self._patch_shape_min
 43
 44    @patch_shape_min.setter
 45    def patch_shape_min(self, value):
 46        self._patch_shape_min = value
 47
 48    @property
 49    def patch_shape_max(self):
 50        return self._patch_shape_max
 51
 52    @patch_shape_max.setter
 53    def patch_shape_max(self, value):
 54        self._patch_shape_max = value
 55
 56    def _sample_bounding_box(self):
 57        assert self._patch_shape_min is not None and self._patch_shape_max is not None
 58        sample_shape = [
 59            pmin if pmin == pmax else np.random.randint(pmin, pmax)
 60            for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max)
 61        ]
 62        bb_start = [
 63            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
 64            for sh, psh in zip(self.shape, sample_shape)
 65        ]
 66        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, sample_shape))
 67
 68
 69class RFImageCollectionDataset(torch_em.data.ImageCollectionDataset):
 70    """@private
 71    """
 72    _patch_shape_min = None
 73    _patch_shape_max = None
 74
 75    @property
 76    def patch_shape_min(self):
 77        return self._patch_shape_min
 78
 79    @patch_shape_min.setter
 80    def patch_shape_min(self, value):
 81        self._patch_shape_min = value
 82
 83    @property
 84    def patch_shape_max(self):
 85        return self._patch_shape_max
 86
 87    @patch_shape_max.setter
 88    def patch_shape_max(self, value):
 89        self._patch_shape_max = value
 90
 91    def _sample_bounding_box(self, shape):
 92        if any(sh < psh for sh, psh in zip(shape, self.patch_shape_max)):
 93            raise NotImplementedError("Image padding is not supported yet.")
 94        assert self._patch_shape_min is not None and self._patch_shape_max is not None
 95        patch_shape = [
 96            pmin if pmin == pmax else np.random.randint(pmin, pmax)
 97            for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max)
 98        ]
 99        bb_start = [
100            np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, patch_shape)
101        ]
102        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape))
103
104
105def _load_rf_segmentation_dataset(
106    raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, **kwargs
107):
108    rois = kwargs.pop("rois", None)
109    sampler = kwargs.pop("sampler", None)
110    sampler = sampler if sampler else torch_em.data.MinForegroundSampler(min_fraction=0.01)
111    if isinstance(raw_paths, str):
112        if rois is not None:
113            assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois)
114        ds = RFSegmentationDataset(
115            raw_paths, raw_key, label_paths, label_key, roi=rois, patch_shape=patch_shape_min, sampler=sampler, **kwargs
116        )
117        ds.patch_shape_min = patch_shape_min
118        ds.patch_shape_max = patch_shape_max
119    else:
120        assert len(raw_paths) > 0
121        if rois is not None:
122            assert len(rois) == len(label_paths)
123            assert all(isinstance(roi, tuple) for roi in rois)
124        n_samples = kwargs.pop("n_samples", None)
125
126        samples_per_ds = (
127            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
128        )
129        ds = []
130        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
131            roi = None if rois is None else rois[i]
132            dset = RFSegmentationDataset(
133                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i],
134                patch_shape=patch_shape_min, sampler=sampler, **kwargs
135            )
136            dset.patch_shape_min = patch_shape_min
137            dset.patch_shape_max = patch_shape_max
138            ds.append(dset)
139        ds = torch_em.data.ConcatDataset(*ds)
140    return ds
141
142
143def _load_rf_image_collection_dataset(
144    raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, roi, **kwargs
145):
146    def _get_paths(rpath, rkey, lpath, lkey, this_roi):
147        rpath = glob(os.path.join(rpath, rkey))
148        rpath.sort()
149        if len(rpath) == 0:
150            raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}")
151        lpath = glob(os.path.join(lpath, lkey))
152        lpath.sort()
153        if len(rpath) != len(lpath):
154            raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}")
155
156        if this_roi is not None:
157            rpath, lpath = rpath[roi], lpath[roi]
158
159        return rpath, lpath
160
161    def _check_patch(patch_shape):
162        if len(patch_shape) == 3:
163            if patch_shape[0] != 1:
164                raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}")
165            patch_shape = patch_shape[1:]
166        assert len(patch_shape) == 2
167        return patch_shape
168
169    patch_shape_min = _check_patch(patch_shape_min)
170    patch_shape_max = _check_patch(patch_shape_max)
171
172    if isinstance(raw_paths, str):
173        raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi)
174        ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs)
175        ds.patch_shape_min = patch_shape_min
176        ds.patch_shape_max = patch_shape_max
177    elif raw_key is None:
178        assert label_key is None
179        assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple))
180        assert len(raw_paths) == len(label_paths)
181        ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs)
182        ds.patch_shape_min = patch_shape_min
183        ds.patch_shape_max = patch_shape_max
184    else:
185        ds = []
186        n_samples = kwargs.pop("n_samples", None)
187        samples_per_ds = (
188            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
189        )
190        if roi is None:
191            roi = len(raw_paths) * [None]
192        assert len(roi) == len(raw_paths)
193        for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)):
194            rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi)
195            dset = RFImageCollectionDataset(
196                rpath, lpath, patch_shape=patch_shape_min, n_samples=samples_per_ds[i], **kwargs
197            )
198            dset.patch_shape_min = patch_shape_min
199            dset.patch_shape_max = patch_shape_max
200            ds.append(dset)
201        ds = torch_em.data.ConcatDataset(*ds)
202    return ds
203
204
205def _get_filters(ndim, filters_and_sigmas):
206    # subset of ilastik default features
207    if filters_and_sigmas is None:
208        filters = [bic.filters.gaussian_smoothing,
209                   bic.filters.laplacian_of_gaussian,
210                   bic.filters.gaussian_gradient_magnitude,
211                   bic.filters.hessian_of_gaussian_eigenvalues,
212                   bic.filters.structure_tensor_eigenvalues]
213        sigmas = [0.7, 1.6, 3.5, 5.0]
214        filters_and_sigmas = [
215            (filt, sigma) if i != len(filters) - 1 else (partial(filt, outer_sigma=0.5*sigma), sigma)
216            for i, filt in enumerate(filters) for sigma in sigmas
217        ]
218    # validate the filter config
219    assert isinstance(filters_and_sigmas, (list, tuple))
220    for filt_and_sig in filters_and_sigmas:
221        filt, sig = filt_and_sig
222        assert callable(filt) or (isinstance(filt, str) and hasattr(bic.filters, _FILTER_NAMES.get(filt, filt)))
223        assert isinstance(sig, (float, tuple))
224        if isinstance(sig, tuple):
225            assert ndim is not None and len(sig) == ndim
226            assert all(isinstance(sigg, float) for sigg in sig)
227    return filters_and_sigmas
228
229
230def _calculate_response(raw, filter_, sigma):
231    if callable(filter_):
232        return filter_(raw, sigma)
233
234    # filter_ is still a string, convert it to a bioimage_cpp.filters function.
235    # bioimage_cpp.filters supports passing sigma as a scalar or as a per-axis tuple.
236    func = _resolve_filter(filter_)
237
238    # special case since the additional argument outer_sigma
239    # is needed for the structure tensor eigenvalues filter
240    if filter_ in ("structureTensorEigenvalues", "structure_tensor_eigenvalues"):
241        outer_sigma = tuple(s*2 for s in sigma) if isinstance(sigma, tuple) else 2*sigma
242        return func(raw, sigma, outer_sigma=outer_sigma)
243
244    return func(raw, sigma)
245
246
247def _apply_filters(raw, filters_and_sigmas):
248    features = []
249    for filter_, sigma in filters_and_sigmas:
250        response = _calculate_response(raw, filter_, sigma)
251        if response.ndim > raw.ndim:
252            for c in range(response.shape[-1]):
253                features.append(response[..., c].flatten())
254        else:
255            features.append(response.flatten())
256    features = np.concatenate([ff[:, None] for ff in features], axis=1)
257    return features
258
259
260def _apply_filters_with_mask(raw, filters_and_sigmas, mask):
261    features = []
262    for filter_, sigma in filters_and_sigmas:
263        response = _calculate_response(raw, filter_, sigma)
264        if response.ndim > raw.ndim:
265            for c in range(response.shape[-1]):
266                features.append(response[..., c][mask])
267        else:
268            features.append(response[mask])
269    features = np.concatenate([ff[:, None] for ff in features], axis=1)
270    return features
271
272
273def _balance_labels(labels, mask):
274    class_ids, label_counts = np.unique(labels[mask], return_counts=True)
275    n_classes = len(class_ids)
276    assert class_ids.tolist() == list(range(n_classes))
277
278    min_class = class_ids[np.argmin(label_counts)]
279    n_labels = label_counts[min_class]
280
281    for class_id in class_ids:
282        if class_id == min_class:
283            continue
284        n_discard = label_counts[class_id] - n_labels
285        # sample from the current class
286        # shuffle the positions and only keep up to n_labels in the mask
287        label_pos = np.where(labels == class_id)
288        discard_ids = np.arange(len(label_pos[0]))
289        np.random.shuffle(discard_ids)
290        discard_ids = discard_ids[:n_discard]
291        discard_mask = tuple(pos[discard_ids] for pos in label_pos)
292        mask[discard_mask] = False
293
294    assert mask.sum() == n_classes * n_labels
295    return mask
296
297
298def _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels, return_mask=False):
299    # find the mask for where we compute filters and labels
300    # by default we exclude everything that has label -1
301    assert labels.shape == raw.shape
302    mask = labels != -1
303    if balance_labels:
304        mask = _balance_labels(labels, mask)
305    labels = labels[mask]
306    assert labels.ndim == 1
307    features = _apply_filters_with_mask(raw, filters_and_sigmas, mask)
308    assert features.ndim == 2
309    assert len(features) == len(labels)
310    if return_mask:
311        return features, labels, mask
312    else:
313        return features, labels
314
315
316def _prepare_shallow2deep(
317    raw_paths,
318    raw_key,
319    label_paths,
320    label_key,
321    patch_shape_min,
322    patch_shape_max,
323    n_forests,
324    ndim,
325    raw_transform,
326    label_transform,
327    rois,
328    is_seg_dataset,
329    filter_config,
330    sampler,
331):
332    assert len(patch_shape_min) == len(patch_shape_max)
333    assert all(maxs >= mins for maxs, mins in zip(patch_shape_max, patch_shape_min))
334    check_paths(raw_paths, label_paths)
335
336    # get the correct dataset
337    if is_seg_dataset is None:
338        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
339    if is_seg_dataset:
340        ds = _load_rf_segmentation_dataset(raw_paths, raw_key, label_paths, label_key,
341                                           patch_shape_min, patch_shape_max,
342                                           raw_transform=raw_transform, label_transform=label_transform,
343                                           rois=rois, n_samples=n_forests, sampler=sampler)
344    else:
345        ds = _load_rf_image_collection_dataset(raw_paths, raw_key, label_paths, label_key,
346                                               patch_shape_min, patch_shape_max, roi=rois,
347                                               raw_transform=raw_transform, label_transform=label_transform,
348                                               n_samples=n_forests)
349
350    assert len(ds) == n_forests, f"{len(ds), {n_forests}}"
351    filters_and_sigmas = _get_filters(ndim, filter_config)
352    return ds, filters_and_sigmas
353
354
355def _serialize_feature_config(filters_and_sigmas):
356    feature_config = [
357        (filt if isinstance(filt, str) else (filt.func.__name__ if isinstance(filt, partial) else filt.__name__), sigma)
358        for filt, sigma in filters_and_sigmas
359    ]
360    return feature_config
361
362
363def prepare_shallow2deep(
364    raw_paths: Union[str, Sequence[str]],
365    raw_key: Optional[str],
366    label_paths: Union[str, Sequence[str]],
367    label_key: Optional[str],
368    patch_shape_min: Tuple[int, ...],
369    patch_shape_max: Tuple[int, ...],
370    n_forests: int,
371    n_threads: int,
372    output_folder: str,
373    ndim: int,
374    raw_transform: Optional[Callable] = None,
375    label_transform: Optional[Callable] = None,
376    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
377    is_seg_dataset: Optional[bool] = None,
378    balance_labels: bool = True,
379    filter_config: Optional[Dict] = None,
380    sampler: Optional[Callable] = None,
381    **rf_kwargs,
382) -> None:
383    """Prepare shallow2deep enhancer training by pre-training random forests.
384
385    Args:
386        raw_paths: The file paths to the raw data. May also be a single file.
387        raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
388        label_paths: The file paths to the lable data. May also be a single file.
389        label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
390        patch_shape_min: The minimal patch shape loaded for training a random forest.
391        patch_shape_max: The maximal patch shape loaded for training a random forest.
392        n_forests: The number of random forests to train.
393        n_threads: The number of threads for parallelizing the training.
394        output_folder: The folder for saving the random forests.
395        ndim: The dimensionality of the data.
396        raw_transform: A transform to apply to the raw data before computing feautres on it.
397        label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
398        rois: Region of interests for the training data.
399        is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset.
400            If None, this wil be determined from the data.
401        balance_labels: Whether to balance the training labels for the random forest.
402        filter_config: The configuration for the image filters that are used to compute features for the random forest.
403        sampler: A sampler to reject samples from training.
404        rf_kwargs: Keyword arguments for creating the random forest.
405    """
406    os.makedirs(output_folder, exist_ok=True)
407    ds, filters_and_sigmas = _prepare_shallow2deep(
408        raw_paths, raw_key, label_paths, label_key,
409        patch_shape_min, patch_shape_max, n_forests, ndim,
410        raw_transform, label_transform, rois, is_seg_dataset,
411        filter_config, sampler,
412    )
413    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
414
415    def _train_rf(rf_id):
416        # Sample random patch with dataset.
417        raw, labels = ds[rf_id]
418        # Cast to numpy and remove channel axis.
419        # Need to update this to support multi-channel input data and/or multi class prediction.
420        raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
421        assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
422        features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels)
423        rf = RandomForestClassifier(**rf_kwargs)
424        rf.fit(features, labels)
425        # Monkey patch these so that we know the feature config and dimensionality.
426        rf.feature_ndim = ndim
427        rf.feature_config = serialized_feature_config
428        out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
429        with open(out_path, "wb") as f:
430            pickle.dump(rf, f)
431
432    with futures.ThreadPoolExecutor(n_threads) as tp:
433        list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests))
434
435
436def _score_based_points(
437    score_function,
438    features, labels, rf_id,
439    forests, forests_per_stage,
440    sample_fraction_per_stage,
441    accumulate_samples,
442):
443    # get the corresponding random forest from the last stage
444    # and predict with it
445    last_forest = forests[rf_id - forests_per_stage]
446    pred = last_forest.predict_proba(features)
447
448    score = score_function(pred, labels)
449    assert len(score) == len(features)
450
451    # get training samples based on the label-prediction diff
452    samples = []
453    nc = len(np.unique(labels))
454    # sample in a class balanced way
455    n_samples = int(sample_fraction_per_stage * len(features))
456    n_samples_class = n_samples // nc
457    for class_id in range(nc):
458        class_indices = np.where(labels == class_id)[0]
459        this_samples = class_indices[np.argsort(score[class_indices])[::-1][:n_samples_class]]
460        samples.append(this_samples)
461    samples = np.concatenate(samples)
462
463    # get the features and labels, add from previous rf if specified
464    features, labels = features[samples], labels[samples]
465    if accumulate_samples:
466        features = np.concatenate([last_forest.train_features, features], axis=0)
467        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
468
469    return features, labels
470
471
472def worst_points(
473    features, labels, rf_id,
474    forests, forests_per_stage,
475    sample_fraction_per_stage,
476    accumulate_samples=True,
477    **kwargs,
478):
479    """@private
480    """
481    def score(pred, labels):
482        # labels to one-hot encoding
483        unique, inverse = np.unique(labels, return_inverse=True)
484        onehot = np.eye(unique.shape[0])[inverse]
485        # compute the difference between labels and prediction
486        return np.abs(onehot - pred).sum(axis=1)
487
488    return _score_based_points(
489        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
490    )
491
492
493def uncertain_points(
494    features, labels, rf_id,
495    forests, forests_per_stage,
496    sample_fraction_per_stage,
497    accumulate_samples=True,
498    **kwargs,
499):
500    """@private
501    """
502    def score(pred, labels):
503        assert pred.ndim == 2
504        channel_sorted = np.sort(pred, axis=1)
505        uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2]
506        return uncertainty
507
508    return _score_based_points(
509        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
510    )
511
512
513def uncertain_worst_points(
514    features, labels, rf_id,
515    forests, forests_per_stage,
516    sample_fraction_per_stage,
517    accumulate_samples=True,
518    alpha=0.5,
519    **kwargs,
520):
521    """@private
522    """
523    def score(pred, labels):
524        assert pred.ndim == 2
525
526        # labels to one-hot encoding
527        unique, inverse = np.unique(labels, return_inverse=True)
528        onehot = np.eye(unique.shape[0])[inverse]
529        # compute the difference between labels and prediction
530        diff = np.abs(onehot - pred).sum(axis=1)
531
532        channel_sorted = np.sort(pred, axis=1)
533        uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2]
534        return alpha * diff + (1.0 - alpha) * uncertainty
535
536    return _score_based_points(
537        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
538    )
539
540
541def random_points(
542    features, labels, rf_id,
543    forests, forests_per_stage,
544    sample_fraction_per_stage,
545    accumulate_samples=True,
546    **kwargs,
547):
548    """@private
549    """
550    samples = []
551    nc = len(np.unique(labels))
552    # sample in a class balanced way
553    n_samples = int(sample_fraction_per_stage * len(features))
554    n_samples_class = n_samples // nc
555    for class_id in range(nc):
556        class_indices = np.where(labels == class_id)[0]
557        this_samples = np.random.choice(
558            class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
559        )
560        samples.append(this_samples)
561    samples = np.concatenate(samples)
562    features, labels = features[samples], labels[samples]
563
564    if accumulate_samples and rf_id >= forests_per_stage:
565        last_forest = forests[rf_id - forests_per_stage]
566        features = np.concatenate([last_forest.train_features, features], axis=0)
567        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
568
569    return features, labels
570
571
572def worst_tiles(
573    features, labels, rf_id,
574    forests, forests_per_stage,
575    sample_fraction_per_stage,
576    img_shape,
577    mask,
578    tile_shape=[25, 25],
579    smoothing_sigma=None,
580    accumulate_samples=True,
581    **kwargs,
582):
583    """@private
584    """
585    # check inputs
586    ndim = len(img_shape)
587    assert ndim in [2, 3], img_shape
588    assert len(tile_shape) == ndim, tile_shape
589
590    # get the corresponding random forest from the last stage
591    # and predict with it
592    last_forest = forests[rf_id - forests_per_stage]
593    pred = last_forest.predict_proba(features)
594
595    # labels to one-hot encoding
596    unique, inverse = np.unique(labels, return_inverse=True)
597    onehot = np.eye(unique.shape[0])[inverse]
598
599    # compute the difference between labels and prediction
600    diff = np.abs(onehot - pred)
601    assert len(diff) == len(features)
602
603    # reshape diff to image shape
604    # we need to also take into account the mask here, and if we apply any masking
605    # because we can't directly reshape if we have it
606    if mask.sum() != mask.size:
607        # get the diff image
608        diff_img = np.zeros(img_shape + diff.shape[-1:], dtype=diff.dtype)
609        diff_img[mask] = diff
610        # inflate the features
611        full_features = np.zeros((mask.size,) + features.shape[-1:], dtype=features.dtype)
612        full_features[mask.ravel()] = features
613        features = full_features
614        # inflate the labels (with -1 so this will not be sampled)
615        full_labels = np.full(mask.size, -1, dtype="int8")
616        full_labels[mask.ravel()] = labels
617        labels = full_labels
618    else:
619        diff_img = diff.reshape(img_shape + (-1,))
620
621    # get the number of classes (not counting ignore label)
622    class_ids = np.unique(labels)
623    nc = len(class_ids) - 1 if -1 in class_ids else len(class_ids)
624
625    # sample in a class balanced way
626    n_samples_class = int(sample_fraction_per_stage * len(features)) // nc
627    samples = []
628    for class_id in range(nc):
629        # smooth either with gaussian or 1-kernel
630        if smoothing_sigma:
631            diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode="constant")
632        else:
633            kernel = np.ones(tile_shape)
634            diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode="constant")
635
636        # get training samples based on tiles around maxima of the label-prediction diff
637        # do this in a class-specific way to ensure that each class is sampled
638        # get maxima of the label-prediction diff (they seem to be sorted already)
639        max_centers = peak_local_max(
640            diff_img_smooth,
641            min_distance=max(tile_shape),
642            exclude_border=tuple([s // 2 for s in tile_shape])
643        )
644
645        # get indices of tiles around maxima
646        tiles = []
647        for center in max_centers:
648            tile_slice = tuple(
649                slice(
650                    center[d]-tile_shape[d]//2,
651                    center[d]+tile_shape[d]//2 + 1,
652                    None
653                ) for d in range(ndim)
654            )
655            grid = np.mgrid[tile_slice]
656            samples_in_tile = grid.reshape(ndim, -1)
657            samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape)
658            tiles.append(samples_in_tile)
659
660        # this (very rarely) fails due to empty tile list. Since we usually
661        # accumulate the features this doesn't hurt much and we can continue
662        try:
663            tiles = np.concatenate(tiles)
664            # take samples that belong to the current class
665            this_samples = tiles[labels[tiles] == class_id][:n_samples_class]
666            samples.append(this_samples)
667        except ValueError:
668            pass
669
670    try:
671        samples = np.concatenate(samples)
672        features, labels = features[samples], labels[samples]
673
674        # get the features and labels, add from previous rf if specified
675        if accumulate_samples:
676            features = np.concatenate([last_forest.train_features, features], axis=0)
677            labels = np.concatenate([last_forest.train_labels, labels], axis=0)
678    except ValueError:
679        features, labels = last_forest.train_features, last_forest.train_labels
680        warnings.warn(
681            f"No features were sampled for forest {rf_id} using features of forest {rf_id - forests_per_stage}"
682        )
683
684    return features, labels
685
686
687def balanced_dense_accumulate(
688    features, labels, rf_id,
689    forests, forests_per_stage,
690    sample_fraction_per_stage,
691    accumulate_samples=True,
692    **kwargs,
693):
694    """@private
695    """
696    samples = []
697    nc = len(np.unique(labels))
698    # sample in a class balanced way
699    # take all pixels from minority class
700    # and choose same amount from other classes randomly
701    n_samples_class = np.unique(labels, return_counts=True)[1].min()
702    for class_id in range(nc):
703        class_indices = np.where(labels == class_id)[0]
704        this_samples = np.random.choice(
705            class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
706        )
707        samples.append(this_samples)
708    samples = np.concatenate(samples)
709    features, labels = features[samples], labels[samples]
710
711    # accumulate
712    if accumulate_samples and rf_id >= forests_per_stage:
713        last_forest = forests[rf_id - forests_per_stage]
714        features = np.concatenate([last_forest.train_features, features], axis=0)
715        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
716
717    return features, labels
718
719
720SAMPLING_STRATEGIES = {
721    "random_points": random_points,
722    "uncertain_points": uncertain_points,
723    "uncertain_worst_points": uncertain_worst_points,
724    "worst_points": worst_points,
725    "worst_tiles": worst_tiles,
726    "balanced_dense_accumulate": balanced_dense_accumulate,
727}
728"""@private
729"""
730
731
732def prepare_shallow2deep_advanced(
733    raw_paths: Union[str, Sequence[str]],
734    raw_key: Optional[str],
735    label_paths: Union[str, Sequence[str]],
736    label_key: Optional[str],
737    patch_shape_min: Tuple[int, ...],
738    patch_shape_max: Tuple[int, ...],
739    n_forests: int,
740    n_threads: int,
741    output_folder: str,
742    ndim: int,
743    forests_per_stage: int,
744    sample_fraction_per_stage: float,
745    sampling_strategy: Union[str, Callable] = "worst_points",
746    sampling_kwargs: Dict = {},
747    raw_transform: Optional[Callable] = None,
748    label_transform: Optional[Callable] = None,
749    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
750    is_seg_dataset: Optional[bool] = None,
751    balance_labels: bool = True,
752    filter_config: Optional[Dict] = None,
753    sampler: Optional[Callable] = None,
754    **rf_kwargs,
755) -> None:
756    """Prepare shallow2deep enhancer training by pre-training random forests.
757
758    This function implements an advanced training procedure compared to `prepare_shallow2deep`.
759    The 'sampling_strategy' argument determines an advnaced sampling strategies,
760    which selects the samples to use for training the random forests.
761    The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests
762    are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage.
763    The random forests in stage 0 are trained from random balanced labels.
764    For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature
765    '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)',
766    and return the sampled features and labels. See for example the 'worst_points' function.
767    Alternatively, one of the pre-defined strategies can be selected by passing one of the following names:
768    - "random_poinst": Select random points.
769    - "uncertain_points": Select points with the highest uncertainty.
770    - "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies.
771    - "worst_points": Select the points with the worst accuracies.
772    - "worst_tiles": Selectt the tiles with the worst accuracies.
773    - "balanced_dense_accumulate": Balanced dense accumulation.
774
775    Args:
776        raw_paths: The file paths to the raw data. May also be a single file.
777        raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
778        label_paths: The file paths to the lable data. May also be a single file.
779        label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
780        patch_shape_min: The minimal patch shape loaded for training a random forest.
781        patch_shape_max: The maximal patch shape loaded for training a random forest.
782        n_forests: The number of random forests to train.
783        n_threads: The number of threads for parallelizing the training.
784        output_folder: The folder for saving the random forests.
785        ndim: The dimensionality of the data.
786        forests_per_stage: The number of forests to train per stage.
787        sample_fraction_per_stage: The fraction of samples to use per stage.
788        sampling_strategy: The sampling strategy.
789        sampling_kwargs: The keyword arguments for the sampling strategy.
790        raw_transform: A transform to apply to the raw data before computing feautres on it.
791        label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
792        rois: Region of interests for the training data.
793        is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset.
794            If None, this wil be determined from the data.
795        balance_labels: Whether to balance the training labels for the random forest.
796        filter_config: The configuration for the image filters that are used to compute features for the random forest.
797        sampler: A sampler to reject samples from training.
798        rf_kwargs: Keyword arguments for creating the random forest.
799    """
800    os.makedirs(output_folder, exist_ok=True)
801    ds, filters_and_sigmas = _prepare_shallow2deep(
802        raw_paths, raw_key, label_paths, label_key,
803        patch_shape_min, patch_shape_max, n_forests, ndim,
804        raw_transform, label_transform, rois, is_seg_dataset,
805        filter_config, sampler,
806    )
807    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
808
809    forests = []
810    n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\
811        n_forests // forests_per_stage + 1
812
813    if isinstance(sampling_strategy, str):
814        assert sampling_strategy in SAMPLING_STRATEGIES, \
815            f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}"
816        sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy]
817    assert callable(sampling_strategy)
818
819    with tqdm(total=n_forests) as pbar:
820
821        def _train_rf(rf_id):
822            # sample random patch with dataset
823            raw, labels = ds[rf_id]
824
825            # cast to numpy and remove channel axis
826            # need to update this to support multi-channel input data and/or multi class prediction
827            raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
828            assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
829
830            # monkey patch original shape to sampling_kwargs
831            # deepcopy needed due to multithreading
832            current_kwargs = copy.deepcopy(sampling_kwargs)
833            current_kwargs["img_shape"] = raw.shape
834
835            # only balance samples for the first (densely trained) rfs
836            features, labels, mask = _get_features_and_labels(
837                raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True
838            )
839            if forests:  # we have forests: apply the sampling strategy
840                features, labels = sampling_strategy(
841                    features, labels, rf_id,
842                    forests, forests_per_stage,
843                    sample_fraction_per_stage,
844                    mask=mask,
845                    **current_kwargs,
846                )
847            else:  # sample randomly
848                features, labels = random_points(
849                    features, labels, rf_id,
850                    forests, forests_per_stage,
851                    sample_fraction_per_stage,
852                )
853
854            # fit the random forest
855            assert len(features) == len(labels)
856            rf = RandomForestClassifier(**rf_kwargs)
857            rf.fit(features, labels)
858            # monkey patch these so that we know the feature config and dimensionality
859            rf.feature_ndim = ndim
860            rf.feature_config = serialized_feature_config
861
862            # save the random forest, update pbar, return it
863            out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
864            with open(out_path, "wb") as f:
865                pickle.dump(rf, f)
866
867            # monkey patch the training data and labels so we can re-use it in later stages
868            rf.train_features = features
869            rf.train_labels = labels
870
871            pbar.update(1)
872            return rf
873
874        for stage in range(n_stages):
875            pbar.set_description(f"Train RFs for stage {stage}")
876            with futures.ThreadPoolExecutor(n_threads) as tp:
877                this_forests = list(tp.map(
878                    _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1))
879                ))
880                forests.extend(this_forests)
def prepare_shallow2deep( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], label_paths: Union[str, Sequence[str]], label_key: Optional[str], patch_shape_min: Tuple[int, ...], patch_shape_max: Tuple[int, ...], n_forests: int, n_threads: int, output_folder: str, ndim: int, raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, is_seg_dataset: Optional[bool] = None, balance_labels: bool = True, filter_config: Optional[Dict] = None, sampler: Optional[Callable] = None, **rf_kwargs) -> None:
364def prepare_shallow2deep(
365    raw_paths: Union[str, Sequence[str]],
366    raw_key: Optional[str],
367    label_paths: Union[str, Sequence[str]],
368    label_key: Optional[str],
369    patch_shape_min: Tuple[int, ...],
370    patch_shape_max: Tuple[int, ...],
371    n_forests: int,
372    n_threads: int,
373    output_folder: str,
374    ndim: int,
375    raw_transform: Optional[Callable] = None,
376    label_transform: Optional[Callable] = None,
377    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
378    is_seg_dataset: Optional[bool] = None,
379    balance_labels: bool = True,
380    filter_config: Optional[Dict] = None,
381    sampler: Optional[Callable] = None,
382    **rf_kwargs,
383) -> None:
384    """Prepare shallow2deep enhancer training by pre-training random forests.
385
386    Args:
387        raw_paths: The file paths to the raw data. May also be a single file.
388        raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
389        label_paths: The file paths to the lable data. May also be a single file.
390        label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
391        patch_shape_min: The minimal patch shape loaded for training a random forest.
392        patch_shape_max: The maximal patch shape loaded for training a random forest.
393        n_forests: The number of random forests to train.
394        n_threads: The number of threads for parallelizing the training.
395        output_folder: The folder for saving the random forests.
396        ndim: The dimensionality of the data.
397        raw_transform: A transform to apply to the raw data before computing feautres on it.
398        label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
399        rois: Region of interests for the training data.
400        is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset.
401            If None, this wil be determined from the data.
402        balance_labels: Whether to balance the training labels for the random forest.
403        filter_config: The configuration for the image filters that are used to compute features for the random forest.
404        sampler: A sampler to reject samples from training.
405        rf_kwargs: Keyword arguments for creating the random forest.
406    """
407    os.makedirs(output_folder, exist_ok=True)
408    ds, filters_and_sigmas = _prepare_shallow2deep(
409        raw_paths, raw_key, label_paths, label_key,
410        patch_shape_min, patch_shape_max, n_forests, ndim,
411        raw_transform, label_transform, rois, is_seg_dataset,
412        filter_config, sampler,
413    )
414    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
415
416    def _train_rf(rf_id):
417        # Sample random patch with dataset.
418        raw, labels = ds[rf_id]
419        # Cast to numpy and remove channel axis.
420        # Need to update this to support multi-channel input data and/or multi class prediction.
421        raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
422        assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
423        features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels)
424        rf = RandomForestClassifier(**rf_kwargs)
425        rf.fit(features, labels)
426        # Monkey patch these so that we know the feature config and dimensionality.
427        rf.feature_ndim = ndim
428        rf.feature_config = serialized_feature_config
429        out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
430        with open(out_path, "wb") as f:
431            pickle.dump(rf, f)
432
433    with futures.ThreadPoolExecutor(n_threads) as tp:
434        list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests))

Prepare shallow2deep enhancer training by pre-training random forests.

Arguments:
  • raw_paths: The file paths to the raw data. May also be a single file.
  • raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
  • label_paths: The file paths to the lable data. May also be a single file.
  • label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
  • patch_shape_min: The minimal patch shape loaded for training a random forest.
  • patch_shape_max: The maximal patch shape loaded for training a random forest.
  • n_forests: The number of random forests to train.
  • n_threads: The number of threads for parallelizing the training.
  • output_folder: The folder for saving the random forests.
  • ndim: The dimensionality of the data.
  • raw_transform: A transform to apply to the raw data before computing feautres on it.
  • label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
  • rois: Region of interests for the training data.
  • is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. If None, this wil be determined from the data.
  • balance_labels: Whether to balance the training labels for the random forest.
  • filter_config: The configuration for the image filters that are used to compute features for the random forest.
  • sampler: A sampler to reject samples from training.
  • rf_kwargs: Keyword arguments for creating the random forest.
def prepare_shallow2deep_advanced( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], label_paths: Union[str, Sequence[str]], label_key: Optional[str], patch_shape_min: Tuple[int, ...], patch_shape_max: Tuple[int, ...], n_forests: int, n_threads: int, output_folder: str, ndim: int, forests_per_stage: int, sample_fraction_per_stage: float, sampling_strategy: Union[str, Callable] = 'worst_points', sampling_kwargs: Dict = {}, raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, is_seg_dataset: Optional[bool] = None, balance_labels: bool = True, filter_config: Optional[Dict] = None, sampler: Optional[Callable] = None, **rf_kwargs) -> None:
733def prepare_shallow2deep_advanced(
734    raw_paths: Union[str, Sequence[str]],
735    raw_key: Optional[str],
736    label_paths: Union[str, Sequence[str]],
737    label_key: Optional[str],
738    patch_shape_min: Tuple[int, ...],
739    patch_shape_max: Tuple[int, ...],
740    n_forests: int,
741    n_threads: int,
742    output_folder: str,
743    ndim: int,
744    forests_per_stage: int,
745    sample_fraction_per_stage: float,
746    sampling_strategy: Union[str, Callable] = "worst_points",
747    sampling_kwargs: Dict = {},
748    raw_transform: Optional[Callable] = None,
749    label_transform: Optional[Callable] = None,
750    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
751    is_seg_dataset: Optional[bool] = None,
752    balance_labels: bool = True,
753    filter_config: Optional[Dict] = None,
754    sampler: Optional[Callable] = None,
755    **rf_kwargs,
756) -> None:
757    """Prepare shallow2deep enhancer training by pre-training random forests.
758
759    This function implements an advanced training procedure compared to `prepare_shallow2deep`.
760    The 'sampling_strategy' argument determines an advnaced sampling strategies,
761    which selects the samples to use for training the random forests.
762    The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests
763    are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage.
764    The random forests in stage 0 are trained from random balanced labels.
765    For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature
766    '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)',
767    and return the sampled features and labels. See for example the 'worst_points' function.
768    Alternatively, one of the pre-defined strategies can be selected by passing one of the following names:
769    - "random_poinst": Select random points.
770    - "uncertain_points": Select points with the highest uncertainty.
771    - "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies.
772    - "worst_points": Select the points with the worst accuracies.
773    - "worst_tiles": Selectt the tiles with the worst accuracies.
774    - "balanced_dense_accumulate": Balanced dense accumulation.
775
776    Args:
777        raw_paths: The file paths to the raw data. May also be a single file.
778        raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
779        label_paths: The file paths to the lable data. May also be a single file.
780        label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
781        patch_shape_min: The minimal patch shape loaded for training a random forest.
782        patch_shape_max: The maximal patch shape loaded for training a random forest.
783        n_forests: The number of random forests to train.
784        n_threads: The number of threads for parallelizing the training.
785        output_folder: The folder for saving the random forests.
786        ndim: The dimensionality of the data.
787        forests_per_stage: The number of forests to train per stage.
788        sample_fraction_per_stage: The fraction of samples to use per stage.
789        sampling_strategy: The sampling strategy.
790        sampling_kwargs: The keyword arguments for the sampling strategy.
791        raw_transform: A transform to apply to the raw data before computing feautres on it.
792        label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
793        rois: Region of interests for the training data.
794        is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset.
795            If None, this wil be determined from the data.
796        balance_labels: Whether to balance the training labels for the random forest.
797        filter_config: The configuration for the image filters that are used to compute features for the random forest.
798        sampler: A sampler to reject samples from training.
799        rf_kwargs: Keyword arguments for creating the random forest.
800    """
801    os.makedirs(output_folder, exist_ok=True)
802    ds, filters_and_sigmas = _prepare_shallow2deep(
803        raw_paths, raw_key, label_paths, label_key,
804        patch_shape_min, patch_shape_max, n_forests, ndim,
805        raw_transform, label_transform, rois, is_seg_dataset,
806        filter_config, sampler,
807    )
808    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
809
810    forests = []
811    n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\
812        n_forests // forests_per_stage + 1
813
814    if isinstance(sampling_strategy, str):
815        assert sampling_strategy in SAMPLING_STRATEGIES, \
816            f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}"
817        sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy]
818    assert callable(sampling_strategy)
819
820    with tqdm(total=n_forests) as pbar:
821
822        def _train_rf(rf_id):
823            # sample random patch with dataset
824            raw, labels = ds[rf_id]
825
826            # cast to numpy and remove channel axis
827            # need to update this to support multi-channel input data and/or multi class prediction
828            raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
829            assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
830
831            # monkey patch original shape to sampling_kwargs
832            # deepcopy needed due to multithreading
833            current_kwargs = copy.deepcopy(sampling_kwargs)
834            current_kwargs["img_shape"] = raw.shape
835
836            # only balance samples for the first (densely trained) rfs
837            features, labels, mask = _get_features_and_labels(
838                raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True
839            )
840            if forests:  # we have forests: apply the sampling strategy
841                features, labels = sampling_strategy(
842                    features, labels, rf_id,
843                    forests, forests_per_stage,
844                    sample_fraction_per_stage,
845                    mask=mask,
846                    **current_kwargs,
847                )
848            else:  # sample randomly
849                features, labels = random_points(
850                    features, labels, rf_id,
851                    forests, forests_per_stage,
852                    sample_fraction_per_stage,
853                )
854
855            # fit the random forest
856            assert len(features) == len(labels)
857            rf = RandomForestClassifier(**rf_kwargs)
858            rf.fit(features, labels)
859            # monkey patch these so that we know the feature config and dimensionality
860            rf.feature_ndim = ndim
861            rf.feature_config = serialized_feature_config
862
863            # save the random forest, update pbar, return it
864            out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
865            with open(out_path, "wb") as f:
866                pickle.dump(rf, f)
867
868            # monkey patch the training data and labels so we can re-use it in later stages
869            rf.train_features = features
870            rf.train_labels = labels
871
872            pbar.update(1)
873            return rf
874
875        for stage in range(n_stages):
876            pbar.set_description(f"Train RFs for stage {stage}")
877            with futures.ThreadPoolExecutor(n_threads) as tp:
878                this_forests = list(tp.map(
879                    _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1))
880                ))
881                forests.extend(this_forests)

Prepare shallow2deep enhancer training by pre-training random forests.

This function implements an advanced training procedure compared to prepare_shallow2deep. The 'sampling_strategy' argument determines an advnaced sampling strategies, which selects the samples to use for training the random forests. The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage. The random forests in stage 0 are trained from random balanced labels. For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)', and return the sampled features and labels. See for example the 'worst_points' function. Alternatively, one of the pre-defined strategies can be selected by passing one of the following names:

  • "random_poinst": Select random points.
  • "uncertain_points": Select points with the highest uncertainty.
  • "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies.
  • "worst_points": Select the points with the worst accuracies.
  • "worst_tiles": Selectt the tiles with the worst accuracies.
  • "balanced_dense_accumulate": Balanced dense accumulation.
Arguments:
  • raw_paths: The file paths to the raw data. May also be a single file.
  • raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
  • label_paths: The file paths to the lable data. May also be a single file.
  • label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
  • patch_shape_min: The minimal patch shape loaded for training a random forest.
  • patch_shape_max: The maximal patch shape loaded for training a random forest.
  • n_forests: The number of random forests to train.
  • n_threads: The number of threads for parallelizing the training.
  • output_folder: The folder for saving the random forests.
  • ndim: The dimensionality of the data.
  • forests_per_stage: The number of forests to train per stage.
  • sample_fraction_per_stage: The fraction of samples to use per stage.
  • sampling_strategy: The sampling strategy.
  • sampling_kwargs: The keyword arguments for the sampling strategy.
  • raw_transform: A transform to apply to the raw data before computing feautres on it.
  • label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
  • rois: Region of interests for the training data.
  • is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. If None, this wil be determined from the data.
  • balance_labels: Whether to balance the training labels for the random forest.
  • filter_config: The configuration for the image filters that are used to compute features for the random forest.
  • sampler: A sampler to reject samples from training.
  • rf_kwargs: Keyword arguments for creating the random forest.