torch_em.shallow2deep.prepare_shallow2deep
1import os 2import copy 3import pickle 4import warnings 5from concurrent import futures 6from glob import glob 7from functools import partial 8from typing import Callable, Dict, Optional, Sequence, Tuple, Union 9 10import numpy as np 11import torch_em 12from scipy.ndimage import gaussian_filter, convolve 13from skimage.feature import peak_local_max 14from sklearn.ensemble import RandomForestClassifier 15from torch_em.segmentation import check_paths, is_segmentation_dataset, samples_to_datasets 16from tqdm import tqdm 17 18import vigra 19try: 20 import fastfilters as filter_impl 21except ImportError: 22 import vigra.filters as filter_impl 23 24 25class RFSegmentationDataset(torch_em.data.SegmentationDataset): 26 """@private 27 """ 28 _patch_shape_min = None 29 _patch_shape_max = None 30 31 @property 32 def patch_shape_min(self): 33 return self._patch_shape_min 34 35 @patch_shape_min.setter 36 def patch_shape_min(self, value): 37 self._patch_shape_min = value 38 39 @property 40 def patch_shape_max(self): 41 return self._patch_shape_max 42 43 @patch_shape_max.setter 44 def patch_shape_max(self, value): 45 self._patch_shape_max = value 46 47 def _sample_bounding_box(self): 48 assert self._patch_shape_min is not None and self._patch_shape_max is not None 49 sample_shape = [ 50 pmin if pmin == pmax else np.random.randint(pmin, pmax) 51 for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max) 52 ] 53 bb_start = [ 54 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 55 for sh, psh in zip(self.shape, sample_shape) 56 ] 57 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, sample_shape)) 58 59 60class RFImageCollectionDataset(torch_em.data.ImageCollectionDataset): 61 """@private 62 """ 63 _patch_shape_min = None 64 _patch_shape_max = None 65 66 @property 67 def patch_shape_min(self): 68 return self._patch_shape_min 69 70 @patch_shape_min.setter 71 def patch_shape_min(self, value): 72 self._patch_shape_min = value 73 74 @property 75 def patch_shape_max(self): 76 return self._patch_shape_max 77 78 @patch_shape_max.setter 79 def patch_shape_max(self, value): 80 self._patch_shape_max = value 81 82 def _sample_bounding_box(self, shape): 83 if any(sh < psh for sh, psh in zip(shape, self.patch_shape_max)): 84 raise NotImplementedError("Image padding is not supported yet.") 85 assert self._patch_shape_min is not None and self._patch_shape_max is not None 86 patch_shape = [ 87 pmin if pmin == pmax else np.random.randint(pmin, pmax) 88 for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max) 89 ] 90 bb_start = [ 91 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, patch_shape) 92 ] 93 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape)) 94 95 96def _load_rf_segmentation_dataset( 97 raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, **kwargs 98): 99 rois = kwargs.pop("rois", None) 100 sampler = kwargs.pop("sampler", None) 101 sampler = sampler if sampler else torch_em.data.MinForegroundSampler(min_fraction=0.01) 102 if isinstance(raw_paths, str): 103 if rois is not None: 104 assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois) 105 ds = RFSegmentationDataset( 106 raw_paths, raw_key, label_paths, label_key, roi=rois, patch_shape=patch_shape_min, sampler=sampler, **kwargs 107 ) 108 ds.patch_shape_min = patch_shape_min 109 ds.patch_shape_max = patch_shape_max 110 else: 111 assert len(raw_paths) > 0 112 if rois is not None: 113 assert len(rois) == len(label_paths) 114 assert all(isinstance(roi, tuple) for roi in rois) 115 n_samples = kwargs.pop("n_samples", None) 116 117 samples_per_ds = ( 118 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 119 ) 120 ds = [] 121 for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): 122 roi = None if rois is None else rois[i] 123 dset = RFSegmentationDataset( 124 raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], 125 patch_shape=patch_shape_min, sampler=sampler, **kwargs 126 ) 127 dset.patch_shape_min = patch_shape_min 128 dset.patch_shape_max = patch_shape_max 129 ds.append(dset) 130 ds = torch_em.data.ConcatDataset(*ds) 131 return ds 132 133 134def _load_rf_image_collection_dataset( 135 raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, roi, **kwargs 136): 137 def _get_paths(rpath, rkey, lpath, lkey, this_roi): 138 rpath = glob(os.path.join(rpath, rkey)) 139 rpath.sort() 140 if len(rpath) == 0: 141 raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}") 142 lpath = glob(os.path.join(lpath, lkey)) 143 lpath.sort() 144 if len(rpath) != len(lpath): 145 raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}") 146 147 if this_roi is not None: 148 rpath, lpath = rpath[roi], lpath[roi] 149 150 return rpath, lpath 151 152 def _check_patch(patch_shape): 153 if len(patch_shape) == 3: 154 if patch_shape[0] != 1: 155 raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}") 156 patch_shape = patch_shape[1:] 157 assert len(patch_shape) == 2 158 return patch_shape 159 160 patch_shape_min = _check_patch(patch_shape_min) 161 patch_shape_max = _check_patch(patch_shape_max) 162 163 if isinstance(raw_paths, str): 164 raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi) 165 ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs) 166 ds.patch_shape_min = patch_shape_min 167 ds.patch_shape_max = patch_shape_max 168 elif raw_key is None: 169 assert label_key is None 170 assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple)) 171 assert len(raw_paths) == len(label_paths) 172 ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs) 173 ds.patch_shape_min = patch_shape_min 174 ds.patch_shape_max = patch_shape_max 175 else: 176 ds = [] 177 n_samples = kwargs.pop("n_samples", None) 178 samples_per_ds = ( 179 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 180 ) 181 if roi is None: 182 roi = len(raw_paths) * [None] 183 assert len(roi) == len(raw_paths) 184 for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)): 185 rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi) 186 dset = RFImageCollectionDataset( 187 rpath, lpath, patch_shape=patch_shape_min, n_samples=samples_per_ds[i], **kwargs 188 ) 189 dset.patch_shape_min = patch_shape_min 190 dset.patch_shape_max = patch_shape_max 191 ds.append(dset) 192 ds = torch_em.data.ConcatDataset(*ds) 193 return ds 194 195 196def _get_filters(ndim, filters_and_sigmas): 197 # subset of ilastik default features 198 if filters_and_sigmas is None: 199 filters = [filter_impl.gaussianSmoothing, 200 filter_impl.laplacianOfGaussian, 201 filter_impl.gaussianGradientMagnitude, 202 filter_impl.hessianOfGaussianEigenvalues, 203 filter_impl.structureTensorEigenvalues] 204 sigmas = [0.7, 1.6, 3.5, 5.0] 205 filters_and_sigmas = [ 206 (filt, sigma) if i != len(filters) - 1 else (partial(filt, outerScale=0.5*sigma), sigma) 207 for i, filt in enumerate(filters) for sigma in sigmas 208 ] 209 # validate the filter config 210 assert isinstance(filters_and_sigmas, (list, tuple)) 211 for filt_and_sig in filters_and_sigmas: 212 filt, sig = filt_and_sig 213 assert callable(filt) or (isinstance(filt, str) and hasattr(filter_impl, filt)) 214 assert isinstance(sig, (float, tuple)) 215 if isinstance(sig, tuple): 216 assert ndim is not None and len(sig) == ndim 217 assert all(isinstance(sigg, float) for sigg in sig) 218 return filters_and_sigmas 219 220 221def _calculate_response(raw, filter_, sigma): 222 if callable(filter_): 223 return filter_(raw, sigma) 224 225 # filter_ is still string, convert it to function 226 # fastfilters does not support passing sigma as tuple 227 func = getattr(vigra.filters, filter_) if isinstance(sigma, tuple) else getattr(filter_impl, filter_) 228 229 # special case since additional argument outerScale 230 # is needed for structureTensorEigenvalues functions 231 if filter_ == "structureTensorEigenvalues": 232 outerScale = tuple([s*2 for s in sigma]) if isinstance(sigma, tuple) else 2*sigma 233 return func(raw, sigma, outerScale=outerScale) 234 235 return func(raw, sigma) 236 237 238def _apply_filters(raw, filters_and_sigmas): 239 features = [] 240 for filter_, sigma in filters_and_sigmas: 241 response = _calculate_response(raw, filter_, sigma) 242 if response.ndim > raw.ndim: 243 for c in range(response.shape[-1]): 244 features.append(response[..., c].flatten()) 245 else: 246 features.append(response.flatten()) 247 features = np.concatenate([ff[:, None] for ff in features], axis=1) 248 return features 249 250 251def _apply_filters_with_mask(raw, filters_and_sigmas, mask): 252 features = [] 253 for filter_, sigma in filters_and_sigmas: 254 response = _calculate_response(raw, filter_, sigma) 255 if response.ndim > raw.ndim: 256 for c in range(response.shape[-1]): 257 features.append(response[..., c][mask]) 258 else: 259 features.append(response[mask]) 260 features = np.concatenate([ff[:, None] for ff in features], axis=1) 261 return features 262 263 264def _balance_labels(labels, mask): 265 class_ids, label_counts = np.unique(labels[mask], return_counts=True) 266 n_classes = len(class_ids) 267 assert class_ids.tolist() == list(range(n_classes)) 268 269 min_class = class_ids[np.argmin(label_counts)] 270 n_labels = label_counts[min_class] 271 272 for class_id in class_ids: 273 if class_id == min_class: 274 continue 275 n_discard = label_counts[class_id] - n_labels 276 # sample from the current class 277 # shuffle the positions and only keep up to n_labels in the mask 278 label_pos = np.where(labels == class_id) 279 discard_ids = np.arange(len(label_pos[0])) 280 np.random.shuffle(discard_ids) 281 discard_ids = discard_ids[:n_discard] 282 discard_mask = tuple(pos[discard_ids] for pos in label_pos) 283 mask[discard_mask] = False 284 285 assert mask.sum() == n_classes * n_labels 286 return mask 287 288 289def _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels, return_mask=False): 290 # find the mask for where we compute filters and labels 291 # by default we exclude everything that has label -1 292 assert labels.shape == raw.shape 293 mask = labels != -1 294 if balance_labels: 295 mask = _balance_labels(labels, mask) 296 labels = labels[mask] 297 assert labels.ndim == 1 298 features = _apply_filters_with_mask(raw, filters_and_sigmas, mask) 299 assert features.ndim == 2 300 assert len(features) == len(labels) 301 if return_mask: 302 return features, labels, mask 303 else: 304 return features, labels 305 306 307def _prepare_shallow2deep( 308 raw_paths, 309 raw_key, 310 label_paths, 311 label_key, 312 patch_shape_min, 313 patch_shape_max, 314 n_forests, 315 ndim, 316 raw_transform, 317 label_transform, 318 rois, 319 is_seg_dataset, 320 filter_config, 321 sampler, 322): 323 assert len(patch_shape_min) == len(patch_shape_max) 324 assert all(maxs >= mins for maxs, mins in zip(patch_shape_max, patch_shape_min)) 325 check_paths(raw_paths, label_paths) 326 327 # get the correct dataset 328 if is_seg_dataset is None: 329 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 330 if is_seg_dataset: 331 ds = _load_rf_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, 332 patch_shape_min, patch_shape_max, 333 raw_transform=raw_transform, label_transform=label_transform, 334 rois=rois, n_samples=n_forests, sampler=sampler) 335 else: 336 ds = _load_rf_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, 337 patch_shape_min, patch_shape_max, roi=rois, 338 raw_transform=raw_transform, label_transform=label_transform, 339 n_samples=n_forests) 340 341 assert len(ds) == n_forests, f"{len(ds), {n_forests}}" 342 filters_and_sigmas = _get_filters(ndim, filter_config) 343 return ds, filters_and_sigmas 344 345 346def _serialize_feature_config(filters_and_sigmas): 347 feature_config = [ 348 (filt if isinstance(filt, str) else (filt.func.__name__ if isinstance(filt, partial) else filt.__name__), sigma) 349 for filt, sigma in filters_and_sigmas 350 ] 351 return feature_config 352 353 354def prepare_shallow2deep( 355 raw_paths: Union[str, Sequence[str]], 356 raw_key: Optional[str], 357 label_paths: Union[str, Sequence[str]], 358 label_key: Optional[str], 359 patch_shape_min: Tuple[int, ...], 360 patch_shape_max: Tuple[int, ...], 361 n_forests: int, 362 n_threads: int, 363 output_folder: str, 364 ndim: int, 365 raw_transform: Optional[Callable] = None, 366 label_transform: Optional[Callable] = None, 367 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 368 is_seg_dataset: Optional[bool] = None, 369 balance_labels: bool = True, 370 filter_config: Optional[Dict] = None, 371 sampler: Optional[Callable] = None, 372 **rf_kwargs, 373) -> None: 374 """Prepare shallow2deep enhancer training by pre-training random forests. 375 376 Args: 377 raw_paths: The file paths to the raw data. May also be a single file. 378 raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif. 379 label_paths: The file paths to the lable data. May also be a single file. 380 label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif. 381 patch_shape_min: The minimal patch shape loaded for training a random forest. 382 patch_shape_max: The maximal patch shape loaded for training a random forest. 383 n_forests: The number of random forests to train. 384 n_threads: The number of threads for parallelizing the training. 385 output_folder: The folder for saving the random forests. 386 ndim: The dimensionality of the data. 387 raw_transform: A transform to apply to the raw data before computing feautres on it. 388 label_transform: A transform to apply to the label data before deriving targets for the random forest for it. 389 rois: Region of interests for the training data. 390 is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. 391 If None, this wil be determined from the data. 392 balance_labels: Whether to balance the training labels for the random forest. 393 filter_config: The configuration for the image filters that are used to compute features for the random forest. 394 sampler: A sampler to reject samples from training. 395 rf_kwargs: Keyword arguments for creating the random forest. 396 """ 397 os.makedirs(output_folder, exist_ok=True) 398 ds, filters_and_sigmas = _prepare_shallow2deep( 399 raw_paths, raw_key, label_paths, label_key, 400 patch_shape_min, patch_shape_max, n_forests, ndim, 401 raw_transform, label_transform, rois, is_seg_dataset, 402 filter_config, sampler, 403 ) 404 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 405 406 def _train_rf(rf_id): 407 # Sample random patch with dataset. 408 raw, labels = ds[rf_id] 409 # Cast to numpy and remove channel axis. 410 # Need to update this to support multi-channel input data and/or multi class prediction. 411 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 412 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 413 features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels) 414 rf = RandomForestClassifier(**rf_kwargs) 415 rf.fit(features, labels) 416 # Monkey patch these so that we know the feature config and dimensionality. 417 rf.feature_ndim = ndim 418 rf.feature_config = serialized_feature_config 419 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 420 with open(out_path, "wb") as f: 421 pickle.dump(rf, f) 422 423 with futures.ThreadPoolExecutor(n_threads) as tp: 424 list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests)) 425 426 427def _score_based_points( 428 score_function, 429 features, labels, rf_id, 430 forests, forests_per_stage, 431 sample_fraction_per_stage, 432 accumulate_samples, 433): 434 # get the corresponding random forest from the last stage 435 # and predict with it 436 last_forest = forests[rf_id - forests_per_stage] 437 pred = last_forest.predict_proba(features) 438 439 score = score_function(pred, labels) 440 assert len(score) == len(features) 441 442 # get training samples based on the label-prediction diff 443 samples = [] 444 nc = len(np.unique(labels)) 445 # sample in a class balanced way 446 n_samples = int(sample_fraction_per_stage * len(features)) 447 n_samples_class = n_samples // nc 448 for class_id in range(nc): 449 class_indices = np.where(labels == class_id)[0] 450 this_samples = class_indices[np.argsort(score[class_indices])[::-1][:n_samples_class]] 451 samples.append(this_samples) 452 samples = np.concatenate(samples) 453 454 # get the features and labels, add from previous rf if specified 455 features, labels = features[samples], labels[samples] 456 if accumulate_samples: 457 features = np.concatenate([last_forest.train_features, features], axis=0) 458 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 459 460 return features, labels 461 462 463def worst_points( 464 features, labels, rf_id, 465 forests, forests_per_stage, 466 sample_fraction_per_stage, 467 accumulate_samples=True, 468 **kwargs, 469): 470 """@private 471 """ 472 def score(pred, labels): 473 # labels to one-hot encoding 474 unique, inverse = np.unique(labels, return_inverse=True) 475 onehot = np.eye(unique.shape[0])[inverse] 476 # compute the difference between labels and prediction 477 return np.abs(onehot - pred).sum(axis=1) 478 479 return _score_based_points( 480 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 481 ) 482 483 484def uncertain_points( 485 features, labels, rf_id, 486 forests, forests_per_stage, 487 sample_fraction_per_stage, 488 accumulate_samples=True, 489 **kwargs, 490): 491 """@private 492 """ 493 def score(pred, labels): 494 assert pred.ndim == 2 495 channel_sorted = np.sort(pred, axis=1) 496 uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2] 497 return uncertainty 498 499 return _score_based_points( 500 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 501 ) 502 503 504def uncertain_worst_points( 505 features, labels, rf_id, 506 forests, forests_per_stage, 507 sample_fraction_per_stage, 508 accumulate_samples=True, 509 alpha=0.5, 510 **kwargs, 511): 512 """@private 513 """ 514 def score(pred, labels): 515 assert pred.ndim == 2 516 517 # labels to one-hot encoding 518 unique, inverse = np.unique(labels, return_inverse=True) 519 onehot = np.eye(unique.shape[0])[inverse] 520 # compute the difference between labels and prediction 521 diff = np.abs(onehot - pred).sum(axis=1) 522 523 channel_sorted = np.sort(pred, axis=1) 524 uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2] 525 return alpha * diff + (1.0 - alpha) * uncertainty 526 527 return _score_based_points( 528 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 529 ) 530 531 532def random_points( 533 features, labels, rf_id, 534 forests, forests_per_stage, 535 sample_fraction_per_stage, 536 accumulate_samples=True, 537 **kwargs, 538): 539 """@private 540 """ 541 samples = [] 542 nc = len(np.unique(labels)) 543 # sample in a class balanced way 544 n_samples = int(sample_fraction_per_stage * len(features)) 545 n_samples_class = n_samples // nc 546 for class_id in range(nc): 547 class_indices = np.where(labels == class_id)[0] 548 this_samples = np.random.choice( 549 class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class 550 ) 551 samples.append(this_samples) 552 samples = np.concatenate(samples) 553 features, labels = features[samples], labels[samples] 554 555 if accumulate_samples and rf_id >= forests_per_stage: 556 last_forest = forests[rf_id - forests_per_stage] 557 features = np.concatenate([last_forest.train_features, features], axis=0) 558 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 559 560 return features, labels 561 562 563def worst_tiles( 564 features, labels, rf_id, 565 forests, forests_per_stage, 566 sample_fraction_per_stage, 567 img_shape, 568 mask, 569 tile_shape=[25, 25], 570 smoothing_sigma=None, 571 accumulate_samples=True, 572 **kwargs, 573): 574 """@private 575 """ 576 # check inputs 577 ndim = len(img_shape) 578 assert ndim in [2, 3], img_shape 579 assert len(tile_shape) == ndim, tile_shape 580 581 # get the corresponding random forest from the last stage 582 # and predict with it 583 last_forest = forests[rf_id - forests_per_stage] 584 pred = last_forest.predict_proba(features) 585 586 # labels to one-hot encoding 587 unique, inverse = np.unique(labels, return_inverse=True) 588 onehot = np.eye(unique.shape[0])[inverse] 589 590 # compute the difference between labels and prediction 591 diff = np.abs(onehot - pred) 592 assert len(diff) == len(features) 593 594 # reshape diff to image shape 595 # we need to also take into account the mask here, and if we apply any masking 596 # because we can't directly reshape if we have it 597 if mask.sum() != mask.size: 598 # get the diff image 599 diff_img = np.zeros(img_shape + diff.shape[-1:], dtype=diff.dtype) 600 diff_img[mask] = diff 601 # inflate the features 602 full_features = np.zeros((mask.size,) + features.shape[-1:], dtype=features.dtype) 603 full_features[mask.ravel()] = features 604 features = full_features 605 # inflate the labels (with -1 so this will not be sampled) 606 full_labels = np.full(mask.size, -1, dtype="int8") 607 full_labels[mask.ravel()] = labels 608 labels = full_labels 609 else: 610 diff_img = diff.reshape(img_shape + (-1,)) 611 612 # get the number of classes (not counting ignore label) 613 class_ids = np.unique(labels) 614 nc = len(class_ids) - 1 if -1 in class_ids else len(class_ids) 615 616 # sample in a class balanced way 617 n_samples_class = int(sample_fraction_per_stage * len(features)) // nc 618 samples = [] 619 for class_id in range(nc): 620 # smooth either with gaussian or 1-kernel 621 if smoothing_sigma: 622 diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode="constant") 623 else: 624 kernel = np.ones(tile_shape) 625 diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode="constant") 626 627 # get training samples based on tiles around maxima of the label-prediction diff 628 # do this in a class-specific way to ensure that each class is sampled 629 # get maxima of the label-prediction diff (they seem to be sorted already) 630 max_centers = peak_local_max( 631 diff_img_smooth, 632 min_distance=max(tile_shape), 633 exclude_border=tuple([s // 2 for s in tile_shape]) 634 ) 635 636 # get indices of tiles around maxima 637 tiles = [] 638 for center in max_centers: 639 tile_slice = tuple( 640 slice( 641 center[d]-tile_shape[d]//2, 642 center[d]+tile_shape[d]//2 + 1, 643 None 644 ) for d in range(ndim) 645 ) 646 grid = np.mgrid[tile_slice] 647 samples_in_tile = grid.reshape(ndim, -1) 648 samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape) 649 tiles.append(samples_in_tile) 650 651 # this (very rarely) fails due to empty tile list. Since we usually 652 # accumulate the features this doesn't hurt much and we can continue 653 try: 654 tiles = np.concatenate(tiles) 655 # take samples that belong to the current class 656 this_samples = tiles[labels[tiles] == class_id][:n_samples_class] 657 samples.append(this_samples) 658 except ValueError: 659 pass 660 661 try: 662 samples = np.concatenate(samples) 663 features, labels = features[samples], labels[samples] 664 665 # get the features and labels, add from previous rf if specified 666 if accumulate_samples: 667 features = np.concatenate([last_forest.train_features, features], axis=0) 668 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 669 except ValueError: 670 features, labels = last_forest.train_features, last_forest.train_labels 671 warnings.warn( 672 f"No features were sampled for forest {rf_id} using features of forest {rf_id - forests_per_stage}" 673 ) 674 675 return features, labels 676 677 678def balanced_dense_accumulate( 679 features, labels, rf_id, 680 forests, forests_per_stage, 681 sample_fraction_per_stage, 682 accumulate_samples=True, 683 **kwargs, 684): 685 """@private 686 """ 687 samples = [] 688 nc = len(np.unique(labels)) 689 # sample in a class balanced way 690 # take all pixels from minority class 691 # and choose same amount from other classes randomly 692 n_samples_class = np.unique(labels, return_counts=True)[1].min() 693 for class_id in range(nc): 694 class_indices = np.where(labels == class_id)[0] 695 this_samples = np.random.choice( 696 class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class 697 ) 698 samples.append(this_samples) 699 samples = np.concatenate(samples) 700 features, labels = features[samples], labels[samples] 701 702 # accumulate 703 if accumulate_samples and rf_id >= forests_per_stage: 704 last_forest = forests[rf_id - forests_per_stage] 705 features = np.concatenate([last_forest.train_features, features], axis=0) 706 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 707 708 return features, labels 709 710 711SAMPLING_STRATEGIES = { 712 "random_points": random_points, 713 "uncertain_points": uncertain_points, 714 "uncertain_worst_points": uncertain_worst_points, 715 "worst_points": worst_points, 716 "worst_tiles": worst_tiles, 717 "balanced_dense_accumulate": balanced_dense_accumulate, 718} 719"""@private 720""" 721 722 723def prepare_shallow2deep_advanced( 724 raw_paths: Union[str, Sequence[str]], 725 raw_key: Optional[str], 726 label_paths: Union[str, Sequence[str]], 727 label_key: Optional[str], 728 patch_shape_min: Tuple[int, ...], 729 patch_shape_max: Tuple[int, ...], 730 n_forests: int, 731 n_threads: int, 732 output_folder: str, 733 ndim: int, 734 forests_per_stage: int, 735 sample_fraction_per_stage: float, 736 sampling_strategy: Union[str, Callable] = "worst_points", 737 sampling_kwargs: Dict = {}, 738 raw_transform: Optional[Callable] = None, 739 label_transform: Optional[Callable] = None, 740 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 741 is_seg_dataset: Optional[bool] = None, 742 balance_labels: bool = True, 743 filter_config: Optional[Dict] = None, 744 sampler: Optional[Callable] = None, 745 **rf_kwargs, 746) -> None: 747 """Prepare shallow2deep enhancer training by pre-training random forests. 748 749 This function implements an advanced training procedure compared to `prepare_shallow2deep`. 750 The 'sampling_strategy' argument determines an advnaced sampling strategies, 751 which selects the samples to use for training the random forests. 752 The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests 753 are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage. 754 The random forests in stage 0 are trained from random balanced labels. 755 For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature 756 '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)', 757 and return the sampled features and labels. See for example the 'worst_points' function. 758 Alternatively, one of the pre-defined strategies can be selected by passing one of the following names: 759 - "random_poinst": Select random points. 760 - "uncertain_points": Select points with the highest uncertainty. 761 - "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies. 762 - "worst_points": Select the points with the worst accuracies. 763 - "worst_tiles": Selectt the tiles with the worst accuracies. 764 - "balanced_dense_accumulate": Balanced dense accumulation. 765 766 Args: 767 raw_paths: The file paths to the raw data. May also be a single file. 768 raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif. 769 label_paths: The file paths to the lable data. May also be a single file. 770 label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif. 771 patch_shape_min: The minimal patch shape loaded for training a random forest. 772 patch_shape_max: The maximal patch shape loaded for training a random forest. 773 n_forests: The number of random forests to train. 774 n_threads: The number of threads for parallelizing the training. 775 output_folder: The folder for saving the random forests. 776 ndim: The dimensionality of the data. 777 forests_per_stage: The number of forests to train per stage. 778 sample_fraction_per_stage: The fraction of samples to use per stage. 779 sampling_strategy: The sampling strategy. 780 sampling_kwargs: The keyword arguments for the sampling strategy. 781 raw_transform: A transform to apply to the raw data before computing feautres on it. 782 label_transform: A transform to apply to the label data before deriving targets for the random forest for it. 783 rois: Region of interests for the training data. 784 is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. 785 If None, this wil be determined from the data. 786 balance_labels: Whether to balance the training labels for the random forest. 787 filter_config: The configuration for the image filters that are used to compute features for the random forest. 788 sampler: A sampler to reject samples from training. 789 rf_kwargs: Keyword arguments for creating the random forest. 790 """ 791 os.makedirs(output_folder, exist_ok=True) 792 ds, filters_and_sigmas = _prepare_shallow2deep( 793 raw_paths, raw_key, label_paths, label_key, 794 patch_shape_min, patch_shape_max, n_forests, ndim, 795 raw_transform, label_transform, rois, is_seg_dataset, 796 filter_config, sampler, 797 ) 798 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 799 800 forests = [] 801 n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\ 802 n_forests // forests_per_stage + 1 803 804 if isinstance(sampling_strategy, str): 805 assert sampling_strategy in SAMPLING_STRATEGIES, \ 806 f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}" 807 sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy] 808 assert callable(sampling_strategy) 809 810 with tqdm(total=n_forests) as pbar: 811 812 def _train_rf(rf_id): 813 # sample random patch with dataset 814 raw, labels = ds[rf_id] 815 816 # cast to numpy and remove channel axis 817 # need to update this to support multi-channel input data and/or multi class prediction 818 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 819 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 820 821 # monkey patch original shape to sampling_kwargs 822 # deepcopy needed due to multithreading 823 current_kwargs = copy.deepcopy(sampling_kwargs) 824 current_kwargs["img_shape"] = raw.shape 825 826 # only balance samples for the first (densely trained) rfs 827 features, labels, mask = _get_features_and_labels( 828 raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True 829 ) 830 if forests: # we have forests: apply the sampling strategy 831 features, labels = sampling_strategy( 832 features, labels, rf_id, 833 forests, forests_per_stage, 834 sample_fraction_per_stage, 835 mask=mask, 836 **current_kwargs, 837 ) 838 else: # sample randomly 839 features, labels = random_points( 840 features, labels, rf_id, 841 forests, forests_per_stage, 842 sample_fraction_per_stage, 843 ) 844 845 # fit the random forest 846 assert len(features) == len(labels) 847 rf = RandomForestClassifier(**rf_kwargs) 848 rf.fit(features, labels) 849 # monkey patch these so that we know the feature config and dimensionality 850 rf.feature_ndim = ndim 851 rf.feature_config = serialized_feature_config 852 853 # save the random forest, update pbar, return it 854 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 855 with open(out_path, "wb") as f: 856 pickle.dump(rf, f) 857 858 # monkey patch the training data and labels so we can re-use it in later stages 859 rf.train_features = features 860 rf.train_labels = labels 861 862 pbar.update(1) 863 return rf 864 865 for stage in range(n_stages): 866 pbar.set_description(f"Train RFs for stage {stage}") 867 with futures.ThreadPoolExecutor(n_threads) as tp: 868 this_forests = list(tp.map( 869 _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1)) 870 )) 871 forests.extend(this_forests)
355def prepare_shallow2deep( 356 raw_paths: Union[str, Sequence[str]], 357 raw_key: Optional[str], 358 label_paths: Union[str, Sequence[str]], 359 label_key: Optional[str], 360 patch_shape_min: Tuple[int, ...], 361 patch_shape_max: Tuple[int, ...], 362 n_forests: int, 363 n_threads: int, 364 output_folder: str, 365 ndim: int, 366 raw_transform: Optional[Callable] = None, 367 label_transform: Optional[Callable] = None, 368 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 369 is_seg_dataset: Optional[bool] = None, 370 balance_labels: bool = True, 371 filter_config: Optional[Dict] = None, 372 sampler: Optional[Callable] = None, 373 **rf_kwargs, 374) -> None: 375 """Prepare shallow2deep enhancer training by pre-training random forests. 376 377 Args: 378 raw_paths: The file paths to the raw data. May also be a single file. 379 raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif. 380 label_paths: The file paths to the lable data. May also be a single file. 381 label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif. 382 patch_shape_min: The minimal patch shape loaded for training a random forest. 383 patch_shape_max: The maximal patch shape loaded for training a random forest. 384 n_forests: The number of random forests to train. 385 n_threads: The number of threads for parallelizing the training. 386 output_folder: The folder for saving the random forests. 387 ndim: The dimensionality of the data. 388 raw_transform: A transform to apply to the raw data before computing feautres on it. 389 label_transform: A transform to apply to the label data before deriving targets for the random forest for it. 390 rois: Region of interests for the training data. 391 is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. 392 If None, this wil be determined from the data. 393 balance_labels: Whether to balance the training labels for the random forest. 394 filter_config: The configuration for the image filters that are used to compute features for the random forest. 395 sampler: A sampler to reject samples from training. 396 rf_kwargs: Keyword arguments for creating the random forest. 397 """ 398 os.makedirs(output_folder, exist_ok=True) 399 ds, filters_and_sigmas = _prepare_shallow2deep( 400 raw_paths, raw_key, label_paths, label_key, 401 patch_shape_min, patch_shape_max, n_forests, ndim, 402 raw_transform, label_transform, rois, is_seg_dataset, 403 filter_config, sampler, 404 ) 405 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 406 407 def _train_rf(rf_id): 408 # Sample random patch with dataset. 409 raw, labels = ds[rf_id] 410 # Cast to numpy and remove channel axis. 411 # Need to update this to support multi-channel input data and/or multi class prediction. 412 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 413 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 414 features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels) 415 rf = RandomForestClassifier(**rf_kwargs) 416 rf.fit(features, labels) 417 # Monkey patch these so that we know the feature config and dimensionality. 418 rf.feature_ndim = ndim 419 rf.feature_config = serialized_feature_config 420 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 421 with open(out_path, "wb") as f: 422 pickle.dump(rf, f) 423 424 with futures.ThreadPoolExecutor(n_threads) as tp: 425 list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests))
Prepare shallow2deep enhancer training by pre-training random forests.
Arguments:
- raw_paths: The file paths to the raw data. May also be a single file.
- raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
- label_paths: The file paths to the lable data. May also be a single file.
- label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
- patch_shape_min: The minimal patch shape loaded for training a random forest.
- patch_shape_max: The maximal patch shape loaded for training a random forest.
- n_forests: The number of random forests to train.
- n_threads: The number of threads for parallelizing the training.
- output_folder: The folder for saving the random forests.
- ndim: The dimensionality of the data.
- raw_transform: A transform to apply to the raw data before computing feautres on it.
- label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
- rois: Region of interests for the training data.
- is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. If None, this wil be determined from the data.
- balance_labels: Whether to balance the training labels for the random forest.
- filter_config: The configuration for the image filters that are used to compute features for the random forest.
- sampler: A sampler to reject samples from training.
- rf_kwargs: Keyword arguments for creating the random forest.
724def prepare_shallow2deep_advanced( 725 raw_paths: Union[str, Sequence[str]], 726 raw_key: Optional[str], 727 label_paths: Union[str, Sequence[str]], 728 label_key: Optional[str], 729 patch_shape_min: Tuple[int, ...], 730 patch_shape_max: Tuple[int, ...], 731 n_forests: int, 732 n_threads: int, 733 output_folder: str, 734 ndim: int, 735 forests_per_stage: int, 736 sample_fraction_per_stage: float, 737 sampling_strategy: Union[str, Callable] = "worst_points", 738 sampling_kwargs: Dict = {}, 739 raw_transform: Optional[Callable] = None, 740 label_transform: Optional[Callable] = None, 741 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 742 is_seg_dataset: Optional[bool] = None, 743 balance_labels: bool = True, 744 filter_config: Optional[Dict] = None, 745 sampler: Optional[Callable] = None, 746 **rf_kwargs, 747) -> None: 748 """Prepare shallow2deep enhancer training by pre-training random forests. 749 750 This function implements an advanced training procedure compared to `prepare_shallow2deep`. 751 The 'sampling_strategy' argument determines an advnaced sampling strategies, 752 which selects the samples to use for training the random forests. 753 The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests 754 are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage. 755 The random forests in stage 0 are trained from random balanced labels. 756 For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature 757 '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)', 758 and return the sampled features and labels. See for example the 'worst_points' function. 759 Alternatively, one of the pre-defined strategies can be selected by passing one of the following names: 760 - "random_poinst": Select random points. 761 - "uncertain_points": Select points with the highest uncertainty. 762 - "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies. 763 - "worst_points": Select the points with the worst accuracies. 764 - "worst_tiles": Selectt the tiles with the worst accuracies. 765 - "balanced_dense_accumulate": Balanced dense accumulation. 766 767 Args: 768 raw_paths: The file paths to the raw data. May also be a single file. 769 raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif. 770 label_paths: The file paths to the lable data. May also be a single file. 771 label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif. 772 patch_shape_min: The minimal patch shape loaded for training a random forest. 773 patch_shape_max: The maximal patch shape loaded for training a random forest. 774 n_forests: The number of random forests to train. 775 n_threads: The number of threads for parallelizing the training. 776 output_folder: The folder for saving the random forests. 777 ndim: The dimensionality of the data. 778 forests_per_stage: The number of forests to train per stage. 779 sample_fraction_per_stage: The fraction of samples to use per stage. 780 sampling_strategy: The sampling strategy. 781 sampling_kwargs: The keyword arguments for the sampling strategy. 782 raw_transform: A transform to apply to the raw data before computing feautres on it. 783 label_transform: A transform to apply to the label data before deriving targets for the random forest for it. 784 rois: Region of interests for the training data. 785 is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. 786 If None, this wil be determined from the data. 787 balance_labels: Whether to balance the training labels for the random forest. 788 filter_config: The configuration for the image filters that are used to compute features for the random forest. 789 sampler: A sampler to reject samples from training. 790 rf_kwargs: Keyword arguments for creating the random forest. 791 """ 792 os.makedirs(output_folder, exist_ok=True) 793 ds, filters_and_sigmas = _prepare_shallow2deep( 794 raw_paths, raw_key, label_paths, label_key, 795 patch_shape_min, patch_shape_max, n_forests, ndim, 796 raw_transform, label_transform, rois, is_seg_dataset, 797 filter_config, sampler, 798 ) 799 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 800 801 forests = [] 802 n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\ 803 n_forests // forests_per_stage + 1 804 805 if isinstance(sampling_strategy, str): 806 assert sampling_strategy in SAMPLING_STRATEGIES, \ 807 f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}" 808 sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy] 809 assert callable(sampling_strategy) 810 811 with tqdm(total=n_forests) as pbar: 812 813 def _train_rf(rf_id): 814 # sample random patch with dataset 815 raw, labels = ds[rf_id] 816 817 # cast to numpy and remove channel axis 818 # need to update this to support multi-channel input data and/or multi class prediction 819 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 820 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 821 822 # monkey patch original shape to sampling_kwargs 823 # deepcopy needed due to multithreading 824 current_kwargs = copy.deepcopy(sampling_kwargs) 825 current_kwargs["img_shape"] = raw.shape 826 827 # only balance samples for the first (densely trained) rfs 828 features, labels, mask = _get_features_and_labels( 829 raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True 830 ) 831 if forests: # we have forests: apply the sampling strategy 832 features, labels = sampling_strategy( 833 features, labels, rf_id, 834 forests, forests_per_stage, 835 sample_fraction_per_stage, 836 mask=mask, 837 **current_kwargs, 838 ) 839 else: # sample randomly 840 features, labels = random_points( 841 features, labels, rf_id, 842 forests, forests_per_stage, 843 sample_fraction_per_stage, 844 ) 845 846 # fit the random forest 847 assert len(features) == len(labels) 848 rf = RandomForestClassifier(**rf_kwargs) 849 rf.fit(features, labels) 850 # monkey patch these so that we know the feature config and dimensionality 851 rf.feature_ndim = ndim 852 rf.feature_config = serialized_feature_config 853 854 # save the random forest, update pbar, return it 855 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 856 with open(out_path, "wb") as f: 857 pickle.dump(rf, f) 858 859 # monkey patch the training data and labels so we can re-use it in later stages 860 rf.train_features = features 861 rf.train_labels = labels 862 863 pbar.update(1) 864 return rf 865 866 for stage in range(n_stages): 867 pbar.set_description(f"Train RFs for stage {stage}") 868 with futures.ThreadPoolExecutor(n_threads) as tp: 869 this_forests = list(tp.map( 870 _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1)) 871 )) 872 forests.extend(this_forests)
Prepare shallow2deep enhancer training by pre-training random forests.
This function implements an advanced training procedure compared to prepare_shallow2deep
.
The 'sampling_strategy' argument determines an advnaced sampling strategies,
which selects the samples to use for training the random forests.
The random forest training operates in stages, the parameter 'forests_per_stage' determines how many forests
are trained in each stage, and 'sample_fraction_per_stage' determines which fraction of samples is used per stage.
The random forests in stage 0 are trained from random balanced labels.
For the other stages 'sampling_strategy' determines the strategy; it has to be a function with signature
'(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)',
and return the sampled features and labels. See for example the 'worst_points' function.
Alternatively, one of the pre-defined strategies can be selected by passing one of the following names:
- "random_poinst": Select random points.
- "uncertain_points": Select points with the highest uncertainty.
- "uncertain_worst_points": Select the points with the highest uncertainty and worst accuracies.
- "worst_points": Select the points with the worst accuracies.
- "worst_tiles": Selectt the tiles with the worst accuracies.
- "balanced_dense_accumulate": Balanced dense accumulation.
Arguments:
- raw_paths: The file paths to the raw data. May also be a single file.
- raw_key: The name of the internal dataset for the raw data. Set to None for regular image such as tif.
- label_paths: The file paths to the lable data. May also be a single file.
- label_key: The name of the internal dataset for the label data. Set to None for regular image such as tif.
- patch_shape_min: The minimal patch shape loaded for training a random forest.
- patch_shape_max: The maximal patch shape loaded for training a random forest.
- n_forests: The number of random forests to train.
- n_threads: The number of threads for parallelizing the training.
- output_folder: The folder for saving the random forests.
- ndim: The dimensionality of the data.
- forests_per_stage: The number of forests to train per stage.
- sample_fraction_per_stage: The fraction of samples to use per stage.
- sampling_strategy: The sampling strategy.
- sampling_kwargs: The keyword arguments for the sampling strategy.
- raw_transform: A transform to apply to the raw data before computing feautres on it.
- label_transform: A transform to apply to the label data before deriving targets for the random forest for it.
- rois: Region of interests for the training data.
- is_seg_dataset: Whether to create a segmentation dataset or an image collection dataset. If None, this wil be determined from the data.
- balance_labels: Whether to balance the training labels for the random forest.
- filter_config: The configuration for the image filters that are used to compute features for the random forest.
- sampler: A sampler to reject samples from training.
- rf_kwargs: Keyword arguments for creating the random forest.