torch_em.shallow2deep.prepare_shallow2deep

  1import os
  2import copy
  3import pickle
  4import warnings
  5from concurrent import futures
  6from glob import glob
  7from functools import partial
  8
  9import numpy as np
 10import torch_em
 11from scipy.ndimage import gaussian_filter, convolve
 12from skimage.feature import peak_local_max
 13from sklearn.ensemble import RandomForestClassifier
 14from torch_em.segmentation import check_paths, is_segmentation_dataset, samples_to_datasets
 15from tqdm import tqdm
 16
 17import vigra
 18try:
 19    import fastfilters as filter_impl
 20except ImportError:
 21    import vigra.filters as filter_impl
 22
 23
 24class RFSegmentationDataset(torch_em.data.SegmentationDataset):
 25    _patch_shape_min = None
 26    _patch_shape_max = None
 27
 28    @property
 29    def patch_shape_min(self):
 30        return self._patch_shape_min
 31
 32    @patch_shape_min.setter
 33    def patch_shape_min(self, value):
 34        self._patch_shape_min = value
 35
 36    @property
 37    def patch_shape_max(self):
 38        return self._patch_shape_max
 39
 40    @patch_shape_max.setter
 41    def patch_shape_max(self, value):
 42        self._patch_shape_max = value
 43
 44    def _sample_bounding_box(self):
 45        assert self._patch_shape_min is not None and self._patch_shape_max is not None
 46        sample_shape = [
 47            pmin if pmin == pmax else np.random.randint(pmin, pmax)
 48            for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max)
 49        ]
 50        bb_start = [
 51            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
 52            for sh, psh in zip(self.shape, sample_shape)
 53        ]
 54        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, sample_shape))
 55
 56
 57class RFImageCollectionDataset(torch_em.data.ImageCollectionDataset):
 58    _patch_shape_min = None
 59    _patch_shape_max = None
 60
 61    @property
 62    def patch_shape_min(self):
 63        return self._patch_shape_min
 64
 65    @patch_shape_min.setter
 66    def patch_shape_min(self, value):
 67        self._patch_shape_min = value
 68
 69    @property
 70    def patch_shape_max(self):
 71        return self._patch_shape_max
 72
 73    @patch_shape_max.setter
 74    def patch_shape_max(self, value):
 75        self._patch_shape_max = value
 76
 77    def _sample_bounding_box(self, shape):
 78        if any(sh < psh for sh, psh in zip(shape, self.patch_shape_max)):
 79            raise NotImplementedError("Image padding is not supported yet.")
 80        assert self._patch_shape_min is not None and self._patch_shape_max is not None
 81        patch_shape = [
 82            pmin if pmin == pmax else np.random.randint(pmin, pmax)
 83            for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max)
 84        ]
 85        bb_start = [
 86            np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, patch_shape)
 87        ]
 88        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape))
 89
 90
 91def _load_rf_segmentation_dataset(
 92    raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, **kwargs
 93):
 94    rois = kwargs.pop("rois", None)
 95    sampler = kwargs.pop("sampler", None)
 96    sampler = sampler if sampler else torch_em.data.MinForegroundSampler(min_fraction=0.01)
 97    if isinstance(raw_paths, str):
 98        if rois is not None:
 99            assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois)
100        ds = RFSegmentationDataset(
101            raw_paths, raw_key, label_paths, label_key, roi=rois, patch_shape=patch_shape_min, sampler=sampler, **kwargs
102        )
103        ds.patch_shape_min = patch_shape_min
104        ds.patch_shape_max = patch_shape_max
105    else:
106        assert len(raw_paths) > 0
107        if rois is not None:
108            assert len(rois) == len(label_paths)
109            assert all(isinstance(roi, tuple) for roi in rois)
110        n_samples = kwargs.pop("n_samples", None)
111
112        samples_per_ds = (
113            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
114        )
115        ds = []
116        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
117            roi = None if rois is None else rois[i]
118            dset = RFSegmentationDataset(
119                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i],
120                patch_shape=patch_shape_min, sampler=sampler, **kwargs
121            )
122            dset.patch_shape_min = patch_shape_min
123            dset.patch_shape_max = patch_shape_max
124            ds.append(dset)
125        ds = torch_em.data.ConcatDataset(*ds)
126    return ds
127
128
129def _load_rf_image_collection_dataset(
130    raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, roi, **kwargs
131):
132    def _get_paths(rpath, rkey, lpath, lkey, this_roi):
133        rpath = glob(os.path.join(rpath, rkey))
134        rpath.sort()
135        if len(rpath) == 0:
136            raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}")
137        lpath = glob(os.path.join(lpath, lkey))
138        lpath.sort()
139        if len(rpath) != len(lpath):
140            raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}")
141
142        if this_roi is not None:
143            rpath, lpath = rpath[roi], lpath[roi]
144
145        return rpath, lpath
146
147    def _check_patch(patch_shape):
148        if len(patch_shape) == 3:
149            if patch_shape[0] != 1:
150                raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}")
151            patch_shape = patch_shape[1:]
152        assert len(patch_shape) == 2
153        return patch_shape
154
155    patch_shape_min = _check_patch(patch_shape_min)
156    patch_shape_max = _check_patch(patch_shape_max)
157
158    if isinstance(raw_paths, str):
159        raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi)
160        ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs)
161        ds.patch_shape_min = patch_shape_min
162        ds.patch_shape_max = patch_shape_max
163    elif raw_key is None:
164        assert label_key is None
165        assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple))
166        assert len(raw_paths) == len(label_paths)
167        ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs)
168        ds.patch_shape_min = patch_shape_min
169        ds.patch_shape_max = patch_shape_max
170    else:
171        ds = []
172        n_samples = kwargs.pop("n_samples", None)
173        samples_per_ds = (
174            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
175        )
176        if roi is None:
177            roi = len(raw_paths) * [None]
178        assert len(roi) == len(raw_paths)
179        for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)):
180            rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi)
181            dset = RFImageCollectionDataset(
182                rpath, lpath, patch_shape=patch_shape_min, n_samples=samples_per_ds[i], **kwargs
183            )
184            dset.patch_shape_min = patch_shape_min
185            dset.patch_shape_max = patch_shape_max
186            ds.append(dset)
187        ds = torch_em.data.ConcatDataset(*ds)
188    return ds
189
190
191def _get_filters(ndim, filters_and_sigmas):
192    # subset of ilastik default features
193    if filters_and_sigmas is None:
194        filters = [filter_impl.gaussianSmoothing,
195                   filter_impl.laplacianOfGaussian,
196                   filter_impl.gaussianGradientMagnitude,
197                   filter_impl.hessianOfGaussianEigenvalues,
198                   filter_impl.structureTensorEigenvalues]
199        sigmas = [0.7, 1.6, 3.5, 5.0]
200        filters_and_sigmas = [
201            (filt, sigma) if i != len(filters) - 1 else (partial(filt, outerScale=0.5*sigma), sigma)
202            for i, filt in enumerate(filters) for sigma in sigmas
203        ]
204    # validate the filter config
205    assert isinstance(filters_and_sigmas, (list, tuple))
206    for filt_and_sig in filters_and_sigmas:
207        filt, sig = filt_and_sig
208        assert callable(filt) or (isinstance(filt, str) and hasattr(filter_impl, filt))
209        assert isinstance(sig, (float, tuple))
210        if isinstance(sig, tuple):
211            assert ndim is not None and len(sig) == ndim
212            assert all(isinstance(sigg, float) for sigg in sig)
213    return filters_and_sigmas
214
215
216def _calculate_response(raw, filter_, sigma):
217    if callable(filter_):
218        return filter_(raw, sigma)
219
220    # filter_ is still string, convert it to function
221    # fastfilters does not support passing sigma as tuple
222    func = getattr(vigra.filters, filter_) if isinstance(sigma, tuple) else getattr(filter_impl, filter_)
223
224    # special case since additional argument outerScale
225    # is needed for structureTensorEigenvalues functions
226    if filter_ == "structureTensorEigenvalues":
227        outerScale = tuple([s*2 for s in sigma]) if isinstance(sigma, tuple) else 2*sigma
228        return func(raw, sigma, outerScale=outerScale)
229
230    return func(raw, sigma)
231
232
233def _apply_filters(raw, filters_and_sigmas):
234    features = []
235    for filter_, sigma in filters_and_sigmas:
236        response = _calculate_response(raw, filter_, sigma)
237        if response.ndim > raw.ndim:
238            for c in range(response.shape[-1]):
239                features.append(response[..., c].flatten())
240        else:
241            features.append(response.flatten())
242    features = np.concatenate([ff[:, None] for ff in features], axis=1)
243    return features
244
245
246def _apply_filters_with_mask(raw, filters_and_sigmas, mask):
247    features = []
248    for filter_, sigma in filters_and_sigmas:
249        response = _calculate_response(raw, filter_, sigma)
250        if response.ndim > raw.ndim:
251            for c in range(response.shape[-1]):
252                features.append(response[..., c][mask])
253        else:
254            features.append(response[mask])
255    features = np.concatenate([ff[:, None] for ff in features], axis=1)
256    return features
257
258
259def _balance_labels(labels, mask):
260    class_ids, label_counts = np.unique(labels[mask], return_counts=True)
261    n_classes = len(class_ids)
262    assert class_ids.tolist() == list(range(n_classes))
263
264    min_class = class_ids[np.argmin(label_counts)]
265    n_labels = label_counts[min_class]
266
267    for class_id in class_ids:
268        if class_id == min_class:
269            continue
270        n_discard = label_counts[class_id] - n_labels
271        # sample from the current class
272        # shuffle the positions and only keep up to n_labels in the mask
273        label_pos = np.where(labels == class_id)
274        discard_ids = np.arange(len(label_pos[0]))
275        np.random.shuffle(discard_ids)
276        discard_ids = discard_ids[:n_discard]
277        discard_mask = tuple(pos[discard_ids] for pos in label_pos)
278        mask[discard_mask] = False
279
280    assert mask.sum() == n_classes * n_labels
281    return mask
282
283
284def _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels, return_mask=False):
285    # find the mask for where we compute filters and labels
286    # by default we exclude everything that has label -1
287    assert labels.shape == raw.shape
288    mask = labels != -1
289    if balance_labels:
290        mask = _balance_labels(labels, mask)
291    labels = labels[mask]
292    assert labels.ndim == 1
293    features = _apply_filters_with_mask(raw, filters_and_sigmas, mask)
294    assert features.ndim == 2
295    assert len(features) == len(labels)
296    if return_mask:
297        return features, labels, mask
298    else:
299        return features, labels
300
301
302def _prepare_shallow2deep(
303    raw_paths,
304    raw_key,
305    label_paths,
306    label_key,
307    patch_shape_min,
308    patch_shape_max,
309    n_forests,
310    ndim,
311    raw_transform,
312    label_transform,
313    rois,
314    is_seg_dataset,
315    filter_config,
316    sampler,
317):
318    assert len(patch_shape_min) == len(patch_shape_max)
319    assert all(maxs >= mins for maxs, mins in zip(patch_shape_max, patch_shape_min))
320    check_paths(raw_paths, label_paths)
321
322    # get the correct dataset
323    if is_seg_dataset is None:
324        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
325    if is_seg_dataset:
326        ds = _load_rf_segmentation_dataset(raw_paths, raw_key, label_paths, label_key,
327                                           patch_shape_min, patch_shape_max,
328                                           raw_transform=raw_transform, label_transform=label_transform,
329                                           rois=rois, n_samples=n_forests, sampler=sampler)
330    else:
331        ds = _load_rf_image_collection_dataset(raw_paths, raw_key, label_paths, label_key,
332                                               patch_shape_min, patch_shape_max, roi=rois,
333                                               raw_transform=raw_transform, label_transform=label_transform,
334                                               n_samples=n_forests)
335
336    assert len(ds) == n_forests, f"{len(ds), {n_forests}}"
337    filters_and_sigmas = _get_filters(ndim, filter_config)
338    return ds, filters_and_sigmas
339
340
341def _serialize_feature_config(filters_and_sigmas):
342    feature_config = [
343        (filt if isinstance(filt, str)
344            else (filt.func.__name__ if isinstance(filt, partial) else filt.__name__),
345        sigma)
346        for filt, sigma in filters_and_sigmas
347    ]
348    return feature_config
349
350
351def prepare_shallow2deep(
352    raw_paths,
353    raw_key,
354    label_paths,
355    label_key,
356    patch_shape_min,
357    patch_shape_max,
358    n_forests,
359    n_threads,
360    output_folder,
361    ndim,
362    raw_transform=None,
363    label_transform=None,
364    rois=None,
365    is_seg_dataset=None,
366    balance_labels=True,
367    filter_config=None,
368    sampler=None,
369    **rf_kwargs,
370):
371    os.makedirs(output_folder, exist_ok=True)
372    ds, filters_and_sigmas = _prepare_shallow2deep(
373        raw_paths, raw_key, label_paths, label_key,
374        patch_shape_min, patch_shape_max, n_forests, ndim,
375        raw_transform, label_transform, rois, is_seg_dataset,
376        filter_config, sampler,
377    )
378    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
379
380    def _train_rf(rf_id):
381        # sample random patch with dataset
382        raw, labels = ds[rf_id]
383        # cast to numpy and remove channel axis
384        # need to update this to support multi-channel input data and/or multi class prediction
385        raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
386        assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
387        features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels)
388        rf = RandomForestClassifier(**rf_kwargs)
389        rf.fit(features, labels)
390        # monkey patch these so that we know the feature config and dimensionality
391        rf.feature_ndim = ndim
392        rf.feature_config = serialized_feature_config
393        out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
394        with open(out_path, "wb") as f:
395            pickle.dump(rf, f)
396
397    with futures.ThreadPoolExecutor(n_threads) as tp:
398        list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests))
399
400
401def _score_based_points(
402    score_function,
403    features, labels, rf_id,
404    forests, forests_per_stage,
405    sample_fraction_per_stage,
406    accumulate_samples,
407):
408    # get the corresponding random forest from the last stage
409    # and predict with it
410    last_forest = forests[rf_id - forests_per_stage]
411    pred = last_forest.predict_proba(features)
412
413    score = score_function(pred, labels)
414    assert len(score) == len(features)
415
416    # get training samples based on the label-prediction diff
417    samples = []
418    nc = len(np.unique(labels))
419    # sample in a class balanced way
420    n_samples = int(sample_fraction_per_stage * len(features))
421    n_samples_class = n_samples // nc
422    for class_id in range(nc):
423        class_indices = np.where(labels == class_id)[0]
424        this_samples = class_indices[np.argsort(score[class_indices])[::-1][:n_samples_class]]
425        samples.append(this_samples)
426    samples = np.concatenate(samples)
427
428    # get the features and labels, add from previous rf if specified
429    features, labels = features[samples], labels[samples]
430    if accumulate_samples:
431        features = np.concatenate([last_forest.train_features, features], axis=0)
432        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
433
434    return features, labels
435
436
437def worst_points(
438    features, labels, rf_id,
439    forests, forests_per_stage,
440    sample_fraction_per_stage,
441    accumulate_samples=True,
442    **kwargs,
443):
444    def score(pred, labels):
445        # labels to one-hot encoding
446        unique, inverse = np.unique(labels, return_inverse=True)
447        onehot = np.eye(unique.shape[0])[inverse]
448        # compute the difference between labels and prediction
449        return np.abs(onehot - pred).sum(axis=1)
450
451    return _score_based_points(
452        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
453    )
454
455
456def uncertain_points(
457    features, labels, rf_id,
458    forests, forests_per_stage,
459    sample_fraction_per_stage,
460    accumulate_samples=True,
461    **kwargs,
462):
463    def score(pred, labels):
464        assert pred.ndim == 2
465        channel_sorted = np.sort(pred, axis=1)
466        uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2]
467        return uncertainty
468
469    return _score_based_points(
470        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
471    )
472
473
474def uncertain_worst_points(
475    features, labels, rf_id,
476    forests, forests_per_stage,
477    sample_fraction_per_stage,
478    accumulate_samples=True,
479    alpha=0.5,
480    **kwargs,
481):
482    def score(pred, labels):
483        assert pred.ndim == 2
484
485        # labels to one-hot encoding
486        unique, inverse = np.unique(labels, return_inverse=True)
487        onehot = np.eye(unique.shape[0])[inverse]
488        # compute the difference between labels and prediction
489        diff = np.abs(onehot - pred).sum(axis=1)
490
491        channel_sorted = np.sort(pred, axis=1)
492        uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2]
493        return alpha * diff + (1.0 - alpha) * uncertainty
494
495    return _score_based_points(
496        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
497    )
498
499
500def random_points(
501    features, labels, rf_id,
502    forests, forests_per_stage,
503    sample_fraction_per_stage,
504    accumulate_samples=True,
505    **kwargs,
506):
507    samples = []
508    nc = len(np.unique(labels))
509    # sample in a class balanced way
510    n_samples = int(sample_fraction_per_stage * len(features))
511    n_samples_class = n_samples // nc
512    for class_id in range(nc):
513        class_indices = np.where(labels == class_id)[0]
514        this_samples = np.random.choice(
515            class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
516        )
517        samples.append(this_samples)
518    samples = np.concatenate(samples)
519    features, labels = features[samples], labels[samples]
520
521    if accumulate_samples and rf_id >= forests_per_stage:
522        last_forest = forests[rf_id - forests_per_stage]
523        features = np.concatenate([last_forest.train_features, features], axis=0)
524        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
525
526    return features, labels
527
528
529def worst_tiles(
530    features, labels, rf_id,
531    forests, forests_per_stage,
532    sample_fraction_per_stage,
533    img_shape,
534    mask,
535    tile_shape=[25, 25],
536    smoothing_sigma=None,
537    accumulate_samples=True,
538    **kwargs,
539):
540    # check inputs
541    ndim = len(img_shape)
542    assert ndim in [2, 3], img_shape
543    assert len(tile_shape) == ndim, tile_shape
544
545    # get the corresponding random forest from the last stage
546    # and predict with it
547    last_forest = forests[rf_id - forests_per_stage]
548    pred = last_forest.predict_proba(features)
549
550    # labels to one-hot encoding
551    unique, inverse = np.unique(labels, return_inverse=True)
552    onehot = np.eye(unique.shape[0])[inverse]
553
554    # compute the difference between labels and prediction
555    diff = np.abs(onehot - pred)
556    assert len(diff) == len(features)
557
558    # reshape diff to image shape
559    # we need to also take into account the mask here, and if we apply any masking
560    # because we can't directly reshape if we have it
561    if mask.sum() != mask.size:
562        # get the diff image
563        diff_img = np.zeros(img_shape + diff.shape[-1:], dtype=diff.dtype)
564        diff_img[mask] = diff
565        # inflate the features
566        full_features = np.zeros((mask.size,) + features.shape[-1:], dtype=features.dtype)
567        full_features[mask.ravel()] = features
568        features = full_features
569        # inflate the labels (with -1 so this will not be sampled)
570        full_labels = np.full(mask.size, -1, dtype="int8")
571        full_labels[mask.ravel()] = labels
572        labels = full_labels
573    else:
574        diff_img = diff.reshape(img_shape + (-1,))
575
576    # get the number of classes (not counting ignore label)
577    class_ids = np.unique(labels)
578    nc = len(class_ids) - 1 if -1 in class_ids else len(class_ids)
579
580    # sample in a class balanced way
581    n_samples_class = int(sample_fraction_per_stage * len(features)) // nc
582    samples = []
583    for class_id in range(nc):
584        # smooth either with gaussian or 1-kernel
585        if smoothing_sigma:
586            diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode="constant")
587        else:
588            kernel = np.ones(tile_shape)
589            diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode="constant")
590
591        # get training samples based on tiles around maxima of the label-prediction diff
592        # do this in a class-specific way to ensure that each class is sampled
593        # get maxima of the label-prediction diff (they seem to be sorted already)
594        max_centers = peak_local_max(
595            diff_img_smooth,
596            min_distance=max(tile_shape),
597            exclude_border=tuple([s // 2 for s in tile_shape])
598        )
599
600        # get indices of tiles around maxima
601        tiles = []
602        for center in max_centers:
603            tile_slice = tuple(
604                slice(
605                    center[d]-tile_shape[d]//2,
606                    center[d]+tile_shape[d]//2 + 1,
607                    None
608                ) for d in range(ndim)
609            )
610            grid = np.mgrid[tile_slice]
611            samples_in_tile = grid.reshape(ndim, -1)
612            samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape)
613            tiles.append(samples_in_tile)
614
615        # this (very rarely) fails due to empty tile list. Since we usually
616        # accumulate the features this doesn't hurt much and we can continue
617        try:
618            tiles = np.concatenate(tiles)
619            # take samples that belong to the current class
620            this_samples = tiles[labels[tiles] == class_id][:n_samples_class]
621            samples.append(this_samples)
622        except ValueError:
623            pass
624
625    try:
626        samples = np.concatenate(samples)
627        features, labels = features[samples], labels[samples]
628
629        # get the features and labels, add from previous rf if specified
630        if accumulate_samples:
631            features = np.concatenate([last_forest.train_features, features], axis=0)
632            labels = np.concatenate([last_forest.train_labels, labels], axis=0)
633    except ValueError:
634        features, labels = last_forest.train_features, last_forest.train_labels
635        warnings.warn(
636            f"No features were sampled for forest {rf_id} using features of forest {rf_id - forests_per_stage}"
637        )
638
639    return features, labels
640
641
642def balanced_dense_accumulate(
643    features, labels, rf_id,
644    forests, forests_per_stage,
645    sample_fraction_per_stage,
646    accumulate_samples=True,
647    **kwargs,
648):
649    samples = []
650    nc = len(np.unique(labels))
651    # sample in a class balanced way
652    # take all pixels from minority class
653    # and choose same amount from other classes randomly
654    n_samples_class = np.unique(labels, return_counts=True)[1].min()
655    for class_id in range(nc):
656        class_indices = np.where(labels == class_id)[0]
657        this_samples = np.random.choice(
658            class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
659        )
660        samples.append(this_samples)
661    samples = np.concatenate(samples)
662    features, labels = features[samples], labels[samples]
663
664    # accumulate
665    if accumulate_samples and rf_id >= forests_per_stage:
666        last_forest = forests[rf_id - forests_per_stage]
667        features = np.concatenate([last_forest.train_features, features], axis=0)
668        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
669
670    return features, labels
671
672
673SAMPLING_STRATEGIES = {
674    "random_points": random_points,
675    "uncertain_points": uncertain_points,
676    "uncertain_worst_points": uncertain_worst_points,
677    "worst_points": worst_points,
678    "worst_tiles": worst_tiles,
679    "balanced_dense_accumulate": balanced_dense_accumulate,
680}
681
682
683def prepare_shallow2deep_advanced(
684    raw_paths,
685    raw_key,
686    label_paths,
687    label_key,
688    patch_shape_min,
689    patch_shape_max,
690    n_forests,
691    n_threads,
692    output_folder,
693    ndim,
694    forests_per_stage,
695    sample_fraction_per_stage,
696    sampling_strategy="worst_points",
697    sampling_kwargs={},
698    raw_transform=None,
699    label_transform=None,
700    rois=None,
701    is_seg_dataset=None,
702    filter_config=None,
703    sampler=None,
704    **rf_kwargs,
705):
706    """Advanced training of random forests for shallow2deep enhancer training.
707
708    This function accepts the 'sampling_strategy' parameter, which allows to implement custom
709    sampling strategies for the samples used for training the random forests.
710    Training operates in stages, the parameter 'forests_per_stage' determines how many forests
711    are trained in each stage, and 'sample_fraction_per_stage' which fraction of the samples is
712    taken per stage. The random forests in stage 0 are trained from random balanced labels.
713    For the other stages 'sampling_strategy' enables specifying the strategy; it has to be a function
714    with signature '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)',
715    and return the sampled features and labels. See for the 'worst_points' function
716    in this file for an example implementation.
717    """
718    os.makedirs(output_folder, exist_ok=True)
719    ds, filters_and_sigmas = _prepare_shallow2deep(
720        raw_paths, raw_key, label_paths, label_key,
721        patch_shape_min, patch_shape_max, n_forests, ndim,
722        raw_transform, label_transform, rois, is_seg_dataset,
723        filter_config, sampler,
724    )
725    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
726
727    forests = []
728    n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\
729        n_forests // forests_per_stage + 1
730
731    if isinstance(sampling_strategy, str):
732        assert sampling_strategy in SAMPLING_STRATEGIES,\
733            f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}"
734        sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy]
735    assert callable(sampling_strategy)
736
737    with tqdm(total=n_forests) as pbar:
738
739        def _train_rf(rf_id):
740            # sample random patch with dataset
741            raw, labels = ds[rf_id]
742
743            # cast to numpy and remove channel axis
744            # need to update this to support multi-channel input data and/or multi class prediction
745            raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
746            assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
747
748            # monkey patch original shape to sampling_kwargs
749            # deepcopy needed due to multithreading
750            current_kwargs = copy.deepcopy(sampling_kwargs)
751            current_kwargs["img_shape"] = raw.shape
752
753            # only balance samples for the first (densely trained) rfs
754            features, labels, mask = _get_features_and_labels(
755                raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True
756            )
757            if forests:  # we have forests: apply the sampling strategy
758                features, labels = sampling_strategy(
759                    features, labels, rf_id,
760                    forests, forests_per_stage,
761                    sample_fraction_per_stage,
762                    mask=mask,
763                    **current_kwargs,
764                )
765            else:  # sample randomly
766                features, labels = random_points(
767                    features, labels, rf_id,
768                    forests, forests_per_stage,
769                    sample_fraction_per_stage,
770                )
771
772            # fit the random forest
773            assert len(features) == len(labels)
774            rf = RandomForestClassifier(**rf_kwargs)
775            rf.fit(features, labels)
776            # monkey patch these so that we know the feature config and dimensionality
777            rf.feature_ndim = ndim
778            rf.feature_config = serialized_feature_config
779
780            # save the random forest, update pbar, return it
781            out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
782            with open(out_path, "wb") as f:
783                pickle.dump(rf, f)
784
785            # monkey patch the training data and labels so we can re-use it in later stages
786            rf.train_features = features
787            rf.train_labels = labels
788
789            pbar.update(1)
790            return rf
791
792        for stage in range(n_stages):
793            pbar.set_description(f"Train RFs for stage {stage}")
794            with futures.ThreadPoolExecutor(n_threads) as tp:
795                this_forests = list(tp.map(
796                    _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1))
797                ))
798                forests.extend(this_forests)
class RFSegmentationDataset(typing.Generic[+T_co]):
25class RFSegmentationDataset(torch_em.data.SegmentationDataset):
26    _patch_shape_min = None
27    _patch_shape_max = None
28
29    @property
30    def patch_shape_min(self):
31        return self._patch_shape_min
32
33    @patch_shape_min.setter
34    def patch_shape_min(self, value):
35        self._patch_shape_min = value
36
37    @property
38    def patch_shape_max(self):
39        return self._patch_shape_max
40
41    @patch_shape_max.setter
42    def patch_shape_max(self, value):
43        self._patch_shape_max = value
44
45    def _sample_bounding_box(self):
46        assert self._patch_shape_min is not None and self._patch_shape_max is not None
47        sample_shape = [
48            pmin if pmin == pmax else np.random.randint(pmin, pmax)
49            for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max)
50        ]
51        bb_start = [
52            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
53            for sh, psh in zip(self.shape, sample_shape)
54        ]
55        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, sample_shape))
patch_shape_min
patch_shape_max
class RFImageCollectionDataset(typing.Generic[+T_co]):
58class RFImageCollectionDataset(torch_em.data.ImageCollectionDataset):
59    _patch_shape_min = None
60    _patch_shape_max = None
61
62    @property
63    def patch_shape_min(self):
64        return self._patch_shape_min
65
66    @patch_shape_min.setter
67    def patch_shape_min(self, value):
68        self._patch_shape_min = value
69
70    @property
71    def patch_shape_max(self):
72        return self._patch_shape_max
73
74    @patch_shape_max.setter
75    def patch_shape_max(self, value):
76        self._patch_shape_max = value
77
78    def _sample_bounding_box(self, shape):
79        if any(sh < psh for sh, psh in zip(shape, self.patch_shape_max)):
80            raise NotImplementedError("Image padding is not supported yet.")
81        assert self._patch_shape_min is not None and self._patch_shape_max is not None
82        patch_shape = [
83            pmin if pmin == pmax else np.random.randint(pmin, pmax)
84            for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max)
85        ]
86        bb_start = [
87            np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, patch_shape)
88        ]
89        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape))

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

patch_shape_min
patch_shape_max
def prepare_shallow2deep( raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, n_forests, n_threads, output_folder, ndim, raw_transform=None, label_transform=None, rois=None, is_seg_dataset=None, balance_labels=True, filter_config=None, sampler=None, **rf_kwargs):
352def prepare_shallow2deep(
353    raw_paths,
354    raw_key,
355    label_paths,
356    label_key,
357    patch_shape_min,
358    patch_shape_max,
359    n_forests,
360    n_threads,
361    output_folder,
362    ndim,
363    raw_transform=None,
364    label_transform=None,
365    rois=None,
366    is_seg_dataset=None,
367    balance_labels=True,
368    filter_config=None,
369    sampler=None,
370    **rf_kwargs,
371):
372    os.makedirs(output_folder, exist_ok=True)
373    ds, filters_and_sigmas = _prepare_shallow2deep(
374        raw_paths, raw_key, label_paths, label_key,
375        patch_shape_min, patch_shape_max, n_forests, ndim,
376        raw_transform, label_transform, rois, is_seg_dataset,
377        filter_config, sampler,
378    )
379    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
380
381    def _train_rf(rf_id):
382        # sample random patch with dataset
383        raw, labels = ds[rf_id]
384        # cast to numpy and remove channel axis
385        # need to update this to support multi-channel input data and/or multi class prediction
386        raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
387        assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
388        features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels)
389        rf = RandomForestClassifier(**rf_kwargs)
390        rf.fit(features, labels)
391        # monkey patch these so that we know the feature config and dimensionality
392        rf.feature_ndim = ndim
393        rf.feature_config = serialized_feature_config
394        out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
395        with open(out_path, "wb") as f:
396            pickle.dump(rf, f)
397
398    with futures.ThreadPoolExecutor(n_threads) as tp:
399        list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests))
def worst_points( features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, **kwargs):
438def worst_points(
439    features, labels, rf_id,
440    forests, forests_per_stage,
441    sample_fraction_per_stage,
442    accumulate_samples=True,
443    **kwargs,
444):
445    def score(pred, labels):
446        # labels to one-hot encoding
447        unique, inverse = np.unique(labels, return_inverse=True)
448        onehot = np.eye(unique.shape[0])[inverse]
449        # compute the difference between labels and prediction
450        return np.abs(onehot - pred).sum(axis=1)
451
452    return _score_based_points(
453        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
454    )
def uncertain_points( features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, **kwargs):
457def uncertain_points(
458    features, labels, rf_id,
459    forests, forests_per_stage,
460    sample_fraction_per_stage,
461    accumulate_samples=True,
462    **kwargs,
463):
464    def score(pred, labels):
465        assert pred.ndim == 2
466        channel_sorted = np.sort(pred, axis=1)
467        uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2]
468        return uncertainty
469
470    return _score_based_points(
471        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
472    )
def uncertain_worst_points( features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, alpha=0.5, **kwargs):
475def uncertain_worst_points(
476    features, labels, rf_id,
477    forests, forests_per_stage,
478    sample_fraction_per_stage,
479    accumulate_samples=True,
480    alpha=0.5,
481    **kwargs,
482):
483    def score(pred, labels):
484        assert pred.ndim == 2
485
486        # labels to one-hot encoding
487        unique, inverse = np.unique(labels, return_inverse=True)
488        onehot = np.eye(unique.shape[0])[inverse]
489        # compute the difference between labels and prediction
490        diff = np.abs(onehot - pred).sum(axis=1)
491
492        channel_sorted = np.sort(pred, axis=1)
493        uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2]
494        return alpha * diff + (1.0 - alpha) * uncertainty
495
496    return _score_based_points(
497        score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples
498    )
def random_points( features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, **kwargs):
501def random_points(
502    features, labels, rf_id,
503    forests, forests_per_stage,
504    sample_fraction_per_stage,
505    accumulate_samples=True,
506    **kwargs,
507):
508    samples = []
509    nc = len(np.unique(labels))
510    # sample in a class balanced way
511    n_samples = int(sample_fraction_per_stage * len(features))
512    n_samples_class = n_samples // nc
513    for class_id in range(nc):
514        class_indices = np.where(labels == class_id)[0]
515        this_samples = np.random.choice(
516            class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
517        )
518        samples.append(this_samples)
519    samples = np.concatenate(samples)
520    features, labels = features[samples], labels[samples]
521
522    if accumulate_samples and rf_id >= forests_per_stage:
523        last_forest = forests[rf_id - forests_per_stage]
524        features = np.concatenate([last_forest.train_features, features], axis=0)
525        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
526
527    return features, labels
def worst_tiles( features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, img_shape, mask, tile_shape=[25, 25], smoothing_sigma=None, accumulate_samples=True, **kwargs):
530def worst_tiles(
531    features, labels, rf_id,
532    forests, forests_per_stage,
533    sample_fraction_per_stage,
534    img_shape,
535    mask,
536    tile_shape=[25, 25],
537    smoothing_sigma=None,
538    accumulate_samples=True,
539    **kwargs,
540):
541    # check inputs
542    ndim = len(img_shape)
543    assert ndim in [2, 3], img_shape
544    assert len(tile_shape) == ndim, tile_shape
545
546    # get the corresponding random forest from the last stage
547    # and predict with it
548    last_forest = forests[rf_id - forests_per_stage]
549    pred = last_forest.predict_proba(features)
550
551    # labels to one-hot encoding
552    unique, inverse = np.unique(labels, return_inverse=True)
553    onehot = np.eye(unique.shape[0])[inverse]
554
555    # compute the difference between labels and prediction
556    diff = np.abs(onehot - pred)
557    assert len(diff) == len(features)
558
559    # reshape diff to image shape
560    # we need to also take into account the mask here, and if we apply any masking
561    # because we can't directly reshape if we have it
562    if mask.sum() != mask.size:
563        # get the diff image
564        diff_img = np.zeros(img_shape + diff.shape[-1:], dtype=diff.dtype)
565        diff_img[mask] = diff
566        # inflate the features
567        full_features = np.zeros((mask.size,) + features.shape[-1:], dtype=features.dtype)
568        full_features[mask.ravel()] = features
569        features = full_features
570        # inflate the labels (with -1 so this will not be sampled)
571        full_labels = np.full(mask.size, -1, dtype="int8")
572        full_labels[mask.ravel()] = labels
573        labels = full_labels
574    else:
575        diff_img = diff.reshape(img_shape + (-1,))
576
577    # get the number of classes (not counting ignore label)
578    class_ids = np.unique(labels)
579    nc = len(class_ids) - 1 if -1 in class_ids else len(class_ids)
580
581    # sample in a class balanced way
582    n_samples_class = int(sample_fraction_per_stage * len(features)) // nc
583    samples = []
584    for class_id in range(nc):
585        # smooth either with gaussian or 1-kernel
586        if smoothing_sigma:
587            diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode="constant")
588        else:
589            kernel = np.ones(tile_shape)
590            diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode="constant")
591
592        # get training samples based on tiles around maxima of the label-prediction diff
593        # do this in a class-specific way to ensure that each class is sampled
594        # get maxima of the label-prediction diff (they seem to be sorted already)
595        max_centers = peak_local_max(
596            diff_img_smooth,
597            min_distance=max(tile_shape),
598            exclude_border=tuple([s // 2 for s in tile_shape])
599        )
600
601        # get indices of tiles around maxima
602        tiles = []
603        for center in max_centers:
604            tile_slice = tuple(
605                slice(
606                    center[d]-tile_shape[d]//2,
607                    center[d]+tile_shape[d]//2 + 1,
608                    None
609                ) for d in range(ndim)
610            )
611            grid = np.mgrid[tile_slice]
612            samples_in_tile = grid.reshape(ndim, -1)
613            samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape)
614            tiles.append(samples_in_tile)
615
616        # this (very rarely) fails due to empty tile list. Since we usually
617        # accumulate the features this doesn't hurt much and we can continue
618        try:
619            tiles = np.concatenate(tiles)
620            # take samples that belong to the current class
621            this_samples = tiles[labels[tiles] == class_id][:n_samples_class]
622            samples.append(this_samples)
623        except ValueError:
624            pass
625
626    try:
627        samples = np.concatenate(samples)
628        features, labels = features[samples], labels[samples]
629
630        # get the features and labels, add from previous rf if specified
631        if accumulate_samples:
632            features = np.concatenate([last_forest.train_features, features], axis=0)
633            labels = np.concatenate([last_forest.train_labels, labels], axis=0)
634    except ValueError:
635        features, labels = last_forest.train_features, last_forest.train_labels
636        warnings.warn(
637            f"No features were sampled for forest {rf_id} using features of forest {rf_id - forests_per_stage}"
638        )
639
640    return features, labels
def balanced_dense_accumulate( features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, **kwargs):
643def balanced_dense_accumulate(
644    features, labels, rf_id,
645    forests, forests_per_stage,
646    sample_fraction_per_stage,
647    accumulate_samples=True,
648    **kwargs,
649):
650    samples = []
651    nc = len(np.unique(labels))
652    # sample in a class balanced way
653    # take all pixels from minority class
654    # and choose same amount from other classes randomly
655    n_samples_class = np.unique(labels, return_counts=True)[1].min()
656    for class_id in range(nc):
657        class_indices = np.where(labels == class_id)[0]
658        this_samples = np.random.choice(
659            class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
660        )
661        samples.append(this_samples)
662    samples = np.concatenate(samples)
663    features, labels = features[samples], labels[samples]
664
665    # accumulate
666    if accumulate_samples and rf_id >= forests_per_stage:
667        last_forest = forests[rf_id - forests_per_stage]
668        features = np.concatenate([last_forest.train_features, features], axis=0)
669        labels = np.concatenate([last_forest.train_labels, labels], axis=0)
670
671    return features, labels
SAMPLING_STRATEGIES = {'random_points': <function random_points>, 'uncertain_points': <function uncertain_points>, 'uncertain_worst_points': <function uncertain_worst_points>, 'worst_points': <function worst_points>, 'worst_tiles': <function worst_tiles>, 'balanced_dense_accumulate': <function balanced_dense_accumulate>}
def prepare_shallow2deep_advanced( raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, n_forests, n_threads, output_folder, ndim, forests_per_stage, sample_fraction_per_stage, sampling_strategy='worst_points', sampling_kwargs={}, raw_transform=None, label_transform=None, rois=None, is_seg_dataset=None, filter_config=None, sampler=None, **rf_kwargs):
684def prepare_shallow2deep_advanced(
685    raw_paths,
686    raw_key,
687    label_paths,
688    label_key,
689    patch_shape_min,
690    patch_shape_max,
691    n_forests,
692    n_threads,
693    output_folder,
694    ndim,
695    forests_per_stage,
696    sample_fraction_per_stage,
697    sampling_strategy="worst_points",
698    sampling_kwargs={},
699    raw_transform=None,
700    label_transform=None,
701    rois=None,
702    is_seg_dataset=None,
703    filter_config=None,
704    sampler=None,
705    **rf_kwargs,
706):
707    """Advanced training of random forests for shallow2deep enhancer training.
708
709    This function accepts the 'sampling_strategy' parameter, which allows to implement custom
710    sampling strategies for the samples used for training the random forests.
711    Training operates in stages, the parameter 'forests_per_stage' determines how many forests
712    are trained in each stage, and 'sample_fraction_per_stage' which fraction of the samples is
713    taken per stage. The random forests in stage 0 are trained from random balanced labels.
714    For the other stages 'sampling_strategy' enables specifying the strategy; it has to be a function
715    with signature '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)',
716    and return the sampled features and labels. See for the 'worst_points' function
717    in this file for an example implementation.
718    """
719    os.makedirs(output_folder, exist_ok=True)
720    ds, filters_and_sigmas = _prepare_shallow2deep(
721        raw_paths, raw_key, label_paths, label_key,
722        patch_shape_min, patch_shape_max, n_forests, ndim,
723        raw_transform, label_transform, rois, is_seg_dataset,
724        filter_config, sampler,
725    )
726    serialized_feature_config = _serialize_feature_config(filters_and_sigmas)
727
728    forests = []
729    n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\
730        n_forests // forests_per_stage + 1
731
732    if isinstance(sampling_strategy, str):
733        assert sampling_strategy in SAMPLING_STRATEGIES,\
734            f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}"
735        sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy]
736    assert callable(sampling_strategy)
737
738    with tqdm(total=n_forests) as pbar:
739
740        def _train_rf(rf_id):
741            # sample random patch with dataset
742            raw, labels = ds[rf_id]
743
744            # cast to numpy and remove channel axis
745            # need to update this to support multi-channel input data and/or multi class prediction
746            raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
747            assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"
748
749            # monkey patch original shape to sampling_kwargs
750            # deepcopy needed due to multithreading
751            current_kwargs = copy.deepcopy(sampling_kwargs)
752            current_kwargs["img_shape"] = raw.shape
753
754            # only balance samples for the first (densely trained) rfs
755            features, labels, mask = _get_features_and_labels(
756                raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True
757            )
758            if forests:  # we have forests: apply the sampling strategy
759                features, labels = sampling_strategy(
760                    features, labels, rf_id,
761                    forests, forests_per_stage,
762                    sample_fraction_per_stage,
763                    mask=mask,
764                    **current_kwargs,
765                )
766            else:  # sample randomly
767                features, labels = random_points(
768                    features, labels, rf_id,
769                    forests, forests_per_stage,
770                    sample_fraction_per_stage,
771                )
772
773            # fit the random forest
774            assert len(features) == len(labels)
775            rf = RandomForestClassifier(**rf_kwargs)
776            rf.fit(features, labels)
777            # monkey patch these so that we know the feature config and dimensionality
778            rf.feature_ndim = ndim
779            rf.feature_config = serialized_feature_config
780
781            # save the random forest, update pbar, return it
782            out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl")
783            with open(out_path, "wb") as f:
784                pickle.dump(rf, f)
785
786            # monkey patch the training data and labels so we can re-use it in later stages
787            rf.train_features = features
788            rf.train_labels = labels
789
790            pbar.update(1)
791            return rf
792
793        for stage in range(n_stages):
794            pbar.set_description(f"Train RFs for stage {stage}")
795            with futures.ThreadPoolExecutor(n_threads) as tp:
796                this_forests = list(tp.map(
797                    _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1))
798                ))
799                forests.extend(this_forests)

Advanced training of random forests for shallow2deep enhancer training.

This function accepts the 'sampling_strategy' parameter, which allows to implement custom sampling strategies for the samples used for training the random forests. Training operates in stages, the parameter 'forests_per_stage' determines how many forests are trained in each stage, and 'sample_fraction_per_stage' which fraction of the samples is taken per stage. The random forests in stage 0 are trained from random balanced labels. For the other stages 'sampling_strategy' enables specifying the strategy; it has to be a function with signature '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)', and return the sampled features and labels. See for the 'worst_points' function in this file for an example implementation.