torch_em.shallow2deep.prepare_shallow2deep
1import os 2import copy 3import pickle 4import warnings 5from concurrent import futures 6from glob import glob 7from functools import partial 8from typing import Callable, Dict, Optional, Sequence, Tuple, Union 9 10import numpy as np 11import torch_em 12from scipy.ndimage import gaussian_filter, convolve 13from skimage.feature import peak_local_max 14from sklearn.ensemble import RandomForestClassifier 15from torch_em.segmentation import check_paths, is_segmentation_dataset, samples_to_datasets 16from tqdm import tqdm 17 18import bioimage_cpp as bic 19 20# Map legacy vigra/fastfilters CamelCase filter names to bioimage_cpp.filters functions. 21_FILTER_NAMES = { 22 "gaussianSmoothing": "gaussian_smoothing", 23 "laplacianOfGaussian": "laplacian_of_gaussian", 24 "gaussianGradientMagnitude": "gaussian_gradient_magnitude", 25 "hessianOfGaussianEigenvalues": "hessian_of_gaussian_eigenvalues", 26 "structureTensorEigenvalues": "structure_tensor_eigenvalues", 27} 28 29 30def _resolve_filter(name): 31 return getattr(bic.filters, _FILTER_NAMES.get(name, name)) 32 33 34class RFSegmentationDataset(torch_em.data.SegmentationDataset): 35 """@private 36 """ 37 _patch_shape_min = None 38 _patch_shape_max = None 39 40 @property 41 def patch_shape_min(self): 42 return self._patch_shape_min 43 44 @patch_shape_min.setter 45 def patch_shape_min(self, value): 46 self._patch_shape_min = value 47 48 @property 49 def patch_shape_max(self): 50 return self._patch_shape_max 51 52 @patch_shape_max.setter 53 def patch_shape_max(self, value): 54 self._patch_shape_max = value 55 56 def _sample_bounding_box(self): 57 assert self._patch_shape_min is not None and self._patch_shape_max is not None 58 sample_shape = [ 59 pmin if pmin == pmax else np.random.randint(pmin, pmax) 60 for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max) 61 ] 62 bb_start = [ 63 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 64 for sh, psh in zip(self.shape, sample_shape) 65 ] 66 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, sample_shape)) 67 68 69class RFImageCollectionDataset(torch_em.data.ImageCollectionDataset): 70 """@private 71 """ 72 _patch_shape_min = None 73 _patch_shape_max = None 74 75 @property 76 def patch_shape_min(self): 77 return self._patch_shape_min 78 79 @patch_shape_min.setter 80 def patch_shape_min(self, value): 81 self._patch_shape_min = value 82 83 @property 84 def patch_shape_max(self): 85 return self._patch_shape_max 86 87 @patch_shape_max.setter 88 def patch_shape_max(self, value): 89 self._patch_shape_max = value 90 91 def _sample_bounding_box(self, shape): 92 if any(sh < psh for sh, psh in zip(shape, self.patch_shape_max)): 93 raise NotImplementedError("Image padding is not supported yet.") 94 assert self._patch_shape_min is not None and self._patch_shape_max is not None 95 patch_shape = [ 96 pmin if pmin == pmax else np.random.randint(pmin, pmax) 97 for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max) 98 ] 99 bb_start = [ 100 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, patch_shape) 101 ] 102 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape)) 103 104 105def _load_rf_segmentation_dataset( 106 raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, **kwargs 107): 108 rois = kwargs.pop("rois", None) 109 sampler = kwargs.pop("sampler", None) 110 sampler = sampler if sampler else torch_em.data.MinForegroundSampler(min_fraction=0.01) 111 if isinstance(raw_paths, str): 112 if rois is not None: 113 assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois) 114 ds = RFSegmentationDataset( 115 raw_paths, raw_key, label_paths, label_key, roi=rois, patch_shape=patch_shape_min, sampler=sampler, **kwargs 116 ) 117 ds.patch_shape_min = patch_shape_min 118 ds.patch_shape_max = patch_shape_max 119 else: 120 assert len(raw_paths) > 0 121 if rois is not None: 122 assert len(rois) == len(label_paths) 123 assert all(isinstance(roi, tuple) for roi in rois) 124 n_samples = kwargs.pop("n_samples", None) 125 126 samples_per_ds = ( 127 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 128 ) 129 ds = [] 130 for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): 131 roi = None if rois is None else rois[i] 132 dset = RFSegmentationDataset( 133 raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], 134 patch_shape=patch_shape_min, sampler=sampler, **kwargs 135 ) 136 dset.patch_shape_min = patch_shape_min 137 dset.patch_shape_max = patch_shape_max 138 ds.append(dset) 139 ds = torch_em.data.ConcatDataset(*ds) 140 return ds 141 142 143def _load_rf_image_collection_dataset( 144 raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, roi, **kwargs 145): 146 def _get_paths(rpath, rkey, lpath, lkey, this_roi): 147 rpath = glob(os.path.join(rpath, rkey)) 148 rpath.sort() 149 if len(rpath) == 0: 150 raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}") 151 lpath = glob(os.path.join(lpath, lkey)) 152 lpath.sort() 153 if len(rpath) != len(lpath): 154 raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}") 155 156 if this_roi is not None: 157 rpath, lpath = rpath[roi], lpath[roi] 158 159 return rpath, lpath 160 161 def _check_patch(patch_shape): 162 if len(patch_shape) == 3: 163 if patch_shape[0] != 1: 164 raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}") 165 patch_shape = patch_shape[1:] 166 assert len(patch_shape) == 2 167 return patch_shape 168 169 patch_shape_min = _check_patch(patch_shape_min) 170 patch_shape_max = _check_patch(patch_shape_max) 171 172 if isinstance(raw_paths, str): 173 raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi) 174 ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs) 175 ds.patch_shape_min = patch_shape_min 176 ds.patch_shape_max = patch_shape_max 177 elif raw_key is None: 178 assert label_key is None 179 assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple)) 180 assert len(raw_paths) == len(label_paths) 181 ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs) 182 ds.patch_shape_min = patch_shape_min 183 ds.patch_shape_max = patch_shape_max 184 else: 185 ds = [] 186 n_samples = kwargs.pop("n_samples", None) 187 samples_per_ds = ( 188 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 189 ) 190 if roi is None: 191 roi = len(raw_paths) * [None] 192 assert len(roi) == len(raw_paths) 193 for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)): 194 rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi) 195 dset = RFImageCollectionDataset( 196 rpath, lpath, patch_shape=patch_shape_min, n_samples=samples_per_ds[i], **kwargs 197 ) 198 dset.patch_shape_min = patch_shape_min 199 dset.patch_shape_max = patch_shape_max 200 ds.append(dset) 201 ds = torch_em.data.ConcatDataset(*ds) 202 return ds 203 204 205def _get_filters(ndim, filters_and_sigmas): 206 # subset of ilastik default features 207 if filters_and_sigmas is None: 208 filters = [bic.filters.gaussian_smoothing, 209 bic.filters.laplacian_of_gaussian, 210 bic.filters.gaussian_gradient_magnitude, 211 bic.filters.hessian_of_gaussian_eigenvalues, 212 bic.filters.structure_tensor_eigenvalues] 213 sigmas = [0.7, 1.6, 3.5, 5.0] 214 filters_and_sigmas = [ 215 (filt, sigma) if i != len(filters) - 1 else (partial(filt, outer_sigma=0.5*sigma), sigma) 216 for i, filt in enumerate(filters) for sigma in sigmas 217 ] 218 # validate the filter config 219 assert isinstance(filters_and_sigmas, (list, tuple)) 220 for filt_and_sig in filters_and_sigmas: 221 filt, sig = filt_and_sig 222 assert callable(filt) or (isinstance(filt, str) and hasattr(bic.filters, _FILTER_NAMES.get(filt, filt))) 223 assert isinstance(sig, (float, tuple)) 224 if isinstance(sig, tuple): 225 assert ndim is not None and len(sig) == ndim 226 assert all(isinstance(sigg, float) for sigg in sig) 227 return filters_and_sigmas 228 229 230def _calculate_response(raw, filter_, sigma): 231 if callable(filter_): 232 return filter_(raw, sigma) 233 234 # filter_ is still a string, convert it to a bioimage_cpp.filters function. 235 # bioimage_cpp.filters supports passing sigma as a scalar or as a per-axis tuple. 236 func = _resolve_filter(filter_) 237 238 # special case since the additional argument outer_sigma 239 # is needed for the structure tensor eigenvalues filter 240 if filter_ in ("structureTensorEigenvalues", "structure_tensor_eigenvalues"): 241 outer_sigma = tuple(s*2 for s in sigma) if isinstance(sigma, tuple) else 2*sigma 242 return func(raw, sigma, outer_sigma=outer_sigma) 243 244 return func(raw, sigma) 245 246 247def _apply_filters(raw, filters_and_sigmas): 248 features = [] 249 for filter_, sigma in filters_and_sigmas: 250 response = _calculate_response(raw, filter_, sigma) 251 if response.ndim > raw.ndim: 252 for c in range(response.shape[-1]): 253 features.append(response[..., c].flatten()) 254 else: 255 features.append(response.flatten()) 256 features = np.concatenate([ff[:, None] for ff in features], axis=1) 257 return features 258 259 260def _apply_filters_with_mask(raw, filters_and_sigmas, mask): 261 features = [] 262 for filter_, sigma in filters_and_sigmas: 263 response = _calculate_response(raw, filter_, sigma) 264 if response.ndim > raw.ndim: 265 for c in range(response.shape[-1]): 266 features.append(response[..., c][mask]) 267 else: 268 features.append(response[mask]) 269 features = np.concatenate([ff[:, None] for ff in features], axis=1) 270 return features 271 272 273def _balance_labels(labels, mask): 274 class_ids, label_counts = np.unique(labels[mask], return_counts=True) 275 n_classes = len(class_ids) 276 assert class_ids.tolist() == list(range(n_classes)) 277 278 min_class = class_ids[np.argmin(label_counts)] 279 n_labels = label_counts[min_class] 280 281 for class_id in class_ids: 282 if class_id == min_class: 283 continue 284 n_discard = label_counts[class_id] - n_labels 285 # sample from the current class 286 # shuffle the positions and only keep up to n_labels in the mask 287 label_pos = np.where(labels == class_id) 288 discard_ids = np.arange(len(label_pos[0])) 289 np.random.shuffle(discard_ids) 290 discard_ids = discard_ids[:n_discard] 291 discard_mask = tuple(pos[discard_ids] for pos in label_pos) 292 mask[discard_mask] = False 293 294 assert mask.sum() == n_classes * n_labels 295 return mask 296 297 298def _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels, return_mask=False): 299 # find the mask for where we compute filters and labels 300 # by default we exclude everything that has label -1 301 assert labels.shape == raw.shape 302 mask = labels != -1 303 if balance_labels: 304 mask = _balance_labels(labels, mask) 305 labels = labels[mask] 306 assert labels.ndim == 1 307 features = _apply_filters_with_mask(raw, filters_and_sigmas, mask) 308 assert features.ndim == 2 309 assert len(features) == len(labels) 310 if return_mask: 311 return features, labels, mask 312 else: 313 return features, labels 314 315 316def _prepare_shallow2deep( 317 raw_paths, 318 raw_key, 319 label_paths, 320 label_key, 321 patch_shape_min, 322 patch_shape_max, 323 n_forests, 324 ndim, 325 raw_transform, 326 label_transform, 327 rois, 328 is_seg_dataset, 329 filter_config, 330 sampler, 331): 332 assert len(patch_shape_min) == len(patch_shape_max) 333 assert all(maxs >= mins for maxs, mins in zip(patch_shape_max, patch_shape_min)) 334 check_paths(raw_paths, label_paths) 335 336 # get the correct dataset 337 if is_seg_dataset is None: 338 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 339 if is_seg_dataset: 340 ds = _load_rf_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, 341 patch_shape_min, patch_shape_max, 342 raw_transform=raw_transform, label_transform=label_transform, 343 rois=rois, n_samples=n_forests, sampler=sampler) 344 else: 345 ds = _load_rf_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, 346 patch_shape_min, patch_shape_max, roi=rois, 347 raw_transform=raw_transform, label_transform=label_transform, 348 n_samples=n_forests) 349 350 assert len(ds) == n_forests, f"{len(ds), {n_forests}}" 351 filters_and_sigmas = _get_filters(ndim, filter_config) 352 return ds, filters_and_sigmas 353 354 355def _serialize_feature_config(filters_and_sigmas): 356 feature_config = [ 357 (filt if isinstance(filt, str) else (filt.func.__name__ if isinstance(filt, partial) else filt.__name__), sigma) 358 for filt, sigma in filters_and_sigmas 359 ] 360 return feature_config 361 362 363def prepare_shallow2deep( 364 raw_paths: Union[str, Sequence[str]], 365 raw_key: Optional[str], 366 label_paths: Union[str, Sequence[str]], 367 label_key: Optional[str], 368 patch_shape_min: Tuple[int, ...], 369 patch_shape_max: Tuple[int, ...], 370 n_forests: int, 371 n_threads: int, 372 output_folder: str, 373 ndim: int, 374 raw_transform: Optional[Callable] = None, 375 label_transform: Optional[Callable] = None, 376 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 377 is_seg_dataset: Optional[bool] = None, 378 balance_labels: bool = True, 379 filter_config: Optional[Dict] = None, 380 sampler: Optional[Callable] = None, 381 **rf_kwargs, 382) -> None: 383 """Prepare shallow2deep enhancer training by pre-training random forests. 384 385 Args: 386 raw_paths: The file paths to the raw data. May also be a single file. 387 raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif. 388 label_paths: The file paths to the lable data. May also be a single file. 389 label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif. 390 patch_shape_min: The minimal patch shape loaded for training a random forest. 391 patch_shape_max: The maximal patch shape loaded for training a random forest. 392 n_forests: The number of random forests to train. 393 n_threads: The number of threads for parallelizing the training. 394 output_folder: The folder for saving the random forests. 395 ndim: The dimensionality of the data. 396 raw_transform: A transform to apply to the raw data before computing feautres on it. 397 label_transform: A transform to apply to the label data before deriving targets for the random forest for it. 398 rois: Region of interests for the training data. 399 is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. 400 If None, this wil be determined from the data. 401 balance_labels: Whether to balance the training labels for the random forest. 402 filter_config: The configuration for the image filters that are used to compute features for the random forest. 403 sampler: A sampler to reject samples from training. 404 rf_kwargs: Keyword arguments for creating the random forest. 405 """ 406 os.makedirs(output_folder, exist_ok=True) 407 ds, filters_and_sigmas = _prepare_shallow2deep( 408 raw_paths, raw_key, label_paths, label_key, 409 patch_shape_min, patch_shape_max, n_forests, ndim, 410 raw_transform, label_transform, rois, is_seg_dataset, 411 filter_config, sampler, 412 ) 413 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 414 415 def _train_rf(rf_id): 416 # Sample random patch with dataset. 417 raw, labels = ds[rf_id] 418 # Cast to numpy and remove channel axis. 419 # Need to update this to support multi-channel input data and/or multi class prediction. 420 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 421 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 422 features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels) 423 rf = RandomForestClassifier(**rf_kwargs) 424 rf.fit(features, labels) 425 # Monkey patch these so that we know the feature config and dimensionality. 426 rf.feature_ndim = ndim 427 rf.feature_config = serialized_feature_config 428 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 429 with open(out_path, "wb") as f: 430 pickle.dump(rf, f) 431 432 with futures.ThreadPoolExecutor(n_threads) as tp: 433 list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests)) 434 435 436def _score_based_points( 437 score_function, 438 features, labels, rf_id, 439 forests, forests_per_stage, 440 sample_fraction_per_stage, 441 accumulate_samples, 442): 443 # get the corresponding random forest from the last stage 444 # and predict with it 445 last_forest = forests[rf_id - forests_per_stage] 446 pred = last_forest.predict_proba(features) 447 448 score = score_function(pred, labels) 449 assert len(score) == len(features) 450 451 # get training samples based on the label-prediction diff 452 samples = [] 453 nc = len(np.unique(labels)) 454 # sample in a class balanced way 455 n_samples = int(sample_fraction_per_stage * len(features)) 456 n_samples_class = n_samples // nc 457 for class_id in range(nc): 458 class_indices = np.where(labels == class_id)[0] 459 this_samples = class_indices[np.argsort(score[class_indices])[::-1][:n_samples_class]] 460 samples.append(this_samples) 461 samples = np.concatenate(samples) 462 463 # get the features and labels, add from previous rf if specified 464 features, labels = features[samples], labels[samples] 465 if accumulate_samples: 466 features = np.concatenate([last_forest.train_features, features], axis=0) 467 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 468 469 return features, labels 470 471 472def worst_points( 473 features, labels, rf_id, 474 forests, forests_per_stage, 475 sample_fraction_per_stage, 476 accumulate_samples=True, 477 **kwargs, 478): 479 """@private 480 """ 481 def score(pred, labels): 482 # labels to one-hot encoding 483 unique, inverse = np.unique(labels, return_inverse=True) 484 onehot = np.eye(unique.shape[0])[inverse] 485 # compute the difference between labels and prediction 486 return np.abs(onehot - pred).sum(axis=1) 487 488 return _score_based_points( 489 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 490 ) 491 492 493def uncertain_points( 494 features, labels, rf_id, 495 forests, forests_per_stage, 496 sample_fraction_per_stage, 497 accumulate_samples=True, 498 **kwargs, 499): 500 """@private 501 """ 502 def score(pred, labels): 503 assert pred.ndim == 2 504 channel_sorted = np.sort(pred, axis=1) 505 uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2] 506 return uncertainty 507 508 return _score_based_points( 509 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 510 ) 511 512 513def uncertain_worst_points( 514 features, labels, rf_id, 515 forests, forests_per_stage, 516 sample_fraction_per_stage, 517 accumulate_samples=True, 518 alpha=0.5, 519 **kwargs, 520): 521 """@private 522 """ 523 def score(pred, labels): 524 assert pred.ndim == 2 525 526 # labels to one-hot encoding 527 unique, inverse = np.unique(labels, return_inverse=True) 528 onehot = np.eye(unique.shape[0])[inverse] 529 # compute the difference between labels and prediction 530 diff = np.abs(onehot - pred).sum(axis=1) 531 532 channel_sorted = np.sort(pred, axis=1) 533 uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2] 534 return alpha * diff + (1.0 - alpha) * uncertainty 535 536 return _score_based_points( 537 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 538 ) 539 540 541def random_points( 542 features, labels, rf_id, 543 forests, forests_per_stage, 544 sample_fraction_per_stage, 545 accumulate_samples=True, 546 **kwargs, 547): 548 """@private 549 """ 550 samples = [] 551 nc = len(np.unique(labels)) 552 # sample in a class balanced way 553 n_samples = int(sample_fraction_per_stage * len(features)) 554 n_samples_class = n_samples // nc 555 for class_id in range(nc): 556 class_indices = np.where(labels == class_id)[0] 557 this_samples = np.random.choice( 558 class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class 559 ) 560 samples.append(this_samples) 561 samples = np.concatenate(samples) 562 features, labels = features[samples], labels[samples] 563 564 if accumulate_samples and rf_id >= forests_per_stage: 565 last_forest = forests[rf_id - forests_per_stage] 566 features = np.concatenate([last_forest.train_features, features], axis=0) 567 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 568 569 return features, labels 570 571 572def worst_tiles( 573 features, labels, rf_id, 574 forests, forests_per_stage, 575 sample_fraction_per_stage, 576 img_shape, 577 mask, 578 tile_shape=[25, 25], 579 smoothing_sigma=None, 580 accumulate_samples=True, 581 **kwargs, 582): 583 """@private 584 """ 585 # check inputs 586 ndim = len(img_shape) 587 assert ndim in [2, 3], img_shape 588 assert len(tile_shape) == ndim, tile_shape 589 590 # get the corresponding random forest from the last stage 591 # and predict with it 592 last_forest = forests[rf_id - forests_per_stage] 593 pred = last_forest.predict_proba(features) 594 595 # labels to one-hot encoding 596 unique, inverse = np.unique(labels, return_inverse=True) 597 onehot = np.eye(unique.shape[0])[inverse] 598 599 # compute the difference between labels and prediction 600 diff = np.abs(onehot - pred) 601 assert len(diff) == len(features) 602 603 # reshape diff to image shape 604 # we need to also take into account the mask here, and if we apply any masking 605 # because we can't directly reshape if we have it 606 if mask.sum() != mask.size: 607 # get the diff image 608 diff_img = np.zeros(img_shape + diff.shape[-1:], dtype=diff.dtype) 609 diff_img[mask] = diff 610 # inflate the features 611 full_features = np.zeros((mask.size,) + features.shape[-1:], dtype=features.dtype) 612 full_features[mask.ravel()] = features 613 features = full_features 614 # inflate the labels (with -1 so this will not be sampled) 615 full_labels = np.full(mask.size, -1, dtype="int8") 616 full_labels[mask.ravel()] = labels 617 labels = full_labels 618 else: 619 diff_img = diff.reshape(img_shape + (-1,)) 620 621 # get the number of classes (not counting ignore label) 622 class_ids = np.unique(labels) 623 nc = len(class_ids) - 1 if -1 in class_ids else len(class_ids) 624 625 # sample in a class balanced way 626 n_samples_class = int(sample_fraction_per_stage * len(features)) // nc 627 samples = [] 628 for class_id in range(nc): 629 # smooth either with gaussian or 1-kernel 630 if smoothing_sigma: 631 diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode="constant") 632 else: 633 kernel = np.ones(tile_shape) 634 diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode="constant") 635 636 # get training samples based on tiles around maxima of the label-prediction diff 637 # do this in a class-specific way to ensure that each class is sampled 638 # get maxima of the label-prediction diff (they seem to be sorted already) 639 max_centers = peak_local_max( 640 diff_img_smooth, 641 min_distance=max(tile_shape), 642 exclude_border=tuple([s // 2 for s in tile_shape]) 643 ) 644 645 # get indices of tiles around maxima 646 tiles = [] 647 for center in max_centers: 648 tile_slice = tuple( 649 slice( 650 center[d]-tile_shape[d]//2, 651 center[d]+tile_shape[d]//2 + 1, 652 None 653 ) for d in range(ndim) 654 ) 655 grid = np.mgrid[tile_slice] 656 samples_in_tile = grid.reshape(ndim, -1) 657 samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape) 658 tiles.append(samples_in_tile) 659 660 # this (very rarely) fails due to empty tile list. Since we usually 661 # accumulate the features this doesn't hurt much and we can continue 662 try: 663 tiles = np.concatenate(tiles) 664 # take samples that belong to the current class 665 this_samples = tiles[labels[tiles] == class_id][:n_samples_class] 666 samples.append(this_samples) 667 except ValueError: 668 pass 669 670 try: 671 samples = np.concatenate(samples) 672 features, labels = features[samples], labels[samples] 673 674 # get the features and labels, add from previous rf if specified 675 if accumulate_samples: 676 features = np.concatenate([last_forest.train_features, features], axis=0) 677 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 678 except ValueError: 679 features, labels = last_forest.train_features, last_forest.train_labels 680 warnings.warn( 681 f"No features were sampled for forest {rf_id} using features of forest {rf_id - forests_per_stage}" 682 ) 683 684 return features, labels 685 686 687def balanced_dense_accumulate( 688 features, labels, rf_id, 689 forests, forests_per_stage, 690 sample_fraction_per_stage, 691 accumulate_samples=True, 692 **kwargs, 693): 694 """@private 695 """ 696 samples = [] 697 nc = len(np.unique(labels)) 698 # sample in a class balanced way 699 # take all pixels from minority class 700 # and choose same amount from other classes randomly 701 n_samples_class = np.unique(labels, return_counts=True)[1].min() 702 for class_id in range(nc): 703 class_indices = np.where(labels == class_id)[0] 704 this_samples = np.random.choice( 705 class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class 706 ) 707 samples.append(this_samples) 708 samples = np.concatenate(samples) 709 features, labels = features[samples], labels[samples] 710 711 # accumulate 712 if accumulate_samples and rf_id >= forests_per_stage: 713 last_forest = forests[rf_id - forests_per_stage] 714 features = np.concatenate([last_forest.train_features, features], axis=0) 715 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 716 717 return features, labels 718 719 720SAMPLING_STRATEGIES = { 721 "random_points": random_points, 722 "uncertain_points": uncertain_points, 723 "uncertain_worst_points": uncertain_worst_points, 724 "worst_points": worst_points, 725 "worst_tiles": worst_tiles, 726 "balanced_dense_accumulate": balanced_dense_accumulate, 727} 728"""@private 729""" 730 731 732def prepare_shallow2deep_advanced( 733 raw_paths: Union[str, Sequence[str]], 734 raw_key: Optional[str], 735 label_paths: Union[str, Sequence[str]], 736 label_key: Optional[str], 737 patch_shape_min: Tuple[int, ...], 738 patch_shape_max: Tuple[int, ...], 739 n_forests: int, 740 n_threads: int, 741 output_folder: str, 742 ndim: int, 743 forests_per_stage: int, 744 sample_fraction_per_stage: float, 745 sampling_strategy: Union[str, Callable] = "worst_points", 746 sampling_kwargs: Dict = {}, 747 raw_transform: Optional[Callable] = None, 748 label_transform: Optional[Callable] = None, 749 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 750 is_seg_dataset: Optional[bool] = None, 751 balance_labels: bool = True, 752 filter_config: Optional[Dict] = None, 753 sampler: Optional[Callable] = None, 754 **rf_kwargs, 755) -> None: 756 """Prepare shallow2deep enhancer training by pre-training random forests. 757 758 This function implements an advanced training procedure compared to `prepare_shallow2deep`. 759 The 'sampling_strategy' argument determines an advnaced sampling strategies, 760 which selects the samples to use for training the random forests. 761 The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests 762 are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage. 763 The random forests in stage 0 are trained from random balanced labels. 764 For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature 765 '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)', 766 and return the sampled features and labels. See for example the 'worst_points' function. 767 Alternatively, one of the pre-defined strategies can be selected by passing one of the following names: 768 - "random_poinst": Select random points. 769 - "uncertain_points": Select points with the highest uncertainty. 770 - "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies. 771 - "worst_points": Select the points with the worst accuracies. 772 - "worst_tiles": Selectt the tiles with the worst accuracies. 773 - "balanced_dense_accumulate": Balanced dense accumulation. 774 775 Args: 776 raw_paths: The file paths to the raw data. May also be a single file. 777 raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif. 778 label_paths: The file paths to the lable data. May also be a single file. 779 label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif. 780 patch_shape_min: The minimal patch shape loaded for training a random forest. 781 patch_shape_max: The maximal patch shape loaded for training a random forest. 782 n_forests: The number of random forests to train. 783 n_threads: The number of threads for parallelizing the training. 784 output_folder: The folder for saving the random forests. 785 ndim: The dimensionality of the data. 786 forests_per_stage: The number of forests to train per stage. 787 sample_fraction_per_stage: The fraction of samples to use per stage. 788 sampling_strategy: The sampling strategy. 789 sampling_kwargs: The keyword arguments for the sampling strategy. 790 raw_transform: A transform to apply to the raw data before computing feautres on it. 791 label_transform: A transform to apply to the label data before deriving targets for the random forest for it. 792 rois: Region of interests for the training data. 793 is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. 794 If None, this wil be determined from the data. 795 balance_labels: Whether to balance the training labels for the random forest. 796 filter_config: The configuration for the image filters that are used to compute features for the random forest. 797 sampler: A sampler to reject samples from training. 798 rf_kwargs: Keyword arguments for creating the random forest. 799 """ 800 os.makedirs(output_folder, exist_ok=True) 801 ds, filters_and_sigmas = _prepare_shallow2deep( 802 raw_paths, raw_key, label_paths, label_key, 803 patch_shape_min, patch_shape_max, n_forests, ndim, 804 raw_transform, label_transform, rois, is_seg_dataset, 805 filter_config, sampler, 806 ) 807 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 808 809 forests = [] 810 n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\ 811 n_forests // forests_per_stage + 1 812 813 if isinstance(sampling_strategy, str): 814 assert sampling_strategy in SAMPLING_STRATEGIES, \ 815 f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}" 816 sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy] 817 assert callable(sampling_strategy) 818 819 with tqdm(total=n_forests) as pbar: 820 821 def _train_rf(rf_id): 822 # sample random patch with dataset 823 raw, labels = ds[rf_id] 824 825 # cast to numpy and remove channel axis 826 # need to update this to support multi-channel input data and/or multi class prediction 827 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 828 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 829 830 # monkey patch original shape to sampling_kwargs 831 # deepcopy needed due to multithreading 832 current_kwargs = copy.deepcopy(sampling_kwargs) 833 current_kwargs["img_shape"] = raw.shape 834 835 # only balance samples for the first (densely trained) rfs 836 features, labels, mask = _get_features_and_labels( 837 raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True 838 ) 839 if forests: # we have forests: apply the sampling strategy 840 features, labels = sampling_strategy( 841 features, labels, rf_id, 842 forests, forests_per_stage, 843 sample_fraction_per_stage, 844 mask=mask, 845 **current_kwargs, 846 ) 847 else: # sample randomly 848 features, labels = random_points( 849 features, labels, rf_id, 850 forests, forests_per_stage, 851 sample_fraction_per_stage, 852 ) 853 854 # fit the random forest 855 assert len(features) == len(labels) 856 rf = RandomForestClassifier(**rf_kwargs) 857 rf.fit(features, labels) 858 # monkey patch these so that we know the feature config and dimensionality 859 rf.feature_ndim = ndim 860 rf.feature_config = serialized_feature_config 861 862 # save the random forest, update pbar, return it 863 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 864 with open(out_path, "wb") as f: 865 pickle.dump(rf, f) 866 867 # monkey patch the training data and labels so we can re-use it in later stages 868 rf.train_features = features 869 rf.train_labels = labels 870 871 pbar.update(1) 872 return rf 873 874 for stage in range(n_stages): 875 pbar.set_description(f"Train RFs for stage {stage}") 876 with futures.ThreadPoolExecutor(n_threads) as tp: 877 this_forests = list(tp.map( 878 _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1)) 879 )) 880 forests.extend(this_forests)
364def prepare_shallow2deep( 365 raw_paths: Union[str, Sequence[str]], 366 raw_key: Optional[str], 367 label_paths: Union[str, Sequence[str]], 368 label_key: Optional[str], 369 patch_shape_min: Tuple[int, ...], 370 patch_shape_max: Tuple[int, ...], 371 n_forests: int, 372 n_threads: int, 373 output_folder: str, 374 ndim: int, 375 raw_transform: Optional[Callable] = None, 376 label_transform: Optional[Callable] = None, 377 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 378 is_seg_dataset: Optional[bool] = None, 379 balance_labels: bool = True, 380 filter_config: Optional[Dict] = None, 381 sampler: Optional[Callable] = None, 382 **rf_kwargs, 383) -> None: 384 """Prepare shallow2deep enhancer training by pre-training random forests. 385 386 Args: 387 raw_paths: The file paths to the raw data. May also be a single file. 388 raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif. 389 label_paths: The file paths to the lable data. May also be a single file. 390 label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif. 391 patch_shape_min: The minimal patch shape loaded for training a random forest. 392 patch_shape_max: The maximal patch shape loaded for training a random forest. 393 n_forests: The number of random forests to train. 394 n_threads: The number of threads for parallelizing the training. 395 output_folder: The folder for saving the random forests. 396 ndim: The dimensionality of the data. 397 raw_transform: A transform to apply to the raw data before computing feautres on it. 398 label_transform: A transform to apply to the label data before deriving targets for the random forest for it. 399 rois: Region of interests for the training data. 400 is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. 401 If None, this wil be determined from the data. 402 balance_labels: Whether to balance the training labels for the random forest. 403 filter_config: The configuration for the image filters that are used to compute features for the random forest. 404 sampler: A sampler to reject samples from training. 405 rf_kwargs: Keyword arguments for creating the random forest. 406 """ 407 os.makedirs(output_folder, exist_ok=True) 408 ds, filters_and_sigmas = _prepare_shallow2deep( 409 raw_paths, raw_key, label_paths, label_key, 410 patch_shape_min, patch_shape_max, n_forests, ndim, 411 raw_transform, label_transform, rois, is_seg_dataset, 412 filter_config, sampler, 413 ) 414 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 415 416 def _train_rf(rf_id): 417 # Sample random patch with dataset. 418 raw, labels = ds[rf_id] 419 # Cast to numpy and remove channel axis. 420 # Need to update this to support multi-channel input data and/or multi class prediction. 421 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 422 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 423 features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels) 424 rf = RandomForestClassifier(**rf_kwargs) 425 rf.fit(features, labels) 426 # Monkey patch these so that we know the feature config and dimensionality. 427 rf.feature_ndim = ndim 428 rf.feature_config = serialized_feature_config 429 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 430 with open(out_path, "wb") as f: 431 pickle.dump(rf, f) 432 433 with futures.ThreadPoolExecutor(n_threads) as tp: 434 list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests))
Prepare shallow2deep enhancer training by pre-training random forests.
Arguments:
- raw_paths: The file paths to the raw data. May also be a single file.
- raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
- label_paths: The file paths to the lable data. May also be a single file.
- label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
- patch_shape_min: The minimal patch shape loaded for training a random forest.
- patch_shape_max: The maximal patch shape loaded for training a random forest.
- n_forests: The number of random forests to train.
- n_threads: The number of threads for parallelizing the training.
- output_folder: The folder for saving the random forests.
- ndim: The dimensionality of the data.
- raw_transform: A transform to apply to the raw data before computing feautres on it.
- label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
- rois: Region of interests for the training data.
- is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. If None, this wil be determined from the data.
- balance_labels: Whether to balance the training labels for the random forest.
- filter_config: The configuration for the image filters that are used to compute features for the random forest.
- sampler: A sampler to reject samples from training.
- rf_kwargs: Keyword arguments for creating the random forest.
733def prepare_shallow2deep_advanced( 734 raw_paths: Union[str, Sequence[str]], 735 raw_key: Optional[str], 736 label_paths: Union[str, Sequence[str]], 737 label_key: Optional[str], 738 patch_shape_min: Tuple[int, ...], 739 patch_shape_max: Tuple[int, ...], 740 n_forests: int, 741 n_threads: int, 742 output_folder: str, 743 ndim: int, 744 forests_per_stage: int, 745 sample_fraction_per_stage: float, 746 sampling_strategy: Union[str, Callable] = "worst_points", 747 sampling_kwargs: Dict = {}, 748 raw_transform: Optional[Callable] = None, 749 label_transform: Optional[Callable] = None, 750 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 751 is_seg_dataset: Optional[bool] = None, 752 balance_labels: bool = True, 753 filter_config: Optional[Dict] = None, 754 sampler: Optional[Callable] = None, 755 **rf_kwargs, 756) -> None: 757 """Prepare shallow2deep enhancer training by pre-training random forests. 758 759 This function implements an advanced training procedure compared to `prepare_shallow2deep`. 760 The 'sampling_strategy' argument determines an advnaced sampling strategies, 761 which selects the samples to use for training the random forests. 762 The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests 763 are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage. 764 The random forests in stage 0 are trained from random balanced labels. 765 For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature 766 '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)', 767 and return the sampled features and labels. See for example the 'worst_points' function. 768 Alternatively, one of the pre-defined strategies can be selected by passing one of the following names: 769 - "random_poinst": Select random points. 770 - "uncertain_points": Select points with the highest uncertainty. 771 - "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies. 772 - "worst_points": Select the points with the worst accuracies. 773 - "worst_tiles": Selectt the tiles with the worst accuracies. 774 - "balanced_dense_accumulate": Balanced dense accumulation. 775 776 Args: 777 raw_paths: The file paths to the raw data. May also be a single file. 778 raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif. 779 label_paths: The file paths to the lable data. May also be a single file. 780 label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif. 781 patch_shape_min: The minimal patch shape loaded for training a random forest. 782 patch_shape_max: The maximal patch shape loaded for training a random forest. 783 n_forests: The number of random forests to train. 784 n_threads: The number of threads for parallelizing the training. 785 output_folder: The folder for saving the random forests. 786 ndim: The dimensionality of the data. 787 forests_per_stage: The number of forests to train per stage. 788 sample_fraction_per_stage: The fraction of samples to use per stage. 789 sampling_strategy: The sampling strategy. 790 sampling_kwargs: The keyword arguments for the sampling strategy. 791 raw_transform: A transform to apply to the raw data before computing feautres on it. 792 label_transform: A transform to apply to the label data before deriving targets for the random forest for it. 793 rois: Region of interests for the training data. 794 is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. 795 If None, this wil be determined from the data. 796 balance_labels: Whether to balance the training labels for the random forest. 797 filter_config: The configuration for the image filters that are used to compute features for the random forest. 798 sampler: A sampler to reject samples from training. 799 rf_kwargs: Keyword arguments for creating the random forest. 800 """ 801 os.makedirs(output_folder, exist_ok=True) 802 ds, filters_and_sigmas = _prepare_shallow2deep( 803 raw_paths, raw_key, label_paths, label_key, 804 patch_shape_min, patch_shape_max, n_forests, ndim, 805 raw_transform, label_transform, rois, is_seg_dataset, 806 filter_config, sampler, 807 ) 808 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 809 810 forests = [] 811 n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\ 812 n_forests // forests_per_stage + 1 813 814 if isinstance(sampling_strategy, str): 815 assert sampling_strategy in SAMPLING_STRATEGIES, \ 816 f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}" 817 sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy] 818 assert callable(sampling_strategy) 819 820 with tqdm(total=n_forests) as pbar: 821 822 def _train_rf(rf_id): 823 # sample random patch with dataset 824 raw, labels = ds[rf_id] 825 826 # cast to numpy and remove channel axis 827 # need to update this to support multi-channel input data and/or multi class prediction 828 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 829 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 830 831 # monkey patch original shape to sampling_kwargs 832 # deepcopy needed due to multithreading 833 current_kwargs = copy.deepcopy(sampling_kwargs) 834 current_kwargs["img_shape"] = raw.shape 835 836 # only balance samples for the first (densely trained) rfs 837 features, labels, mask = _get_features_and_labels( 838 raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True 839 ) 840 if forests: # we have forests: apply the sampling strategy 841 features, labels = sampling_strategy( 842 features, labels, rf_id, 843 forests, forests_per_stage, 844 sample_fraction_per_stage, 845 mask=mask, 846 **current_kwargs, 847 ) 848 else: # sample randomly 849 features, labels = random_points( 850 features, labels, rf_id, 851 forests, forests_per_stage, 852 sample_fraction_per_stage, 853 ) 854 855 # fit the random forest 856 assert len(features) == len(labels) 857 rf = RandomForestClassifier(**rf_kwargs) 858 rf.fit(features, labels) 859 # monkey patch these so that we know the feature config and dimensionality 860 rf.feature_ndim = ndim 861 rf.feature_config = serialized_feature_config 862 863 # save the random forest, update pbar, return it 864 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 865 with open(out_path, "wb") as f: 866 pickle.dump(rf, f) 867 868 # monkey patch the training data and labels so we can re-use it in later stages 869 rf.train_features = features 870 rf.train_labels = labels 871 872 pbar.update(1) 873 return rf 874 875 for stage in range(n_stages): 876 pbar.set_description(f"Train RFs for stage {stage}") 877 with futures.ThreadPoolExecutor(n_threads) as tp: 878 this_forests = list(tp.map( 879 _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1)) 880 )) 881 forests.extend(this_forests)
Prepare shallow2deep enhancer training by pre-training random forests.
This function implements an advanced training procedure compared to prepare_shallow2deep.
The 'sampling_strategy' argument determines an advnaced sampling strategies,
which selects the samples to use for training the random forests.
The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests
are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage.
The random forests in stage 0 are trained from random balanced labels.
For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature
'(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)',
and return the sampled features and labels. See for example the 'worst_points' function.
Alternatively, one of the pre-defined strategies can be selected by passing one of the following names:
- "random_poinst": Select random points.
- "uncertain_points": Select points with the highest uncertainty.
- "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies.
- "worst_points": Select the points with the worst accuracies.
- "worst_tiles": Selectt the tiles with the worst accuracies.
- "balanced_dense_accumulate": Balanced dense accumulation.
Arguments:
- raw_paths: The file paths to the raw data. May also be a single file.
- raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
- label_paths: The file paths to the lable data. May also be a single file.
- label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
- patch_shape_min: The minimal patch shape loaded for training a random forest.
- patch_shape_max: The maximal patch shape loaded for training a random forest.
- n_forests: The number of random forests to train.
- n_threads: The number of threads for parallelizing the training.
- output_folder: The folder for saving the random forests.
- ndim: The dimensionality of the data.
- forests_per_stage: The number of forests to train per stage.
- sample_fraction_per_stage: The fraction of samples to use per stage.
- sampling_strategy: The sampling strategy.
- sampling_kwargs: The keyword arguments for the sampling strategy.
- raw_transform: A transform to apply to the raw data before computing feautres on it.
- label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
- rois: Region of interests for the training data.
- is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. If None, this wil be determined from the data.
- balance_labels: Whether to balance the training labels for the random forest.
- filter_config: The configuration for the image filters that are used to compute features for the random forest.
- sampler: A sampler to reject samples from training.
- rf_kwargs: Keyword arguments for creating the random forest.