
  1import os
  2import copy
  3import pickle
  4import warnings
  5from concurrent import futures
  6from glob import glob
  7from functools import partial
  8from typing import Callable, Dict, Optional, Sequence, Tuple, Union
 10import numpy as np
 11import torch_em
 12from scipy.ndimage import gaussian_filter, convolve
 13from skimage.feature import peak_local_max
 14from sklearn.ensemble import RandomForestClassifier
 15from torch_em.segmentation import check_paths, is_segmentation_dataset, samples_to_datasets
 16from tqdm import tqdm
 18import vigra
 20    import fastfilters as filter_impl
 21except ImportError:
 22    import vigra.filters as filter_impl
 25class RFSegmentationDataset(
 26    """@private
 27    """
 28    _patch_shape_min = None
 29    _patch_shape_max = None
 31    @property
 32    def patch_shape_min(self):
 33        return self._patch_shape_min
 35    @patch_shape_min.setter
 36    def patch_shape_min(self, value):
 37        self._patch_shape_min = value
 39    @property
 40    def patch_shape_max(self):
 41        return self._patch_shape_max
 43    @patch_shape_max.setter
 44    def patch_shape_max(self, value):
 45        self._patch_shape_max = value
 47    def _sample_bounding_box(self):
 48        assert self._patch_shape_min is not None and self._patch_shape_max is not None
 49        sample_shape = [
 50            pmin if pmin == pmax else np.random.randint(pmin, pmax)
 51            for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max)
 52        ]
 53        bb_start = [
 54            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
 55            for sh, psh in zip(self.shape, sample_shape)
 56        ]
 57        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, sample_shape))
 60class RFImageCollectionDataset(
 61    """@private
 62    """
 63    _patch_shape_min = None
 64    _patch_shape_max = None
 66    @property
 67    def patch_shape_min(self):
 68        return self._patch_shape_min
 70    @patch_shape_min.setter
 71    def patch_shape_min(self, value):
 72        self._patch_shape_min = value
 74    @property
 75    def patch_shape_max(self):
 76        return self._patch_shape_max
 78    @patch_shape_max.setter
 79    def patch_shape_max(self, value):
 80        self._patch_shape_max = value
 82    def _sample_bounding_box(self, shape):
 83        if any(sh < psh for sh, psh in zip(shape, self.patch_shape_max)):
 84            raise NotImplementedError("Image padding is not supported yet.")
 85        assert self._patch_shape_min is not None and self._patch_shape_max is not None
 86        patch_shape = [
 87            pmin if pmin == pmax else np.random.randint(pmin, pmax)
 88            for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max)
 89        ]
 90        bb_start = [
 91            np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, patch_shape)
 92        ]
 93        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape))
 96def _load_rf_segmentation_dataset(
 97    raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, **kwargs
 99    rois = kwargs.pop("rois", None)
100    sampler = kwargs.pop("sampler", None)
101    sampler = sampler if sampler else
102    if isinstance(raw_paths, str):
103        if rois is not None:
104            assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois)
105        ds = RFSegmentationDataset(
106            raw_paths, raw_key, label_paths, label_key, roi=rois, patch_shape=patch_shape_min, sampler=sampler, **kwargs
107        )
108        ds.patch_shape_min = patch_shape_min
109        ds.patch_shape_max = patch_shape_max
110    else:
111        assert len(raw_paths) > 0
112        if rois is not None:
113            assert len(rois) == len(label_paths)
114            assert all(isinstance(roi, tuple) for roi in rois)
115        n_samples = kwargs.pop("n_samples", None)
117        samples_per_ds = (
118            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
119        )
120        ds = []
121        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
122            roi = None if rois is None else rois[i]
123            dset = RFSegmentationDataset(
124                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i],
125                patch_shape=patch_shape_min, sampler=sampler, **kwargs
126            )
127            dset.patch_shape_min = patch_shape_min
128            dset.patch_shape_max = patch_shape_max
129            ds.append(dset)
130        ds =*ds)
131    return ds
134def _load_rf_image_collection_dataset(
135    raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, roi, **kwargs
137    def _get_paths(rpath, rkey, lpath, lkey, this_roi):
138        rpath = glob(os.path.join(rpath, rkey))
139        rpath.sort()
140        if len(rpath) == 0:
141            raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}")
142        lpath = glob(os.path.join(lpath, lkey))
143        lpath.sort()
144        if len(rpath) != len(lpath):
145            raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}")
147        if this_roi is not None:
148            rpath, lpath = rpath[roi], lpath[roi]
150        return rpath, lpath
152    def _check_patch(patch_shape):
153        if len(patch_shape) == 3:
154            if patch_shape[0] != 1:
155                raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}")
156            patch_shape = patch_shape[1:]
157        assert len(patch_shape) == 2
158        return patch_shape
160    patch_shape_min = _check_patch(patch_shape_min)
161    patch_shape_max = _check_patch(patch_shape_max)
163    if isinstance(raw_paths, str):
164        raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi)
165        ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs)
166        ds.patch_shape_min = patch_shape_min
167        ds.patch_shape_max = patch_shape_max
168    elif raw_key is None:
169        assert label_key is None
170        assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple))
171        assert len(raw_paths) == len(label_paths)
172        ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs)
173        ds.patch_shape_min = patch_shape_min
174        ds.patch_shape_max = patch_shape_max
175    else:
176        ds = []
177        n_samples = kwargs.pop("n_samples", None)
178        samples_per_ds = (
179            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
180        )
181        if roi is None:
182            roi = len(raw_paths) * [None]
183        assert len(roi) == len(raw_paths)
184        for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)):
185            rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi)
186            dset = RFImageCollectionDataset(
187                rpath, lpath, patch_shape=patch_shape_min, n_samples=samples_per_ds[i], **kwargs
188            )
189            dset.patch_shape_min = patch_shape_min
190            dset.patch_shape_max = patch_shape_max
191            ds.append(dset)
192        ds =*ds)
193    return ds
196def _get_filters(ndim, filters_and_sigmas):
197    # subset of ilastik default features
198    if filters_and_sigmas is None:
199        filters = [filter_impl.gaussianSmoothing,
200                   filter_impl.laplacianOfGaussian,
201                   filter_impl.gaussianGradientMagnitude,
202                   filter_impl.hessianOfGaussianEigenvalues,
203                   filter_impl.structureTensorEigenvalues]
204        sigmas = [0.7, 1.6, 3.5, 5.0]
205        filters_and_sigmas = [
206            (filt, sigma) if i != len(filters) - 1 else (partial(filt, outerScale=0.5*sigma), sigma)
207            for i, filt in enumerate(filters) for sigma in sigmas
208        ]
209    # validate the filter config
210    assert isinstance(filters_and_sigmas, (list, tuple))
211    for filt_and_sig in filters_and_sigmas:
212        filt, sig = filt_and_sig
213        assert callable(filt) or (isinstance(filt, str) and hasattr(filter_impl, filt))
214        assert isinstance(sig, (float, tuple))
215        if isinstance(sig, tuple):
216            assert ndim is not None and len(sig) == ndim
217            assert all(isinstance(sigg, float) for sigg in sig)
218    return filters_and_sigmas
221def _calculate_response(raw, filter_, sigma):
222    if callable(filter_):
223        return filter_(raw, sigma)
225    # filter_ is still string, convert it to function
226    # fastfilters does not support passing sigma as tuple
227    func = getattr(vigra.filters, filter_) if isinstance(sigma, tuple) else getattr(filter_impl, filter_)
229    # special case since additional argument outerScale
230    # is needed for structureTensorEigenvalues functions
231    if filter_ == "structureTensorEigenvalues":
232        outerScale = tuple([s*2 for s in sigma]) if isinstance(sigma, tuple) else 2*sigma
233        return func(raw, sigma, outerScale=outerScale)
235    return func(raw, sigma)
238def _apply_filters(raw, filters_and_sigmas):
239    features = []
240    for filter_, sigma in filters_and_sigmas:
241        response = _calculate_response(raw, filter_, sigma)
242        if response.ndim > raw.ndim:
243            for c in range(response.shape[-1]):
244                features.append(response[..., c].flatten())
245        else:
246            features.append(response.flatten())
247    features = np.concatenate([ff[:, None] for ff in features], axis=1)
248    return features
251def _apply_filters_with_mask(raw, filters_and_sigmas, mask):
252    features = []
253    for filter_, sigma in filters_and_sigmas:
254        response = _calculate_response(raw, filter_, sigma)
255        if response.ndim > raw.ndim:
256            for c in range(response.shape[-1]):
257                features.append(response[..., c][mask])
258        else:
259            features.append(response[mask])
260    features = np.concatenate([ff[:, None] for ff in features], axis=1)
261    return features
264def _balance_labels(labels, mask):
265    class_ids, label_counts = np.unique(labels[mask], return_counts=True)
266    n_classes = len(class_ids)
267    assert class_ids.tolist() == list(range(n_classes))
269    min_class = class_ids[np.argmin(label_counts)]
270    n_labels = label_counts[min_class]
272    for class_id in class_ids:
273        if class_id == min_class:
274            continue
275        n_discard = label_counts[class_id] - n_labels
276        # sample from the current class
277        # shuffle the positions and only keep up to n_labels in the mask
278        label_pos = np.where(labels == class_id)
279        discard_ids = np.arange(len(label_pos[0]))
280        np.random.shuffle(discard_ids)
281        discard_ids = discard_ids[:n_discard]
282        discard_mask = tuple(pos[discard_ids] for pos in label_pos)
283        mask[discard_mask] = False
285    assert mask.sum() == n_classes * n_labels
286    return mask
289def _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels, return_mask=False):
290    # find the mask for where we compute filters and labels
291    # by default we exclude everything that has label -1
292    assert labels.shape == raw.shape
293    mask = labels != -1
294    if balance_labels:
295        mask = _balance_labels(labels, mask)
296    labels = labels[mask]
297    assert labels.ndim == 1
298    features = _apply_filters_with_mask(raw, filters_and_sigmas, mask)
299    assert features.ndim == 2
300    assert len(features) == len(labels)
301    if return_mask:
302        return features, labels, mask
303    else:
304        return features, labels
307def _prepare_shallow2deep(
308    raw_paths,
309    raw_key,
310    label_paths,
311    label_key,
312    patch_shape_min,
313    patch_shape_max,
314    n_forests,
315    ndim,
316    raw_transform,
317    label_transform,
318    rois,
319    is_seg_dataset,
320    filter_config,
321    sampler,
323    assert len(patch_shape_min) == len(patch_shape_max)
324    assert all(maxs >= mins for maxs, mins in zip(patch_shape_max, patch_shape_min))
325    check_paths(raw_paths, label_paths)
327    # get the correct dataset
328    if is_seg_dataset is None:
329        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
330    if is_seg_dataset:
331        ds = _load_rf_segmentation_dataset(raw_paths, raw_key, label_paths, label_key,
332                                           patch_shape_min, patch_shape_max,
333                                           raw_transform=raw_transform, label_transform=label_transform,
334                                           rois=rois, n_samples=n_forests, sampler=sampler)
335    else:
336        ds = _load_rf_image_collection_dataset(raw_paths, raw_key, label_paths, label_key,
337                                               patch_shape_min, patch_shape_max, roi=rois,
338                                               raw_transform=raw_transform, label_transform=label_transform,
339                                               n_samples=n_forests)
341    assert len(ds) == n_forests, f"{len(ds), {n_forests}}"
342    filters_and_sigmas = _get_filters(ndim, filter_config)
343    return ds, filters_and_sigmas
346def _serialize_feature_config(filters_and_sigmas):
347    feature_config = [
348        (filt if isinstance(filt, str) else (filt.func.__name__ if isinstance(filt, partial) else filt.__name__), sigma)
349        for filt, sigma in filters_and_sigmas
350    ]
351    return feature_config
354def prepare_shallow2deep(
355    raw_paths: Union[str, Sequence[str]],
356    raw_key: Optional[str],
357    label_paths: Union[str, Sequence[str]],
358    label_key: Optional[str],
359    patch_shape_min: Tuple[int, ...],
360    patch_shape_max: Tuple[int, ...],
361    n_forests: int,
362    n_threads: int,
363    output_folder: str,
364    ndim: int,
365    raw_transform: Optional[Callable] = None,
366    label_transform: Optional[Callable] = None,
367    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
368    is_seg_dataset: Optional[bool] = None,
369    balance_labels: bool = True,
370    filter_config: Optional[Dict] = None,
371    sampler: Optional[Callable] = None,
372    **rf_kwargs,
373) -> None:
374    """Prepare shallow2deep enhancer training by pre-training random forests.
376    Args:
377        raw_paths: The file paths to the raw data. May also be a single file.
378        raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
379        label_paths: The file paths to the lable data. May also be a single file.
380        label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
381        patch_shape_min: The minimal patch shape loaded for training a random forest.
382        patch_shape_max: The maximal patch shape loaded for training a random forest.
383        n_forests: The number of random forests to train.
384        n_threads: The number of threads for parallelizing the training.
385        output_folder: The folder for saving the random forests.
386        ndim: The dimensionality of the data.
387        raw_transform: A transform to apply to the raw data before computing feautres on it.
388        label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
389        rois: Region of interests for the training data.
390        is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset.
391            If None, this wil be determined from the data.
392        balance_labels: Whether to balance the training labels for the random forest.
393        filter_config: The configuration for the image filters that are used to compute features for the random forest.
394        sampler: A sampler to reject samples from training.
395        rf_kwargs: Keyword arguments for creating the random forest.
396    """
397    os.makedirs(output_folder, exist_ok=True)
398    ds, filters_and_sigmas = _prepare_shallow2deep(
399        raw_paths, raw_key, label_paths, label_key,
400        patch_shape_min, patch_shape_max, n_forests, ndim,
401        raw_transform, label_transform, rois, is_seg_dataset,
402        filter_config, sampler,
403    )
404    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
406    def _train_rf(rf_id):
407        # Sample random patch with dataset.
408        raw, labels = ds[rf_id]
409        # Cast to numpy and remove channel axis.
410        # Need to update this to support multi-channel input data and/or multi class prediction.
411        raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
412        assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
413        features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels)
414        rf = RandomForestClassifier(**rf_kwargs)
415, labels)
416        # Monkey patch these so that we know the feature config and dimensionality.
417        rf.feature_ndim = ndim
418        rf.feature_config = serialized_feature_config
419        out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
420        with open(out_path, "wb") as f:
421            pickle.dump(rf, f)
423    with futures.ThreadPoolExecutor(n_threads) as tp:
424        list(tqdm(, range(n_forests)), desc="Train RFs", total=n_forests))
427def _score_based_points(
428    score_function,
429    features, labels, rf_id,
430    forests, forests_per_stage,
431    sample_fraction_per_stage,
432    accumulate_samples,
434    # get the corresponding random forest from the last stage
435    # and predict with it
436    last_forest = forests[rf_id - forests_per_stage]
437    pred = last_forest.predict_proba(features)
439    score = score_function(pred, labels)
440    assert len(score) == len(features)
442    # get training samples based on the label-prediction diff
443    samples = []
444    nc = len(np.unique(labels))
445    # sample in a class balanced way
446    n_samples = int(sample_fraction_per_stage * len(features))
447    n_samples_class = n_samples // nc
448    for class_id in range(nc):
449        class_indices = np.where(labels == class_id)[0]
450        this_samples = class_indices[np.argsort(score[class_indices])[::-1][:n_samples_class]]
451        samples.append(this_samples)
452    samples = np.concatenate(samples)
454    # get the features and labels, add from previous rf if specified
455    features, labels = features[samples], labels[samples]
456    if accumulate_samples:
457        features = np.concatenate([last_forest.train_features, features], axis=0)
458        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
460    return features, labels
463def worst_points(
464    features, labels, rf_id,
465    forests, forests_per_stage,
466    sample_fraction_per_stage,
467    accumulate_samples=True,
468    **kwargs,
470    """@private
471    """
472    def score(pred, labels):
473        # labels to one-hot encoding
474        unique, inverse = np.unique(labels, return_inverse=True)
475        onehot = np.eye(unique.shape[0])[inverse]
476        # compute the difference between labels and prediction
477        return np.abs(onehot - pred).sum(axis=1)
479    return _score_based_points(
480        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
481    )
484def uncertain_points(
485    features, labels, rf_id,
486    forests, forests_per_stage,
487    sample_fraction_per_stage,
488    accumulate_samples=True,
489    **kwargs,
491    """@private
492    """
493    def score(pred, labels):
494        assert pred.ndim == 2
495        channel_sorted = np.sort(pred, axis=1)
496        uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2]
497        return uncertainty
499    return _score_based_points(
500        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
501    )
504def uncertain_worst_points(
505    features, labels, rf_id,
506    forests, forests_per_stage,
507    sample_fraction_per_stage,
508    accumulate_samples=True,
509    alpha=0.5,
510    **kwargs,
512    """@private
513    """
514    def score(pred, labels):
515        assert pred.ndim == 2
517        # labels to one-hot encoding
518        unique, inverse = np.unique(labels, return_inverse=True)
519        onehot = np.eye(unique.shape[0])[inverse]
520        # compute the difference between labels and prediction
521        diff = np.abs(onehot - pred).sum(axis=1)
523        channel_sorted = np.sort(pred, axis=1)
524        uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2]
525        return alpha * diff + (1.0 - alpha) * uncertainty
527    return _score_based_points(
528        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
529    )
532def random_points(
533    features, labels, rf_id,
534    forests, forests_per_stage,
535    sample_fraction_per_stage,
536    accumulate_samples=True,
537    **kwargs,
539    """@private
540    """
541    samples = []
542    nc = len(np.unique(labels))
543    # sample in a class balanced way
544    n_samples = int(sample_fraction_per_stage * len(features))
545    n_samples_class = n_samples // nc
546    for class_id in range(nc):
547        class_indices = np.where(labels == class_id)[0]
548        this_samples = np.random.choice(
549            class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
550        )
551        samples.append(this_samples)
552    samples = np.concatenate(samples)
553    features, labels = features[samples], labels[samples]
555    if accumulate_samples and rf_id >= forests_per_stage:
556        last_forest = forests[rf_id - forests_per_stage]
557        features = np.concatenate([last_forest.train_features, features], axis=0)
558        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
560    return features, labels
563def worst_tiles(
564    features, labels, rf_id,
565    forests, forests_per_stage,
566    sample_fraction_per_stage,
567    img_shape,
568    mask,
569    tile_shape=[25, 25],
570    smoothing_sigma=None,
571    accumulate_samples=True,
572    **kwargs,
574    """@private
575    """
576    # check inputs
577    ndim = len(img_shape)
578    assert ndim in [2, 3], img_shape
579    assert len(tile_shape) == ndim, tile_shape
581    # get the corresponding random forest from the last stage
582    # and predict with it
583    last_forest = forests[rf_id - forests_per_stage]
584    pred = last_forest.predict_proba(features)
586    # labels to one-hot encoding
587    unique, inverse = np.unique(labels, return_inverse=True)
588    onehot = np.eye(unique.shape[0])[inverse]
590    # compute the difference between labels and prediction
591    diff = np.abs(onehot - pred)
592    assert len(diff) == len(features)
594    # reshape diff to image shape
595    # we need to also take into account the mask here, and if we apply any masking
596    # because we can't directly reshape if we have it
597    if mask.sum() != mask.size:
598        # get the diff image
599        diff_img = np.zeros(img_shape + diff.shape[-1:], dtype=diff.dtype)
600        diff_img[mask] = diff
601        # inflate the features
602        full_features = np.zeros((mask.size,) + features.shape[-1:], dtype=features.dtype)
603        full_features[mask.ravel()] = features
604        features = full_features
605        # inflate the labels (with -1 so this will not be sampled)
606        full_labels = np.full(mask.size, -1, dtype="int8")
607        full_labels[mask.ravel()] = labels
608        labels = full_labels
609    else:
610        diff_img = diff.reshape(img_shape + (-1,))
612    # get the number of classes (not counting ignore label)
613    class_ids = np.unique(labels)
614    nc = len(class_ids) - 1 if -1 in class_ids else len(class_ids)
616    # sample in a class balanced way
617    n_samples_class = int(sample_fraction_per_stage * len(features)) // nc
618    samples = []
619    for class_id in range(nc):
620        # smooth either with gaussian or 1-kernel
621        if smoothing_sigma:
622            diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode="constant")
623        else:
624            kernel = np.ones(tile_shape)
625            diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode="constant")
627        # get training samples based on tiles around maxima of the label-prediction diff
628        # do this in a class-specific way to ensure that each class is sampled
629        # get maxima of the label-prediction diff (they seem to be sorted already)
630        max_centers = peak_local_max(
631            diff_img_smooth,
632            min_distance=max(tile_shape),
633            exclude_border=tuple([s // 2 for s in tile_shape])
634        )
636        # get indices of tiles around maxima
637        tiles = []
638        for center in max_centers:
639            tile_slice = tuple(
640                slice(
641                    center[d]-tile_shape[d]//2,
642                    center[d]+tile_shape[d]//2 + 1,
643                    None
644                ) for d in range(ndim)
645            )
646            grid = np.mgrid[tile_slice]
647            samples_in_tile = grid.reshape(ndim, -1)
648            samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape)
649            tiles.append(samples_in_tile)
651        # this (very rarely) fails due to empty tile list. Since we usually
652        # accumulate the features this doesn't hurt much and we can continue
653        try:
654            tiles = np.concatenate(tiles)
655            # take samples that belong to the current class
656            this_samples = tiles[labels[tiles] == class_id][:n_samples_class]
657            samples.append(this_samples)
658        except ValueError:
659            pass
661    try:
662        samples = np.concatenate(samples)
663        features, labels = features[samples], labels[samples]
665        # get the features and labels, add from previous rf if specified
666        if accumulate_samples:
667            features = np.concatenate([last_forest.train_features, features], axis=0)
668            labels = np.concatenate([last_forest.train_labels, labels], axis=0)
669    except ValueError:
670        features, labels = last_forest.train_features, last_forest.train_labels
671        warnings.warn(
672            f"No features were sampled for forest {rf_id} using features of forest {rf_id - forests_per_stage}"
673        )
675    return features, labels
678def balanced_dense_accumulate(
679    features, labels, rf_id,
680    forests, forests_per_stage,
681    sample_fraction_per_stage,
682    accumulate_samples=True,
683    **kwargs,
685    """@private
686    """
687    samples = []
688    nc = len(np.unique(labels))
689    # sample in a class balanced way
690    # take all pixels from minority class
691    # and choose same amount from other classes randomly
692    n_samples_class = np.unique(labels, return_counts=True)[1].min()
693    for class_id in range(nc):
694        class_indices = np.where(labels == class_id)[0]
695        this_samples = np.random.choice(
696            class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
697        )
698        samples.append(this_samples)
699    samples = np.concatenate(samples)
700    features, labels = features[samples], labels[samples]
702    # accumulate
703    if accumulate_samples and rf_id >= forests_per_stage:
704        last_forest = forests[rf_id - forests_per_stage]
705        features = np.concatenate([last_forest.train_features, features], axis=0)
706        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
708    return features, labels
712    "random_points": random_points,
713    "uncertain_points": uncertain_points,
714    "uncertain_worst_points": uncertain_worst_points,
715    "worst_points": worst_points,
716    "worst_tiles": worst_tiles,
717    "balanced_dense_accumulate": balanced_dense_accumulate,
723def prepare_shallow2deep_advanced(
724    raw_paths: Union[str, Sequence[str]],
725    raw_key: Optional[str],
726    label_paths: Union[str, Sequence[str]],
727    label_key: Optional[str],
728    patch_shape_min: Tuple[int, ...],
729    patch_shape_max: Tuple[int, ...],
730    n_forests: int,
731    n_threads: int,
732    output_folder: str,
733    ndim: int,
734    forests_per_stage: int,
735    sample_fraction_per_stage: float,
736    sampling_strategy: Union[str, Callable] = "worst_points",
737    sampling_kwargs: Dict = {},
738    raw_transform: Optional[Callable] = None,
739    label_transform: Optional[Callable] = None,
740    rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None,
741    is_seg_dataset: Optional[bool] = None,
742    balance_labels: bool = True,
743    filter_config: Optional[Dict] = None,
744    sampler: Optional[Callable] = None,
745    **rf_kwargs,
746) -> None:
747    """Prepare shallow2deep enhancer training by pre-training random forests.
749    This function implements an advanced training procedure compared to `prepare_shallow2deep`.
750    The 'sampling_strategy' argument determines an advnaced sampling strategies,
751    which selects the samples to use for training the random forests.
752    The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests
753    are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage.
754    The random forests in stage 0 are trained from random balanced labels.
755    For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature
756    '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)',
757    and return the sampled features and labels. See for example the 'worst_points' function.
758    Alternatively, one of the pre-defined strategies can be selected by passing one of the following names:
759    - "random_poinst": Select random points.
760    - "uncertain_points": Select points with the highest uncertainty.
761    - "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies.
762    - "worst_points": Select the points with the worst accuracies.
763    - "worst_tiles": Selectt the tiles with the worst accuracies.
764    - "balanced_dense_accumulate": Balanced dense accumulation.
766    Args:
767        raw_paths: The file paths to the raw data. May also be a single file.
768        raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
769        label_paths: The file paths to the lable data. May also be a single file.
770        label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
771        patch_shape_min: The minimal patch shape loaded for training a random forest.
772        patch_shape_max: The maximal patch shape loaded for training a random forest.
773        n_forests: The number of random forests to train.
774        n_threads: The number of threads for parallelizing the training.
775        output_folder: The folder for saving the random forests.
776        ndim: The dimensionality of the data.
777        forests_per_stage: The number of forests to train per stage.
778        sample_fraction_per_stage: The fraction of samples to use per stage.
779        sampling_strategy: The sampling strategy.
780        sampling_kwargs: The keyword arguments for the sampling strategy.
781        raw_transform: A transform to apply to the raw data before computing feautres on it.
782        label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
783        rois: Region of interests for the training data.
784        is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset.
785            If None, this wil be determined from the data.
786        balance_labels: Whether to balance the training labels for the random forest.
787        filter_config: The configuration for the image filters that are used to compute features for the random forest.
788        sampler: A sampler to reject samples from training.
789        rf_kwargs: Keyword arguments for creating the random forest.
790    """
791    os.makedirs(output_folder, exist_ok=True)
792    ds, filters_and_sigmas = _prepare_shallow2deep(
793        raw_paths, raw_key, label_paths, label_key,
794        patch_shape_min, patch_shape_max, n_forests, ndim,
795        raw_transform, label_transform, rois, is_seg_dataset,
796        filter_config, sampler,
797    )
798    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
800    forests = []
801    n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\
802        n_forests // forests_per_stage + 1
804    if isinstance(sampling_strategy, str):
805        assert sampling_strategy in SAMPLING_STRATEGIES, \
806            f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}"
807        sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy]
808    assert callable(sampling_strategy)
810    with tqdm(total=n_forests) as pbar:
812        def _train_rf(rf_id):
813            # sample random patch with dataset
814            raw, labels = ds[rf_id]
816            # cast to numpy and remove channel axis
817            # need to update this to support multi-channel input data and/or multi class prediction
818            raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
819            assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
821            # monkey patch original shape to sampling_kwargs
822            # deepcopy needed due to multithreading
823            current_kwargs = copy.deepcopy(sampling_kwargs)
824            current_kwargs["img_shape"] = raw.shape
826            # only balance samples for the first (densely trained) rfs
827            features, labels, mask = _get_features_and_labels(
828                raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True
829            )
830            if forests:  # we have forests: apply the sampling strategy
831                features, labels = sampling_strategy(
832                    features, labels, rf_id,
833                    forests, forests_per_stage,
834                    sample_fraction_per_stage,
835                    mask=mask,
836                    **current_kwargs,
837                )
838            else:  # sample randomly
839                features, labels = random_points(
840                    features, labels, rf_id,
841                    forests, forests_per_stage,
842                    sample_fraction_per_stage,
843                )
845            # fit the random forest
846            assert len(features) == len(labels)
847            rf = RandomForestClassifier(**rf_kwargs)
848  , labels)
849            # monkey patch these so that we know the feature config and dimensionality
850            rf.feature_ndim = ndim
851            rf.feature_config = serialized_feature_config
853            # save the random forest, update pbar, return it
854            out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
855            with open(out_path, "wb") as f:
856                pickle.dump(rf, f)
858            # monkey patch the training data and labels so we can re-use it in later stages
859            rf.train_features = features
860            rf.train_labels = labels
862            pbar.update(1)
863            return rf
865        for stage in range(n_stages):
866            pbar.set_description(f"Train RFs for stage {stage}")
867            with futures.ThreadPoolExecutor(n_threads) as tp:
868                this_forests = list(
869                    _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1))
870                ))
871                forests.extend(this_forests)
