torch_em.shallow2deep.prepare_shallow2deep
1import os 2import copy 3import pickle 4import warnings 5from concurrent import futures 6from glob import glob 7from functools import partial 8 9import numpy as np 10import torch_em 11from scipy.ndimage import gaussian_filter, convolve 12from skimage.feature import peak_local_max 13from sklearn.ensemble import RandomForestClassifier 14from torch_em.segmentation import check_paths, is_segmentation_dataset, samples_to_datasets 15from tqdm import tqdm 16 17import vigra 18try: 19 import fastfilters as filter_impl 20except ImportError: 21 import vigra.filters as filter_impl 22 23 24class RFSegmentationDataset(torch_em.data.SegmentationDataset): 25 _patch_shape_min = None 26 _patch_shape_max = None 27 28 @property 29 def patch_shape_min(self): 30 return self._patch_shape_min 31 32 @patch_shape_min.setter 33 def patch_shape_min(self, value): 34 self._patch_shape_min = value 35 36 @property 37 def patch_shape_max(self): 38 return self._patch_shape_max 39 40 @patch_shape_max.setter 41 def patch_shape_max(self, value): 42 self._patch_shape_max = value 43 44 def _sample_bounding_box(self): 45 assert self._patch_shape_min is not None and self._patch_shape_max is not None 46 sample_shape = [ 47 pmin if pmin == pmax else np.random.randint(pmin, pmax) 48 for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max) 49 ] 50 bb_start = [ 51 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 52 for sh, psh in zip(self.shape, sample_shape) 53 ] 54 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, sample_shape)) 55 56 57class RFImageCollectionDataset(torch_em.data.ImageCollectionDataset): 58 _patch_shape_min = None 59 _patch_shape_max = None 60 61 @property 62 def patch_shape_min(self): 63 return self._patch_shape_min 64 65 @patch_shape_min.setter 66 def patch_shape_min(self, value): 67 self._patch_shape_min = value 68 69 @property 70 def patch_shape_max(self): 71 return self._patch_shape_max 72 73 @patch_shape_max.setter 74 def patch_shape_max(self, value): 75 self._patch_shape_max = value 76 77 def _sample_bounding_box(self, shape): 78 if any(sh < psh for sh, psh in zip(shape, self.patch_shape_max)): 79 raise NotImplementedError("Image padding is not supported yet.") 80 assert self._patch_shape_min is not None and self._patch_shape_max is not None 81 patch_shape = [ 82 pmin if pmin == pmax else np.random.randint(pmin, pmax) 83 for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max) 84 ] 85 bb_start = [ 86 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, patch_shape) 87 ] 88 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape)) 89 90 91def _load_rf_segmentation_dataset( 92 raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, **kwargs 93): 94 rois = kwargs.pop("rois", None) 95 sampler = kwargs.pop("sampler", None) 96 sampler = sampler if sampler else torch_em.data.MinForegroundSampler(min_fraction=0.01) 97 if isinstance(raw_paths, str): 98 if rois is not None: 99 assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois) 100 ds = RFSegmentationDataset( 101 raw_paths, raw_key, label_paths, label_key, roi=rois, patch_shape=patch_shape_min, sampler=sampler, **kwargs 102 ) 103 ds.patch_shape_min = patch_shape_min 104 ds.patch_shape_max = patch_shape_max 105 else: 106 assert len(raw_paths) > 0 107 if rois is not None: 108 assert len(rois) == len(label_paths) 109 assert all(isinstance(roi, tuple) for roi in rois) 110 n_samples = kwargs.pop("n_samples", None) 111 112 samples_per_ds = ( 113 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 114 ) 115 ds = [] 116 for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): 117 roi = None if rois is None else rois[i] 118 dset = RFSegmentationDataset( 119 raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], 120 patch_shape=patch_shape_min, sampler=sampler, **kwargs 121 ) 122 dset.patch_shape_min = patch_shape_min 123 dset.patch_shape_max = patch_shape_max 124 ds.append(dset) 125 ds = torch_em.data.ConcatDataset(*ds) 126 return ds 127 128 129def _load_rf_image_collection_dataset( 130 raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, roi, **kwargs 131): 132 def _get_paths(rpath, rkey, lpath, lkey, this_roi): 133 rpath = glob(os.path.join(rpath, rkey)) 134 rpath.sort() 135 if len(rpath) == 0: 136 raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}") 137 lpath = glob(os.path.join(lpath, lkey)) 138 lpath.sort() 139 if len(rpath) != len(lpath): 140 raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}") 141 142 if this_roi is not None: 143 rpath, lpath = rpath[roi], lpath[roi] 144 145 return rpath, lpath 146 147 def _check_patch(patch_shape): 148 if len(patch_shape) == 3: 149 if patch_shape[0] != 1: 150 raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}") 151 patch_shape = patch_shape[1:] 152 assert len(patch_shape) == 2 153 return patch_shape 154 155 patch_shape_min = _check_patch(patch_shape_min) 156 patch_shape_max = _check_patch(patch_shape_max) 157 158 if isinstance(raw_paths, str): 159 raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi) 160 ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs) 161 ds.patch_shape_min = patch_shape_min 162 ds.patch_shape_max = patch_shape_max 163 elif raw_key is None: 164 assert label_key is None 165 assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple)) 166 assert len(raw_paths) == len(label_paths) 167 ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs) 168 ds.patch_shape_min = patch_shape_min 169 ds.patch_shape_max = patch_shape_max 170 else: 171 ds = [] 172 n_samples = kwargs.pop("n_samples", None) 173 samples_per_ds = ( 174 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 175 ) 176 if roi is None: 177 roi = len(raw_paths) * [None] 178 assert len(roi) == len(raw_paths) 179 for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)): 180 rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi) 181 dset = RFImageCollectionDataset( 182 rpath, lpath, patch_shape=patch_shape_min, n_samples=samples_per_ds[i], **kwargs 183 ) 184 dset.patch_shape_min = patch_shape_min 185 dset.patch_shape_max = patch_shape_max 186 ds.append(dset) 187 ds = torch_em.data.ConcatDataset(*ds) 188 return ds 189 190 191def _get_filters(ndim, filters_and_sigmas): 192 # subset of ilastik default features 193 if filters_and_sigmas is None: 194 filters = [filter_impl.gaussianSmoothing, 195 filter_impl.laplacianOfGaussian, 196 filter_impl.gaussianGradientMagnitude, 197 filter_impl.hessianOfGaussianEigenvalues, 198 filter_impl.structureTensorEigenvalues] 199 sigmas = [0.7, 1.6, 3.5, 5.0] 200 filters_and_sigmas = [ 201 (filt, sigma) if i != len(filters) - 1 else (partial(filt, outerScale=0.5*sigma), sigma) 202 for i, filt in enumerate(filters) for sigma in sigmas 203 ] 204 # validate the filter config 205 assert isinstance(filters_and_sigmas, (list, tuple)) 206 for filt_and_sig in filters_and_sigmas: 207 filt, sig = filt_and_sig 208 assert callable(filt) or (isinstance(filt, str) and hasattr(filter_impl, filt)) 209 assert isinstance(sig, (float, tuple)) 210 if isinstance(sig, tuple): 211 assert ndim is not None and len(sig) == ndim 212 assert all(isinstance(sigg, float) for sigg in sig) 213 return filters_and_sigmas 214 215 216def _calculate_response(raw, filter_, sigma): 217 if callable(filter_): 218 return filter_(raw, sigma) 219 220 # filter_ is still string, convert it to function 221 # fastfilters does not support passing sigma as tuple 222 func = getattr(vigra.filters, filter_) if isinstance(sigma, tuple) else getattr(filter_impl, filter_) 223 224 # special case since additional argument outerScale 225 # is needed for structureTensorEigenvalues functions 226 if filter_ == "structureTensorEigenvalues": 227 outerScale = tuple([s*2 for s in sigma]) if isinstance(sigma, tuple) else 2*sigma 228 return func(raw, sigma, outerScale=outerScale) 229 230 return func(raw, sigma) 231 232 233def _apply_filters(raw, filters_and_sigmas): 234 features = [] 235 for filter_, sigma in filters_and_sigmas: 236 response = _calculate_response(raw, filter_, sigma) 237 if response.ndim > raw.ndim: 238 for c in range(response.shape[-1]): 239 features.append(response[..., c].flatten()) 240 else: 241 features.append(response.flatten()) 242 features = np.concatenate([ff[:, None] for ff in features], axis=1) 243 return features 244 245 246def _apply_filters_with_mask(raw, filters_and_sigmas, mask): 247 features = [] 248 for filter_, sigma in filters_and_sigmas: 249 response = _calculate_response(raw, filter_, sigma) 250 if response.ndim > raw.ndim: 251 for c in range(response.shape[-1]): 252 features.append(response[..., c][mask]) 253 else: 254 features.append(response[mask]) 255 features = np.concatenate([ff[:, None] for ff in features], axis=1) 256 return features 257 258 259def _balance_labels(labels, mask): 260 class_ids, label_counts = np.unique(labels[mask], return_counts=True) 261 n_classes = len(class_ids) 262 assert class_ids.tolist() == list(range(n_classes)) 263 264 min_class = class_ids[np.argmin(label_counts)] 265 n_labels = label_counts[min_class] 266 267 for class_id in class_ids: 268 if class_id == min_class: 269 continue 270 n_discard = label_counts[class_id] - n_labels 271 # sample from the current class 272 # shuffle the positions and only keep up to n_labels in the mask 273 label_pos = np.where(labels == class_id) 274 discard_ids = np.arange(len(label_pos[0])) 275 np.random.shuffle(discard_ids) 276 discard_ids = discard_ids[:n_discard] 277 discard_mask = tuple(pos[discard_ids] for pos in label_pos) 278 mask[discard_mask] = False 279 280 assert mask.sum() == n_classes * n_labels 281 return mask 282 283 284def _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels, return_mask=False): 285 # find the mask for where we compute filters and labels 286 # by default we exclude everything that has label -1 287 assert labels.shape == raw.shape 288 mask = labels != -1 289 if balance_labels: 290 mask = _balance_labels(labels, mask) 291 labels = labels[mask] 292 assert labels.ndim == 1 293 features = _apply_filters_with_mask(raw, filters_and_sigmas, mask) 294 assert features.ndim == 2 295 assert len(features) == len(labels) 296 if return_mask: 297 return features, labels, mask 298 else: 299 return features, labels 300 301 302def _prepare_shallow2deep( 303 raw_paths, 304 raw_key, 305 label_paths, 306 label_key, 307 patch_shape_min, 308 patch_shape_max, 309 n_forests, 310 ndim, 311 raw_transform, 312 label_transform, 313 rois, 314 is_seg_dataset, 315 filter_config, 316 sampler, 317): 318 assert len(patch_shape_min) == len(patch_shape_max) 319 assert all(maxs >= mins for maxs, mins in zip(patch_shape_max, patch_shape_min)) 320 check_paths(raw_paths, label_paths) 321 322 # get the correct dataset 323 if is_seg_dataset is None: 324 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 325 if is_seg_dataset: 326 ds = _load_rf_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, 327 patch_shape_min, patch_shape_max, 328 raw_transform=raw_transform, label_transform=label_transform, 329 rois=rois, n_samples=n_forests, sampler=sampler) 330 else: 331 ds = _load_rf_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, 332 patch_shape_min, patch_shape_max, roi=rois, 333 raw_transform=raw_transform, label_transform=label_transform, 334 n_samples=n_forests) 335 336 assert len(ds) == n_forests, f"{len(ds), {n_forests}}" 337 filters_and_sigmas = _get_filters(ndim, filter_config) 338 return ds, filters_and_sigmas 339 340 341def _serialize_feature_config(filters_and_sigmas): 342 feature_config = [ 343 (filt if isinstance(filt, str) 344 else (filt.func.__name__ if isinstance(filt, partial) else filt.__name__), 345 sigma) 346 for filt, sigma in filters_and_sigmas 347 ] 348 return feature_config 349 350 351def prepare_shallow2deep( 352 raw_paths, 353 raw_key, 354 label_paths, 355 label_key, 356 patch_shape_min, 357 patch_shape_max, 358 n_forests, 359 n_threads, 360 output_folder, 361 ndim, 362 raw_transform=None, 363 label_transform=None, 364 rois=None, 365 is_seg_dataset=None, 366 balance_labels=True, 367 filter_config=None, 368 sampler=None, 369 **rf_kwargs, 370): 371 os.makedirs(output_folder, exist_ok=True) 372 ds, filters_and_sigmas = _prepare_shallow2deep( 373 raw_paths, raw_key, label_paths, label_key, 374 patch_shape_min, patch_shape_max, n_forests, ndim, 375 raw_transform, label_transform, rois, is_seg_dataset, 376 filter_config, sampler, 377 ) 378 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 379 380 def _train_rf(rf_id): 381 # sample random patch with dataset 382 raw, labels = ds[rf_id] 383 # cast to numpy and remove channel axis 384 # need to update this to support multi-channel input data and/or multi class prediction 385 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 386 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 387 features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels) 388 rf = RandomForestClassifier(**rf_kwargs) 389 rf.fit(features, labels) 390 # monkey patch these so that we know the feature config and dimensionality 391 rf.feature_ndim = ndim 392 rf.feature_config = serialized_feature_config 393 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 394 with open(out_path, "wb") as f: 395 pickle.dump(rf, f) 396 397 with futures.ThreadPoolExecutor(n_threads) as tp: 398 list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests)) 399 400 401def _score_based_points( 402 score_function, 403 features, labels, rf_id, 404 forests, forests_per_stage, 405 sample_fraction_per_stage, 406 accumulate_samples, 407): 408 # get the corresponding random forest from the last stage 409 # and predict with it 410 last_forest = forests[rf_id - forests_per_stage] 411 pred = last_forest.predict_proba(features) 412 413 score = score_function(pred, labels) 414 assert len(score) == len(features) 415 416 # get training samples based on the label-prediction diff 417 samples = [] 418 nc = len(np.unique(labels)) 419 # sample in a class balanced way 420 n_samples = int(sample_fraction_per_stage * len(features)) 421 n_samples_class = n_samples // nc 422 for class_id in range(nc): 423 class_indices = np.where(labels == class_id)[0] 424 this_samples = class_indices[np.argsort(score[class_indices])[::-1][:n_samples_class]] 425 samples.append(this_samples) 426 samples = np.concatenate(samples) 427 428 # get the features and labels, add from previous rf if specified 429 features, labels = features[samples], labels[samples] 430 if accumulate_samples: 431 features = np.concatenate([last_forest.train_features, features], axis=0) 432 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 433 434 return features, labels 435 436 437def worst_points( 438 features, labels, rf_id, 439 forests, forests_per_stage, 440 sample_fraction_per_stage, 441 accumulate_samples=True, 442 **kwargs, 443): 444 def score(pred, labels): 445 # labels to one-hot encoding 446 unique, inverse = np.unique(labels, return_inverse=True) 447 onehot = np.eye(unique.shape[0])[inverse] 448 # compute the difference between labels and prediction 449 return np.abs(onehot - pred).sum(axis=1) 450 451 return _score_based_points( 452 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 453 ) 454 455 456def uncertain_points( 457 features, labels, rf_id, 458 forests, forests_per_stage, 459 sample_fraction_per_stage, 460 accumulate_samples=True, 461 **kwargs, 462): 463 def score(pred, labels): 464 assert pred.ndim == 2 465 channel_sorted = np.sort(pred, axis=1) 466 uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2] 467 return uncertainty 468 469 return _score_based_points( 470 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 471 ) 472 473 474def uncertain_worst_points( 475 features, labels, rf_id, 476 forests, forests_per_stage, 477 sample_fraction_per_stage, 478 accumulate_samples=True, 479 alpha=0.5, 480 **kwargs, 481): 482 def score(pred, labels): 483 assert pred.ndim == 2 484 485 # labels to one-hot encoding 486 unique, inverse = np.unique(labels, return_inverse=True) 487 onehot = np.eye(unique.shape[0])[inverse] 488 # compute the difference between labels and prediction 489 diff = np.abs(onehot - pred).sum(axis=1) 490 491 channel_sorted = np.sort(pred, axis=1) 492 uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2] 493 return alpha * diff + (1.0 - alpha) * uncertainty 494 495 return _score_based_points( 496 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 497 ) 498 499 500def random_points( 501 features, labels, rf_id, 502 forests, forests_per_stage, 503 sample_fraction_per_stage, 504 accumulate_samples=True, 505 **kwargs, 506): 507 samples = [] 508 nc = len(np.unique(labels)) 509 # sample in a class balanced way 510 n_samples = int(sample_fraction_per_stage * len(features)) 511 n_samples_class = n_samples // nc 512 for class_id in range(nc): 513 class_indices = np.where(labels == class_id)[0] 514 this_samples = np.random.choice( 515 class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class 516 ) 517 samples.append(this_samples) 518 samples = np.concatenate(samples) 519 features, labels = features[samples], labels[samples] 520 521 if accumulate_samples and rf_id >= forests_per_stage: 522 last_forest = forests[rf_id - forests_per_stage] 523 features = np.concatenate([last_forest.train_features, features], axis=0) 524 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 525 526 return features, labels 527 528 529def worst_tiles( 530 features, labels, rf_id, 531 forests, forests_per_stage, 532 sample_fraction_per_stage, 533 img_shape, 534 mask, 535 tile_shape=[25, 25], 536 smoothing_sigma=None, 537 accumulate_samples=True, 538 **kwargs, 539): 540 # check inputs 541 ndim = len(img_shape) 542 assert ndim in [2, 3], img_shape 543 assert len(tile_shape) == ndim, tile_shape 544 545 # get the corresponding random forest from the last stage 546 # and predict with it 547 last_forest = forests[rf_id - forests_per_stage] 548 pred = last_forest.predict_proba(features) 549 550 # labels to one-hot encoding 551 unique, inverse = np.unique(labels, return_inverse=True) 552 onehot = np.eye(unique.shape[0])[inverse] 553 554 # compute the difference between labels and prediction 555 diff = np.abs(onehot - pred) 556 assert len(diff) == len(features) 557 558 # reshape diff to image shape 559 # we need to also take into account the mask here, and if we apply any masking 560 # because we can't directly reshape if we have it 561 if mask.sum() != mask.size: 562 # get the diff image 563 diff_img = np.zeros(img_shape + diff.shape[-1:], dtype=diff.dtype) 564 diff_img[mask] = diff 565 # inflate the features 566 full_features = np.zeros((mask.size,) + features.shape[-1:], dtype=features.dtype) 567 full_features[mask.ravel()] = features 568 features = full_features 569 # inflate the labels (with -1 so this will not be sampled) 570 full_labels = np.full(mask.size, -1, dtype="int8") 571 full_labels[mask.ravel()] = labels 572 labels = full_labels 573 else: 574 diff_img = diff.reshape(img_shape + (-1,)) 575 576 # get the number of classes (not counting ignore label) 577 class_ids = np.unique(labels) 578 nc = len(class_ids) - 1 if -1 in class_ids else len(class_ids) 579 580 # sample in a class balanced way 581 n_samples_class = int(sample_fraction_per_stage * len(features)) // nc 582 samples = [] 583 for class_id in range(nc): 584 # smooth either with gaussian or 1-kernel 585 if smoothing_sigma: 586 diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode="constant") 587 else: 588 kernel = np.ones(tile_shape) 589 diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode="constant") 590 591 # get training samples based on tiles around maxima of the label-prediction diff 592 # do this in a class-specific way to ensure that each class is sampled 593 # get maxima of the label-prediction diff (they seem to be sorted already) 594 max_centers = peak_local_max( 595 diff_img_smooth, 596 min_distance=max(tile_shape), 597 exclude_border=tuple([s // 2 for s in tile_shape]) 598 ) 599 600 # get indices of tiles around maxima 601 tiles = [] 602 for center in max_centers: 603 tile_slice = tuple( 604 slice( 605 center[d]-tile_shape[d]//2, 606 center[d]+tile_shape[d]//2 + 1, 607 None 608 ) for d in range(ndim) 609 ) 610 grid = np.mgrid[tile_slice] 611 samples_in_tile = grid.reshape(ndim, -1) 612 samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape) 613 tiles.append(samples_in_tile) 614 615 # this (very rarely) fails due to empty tile list. Since we usually 616 # accumulate the features this doesn't hurt much and we can continue 617 try: 618 tiles = np.concatenate(tiles) 619 # take samples that belong to the current class 620 this_samples = tiles[labels[tiles] == class_id][:n_samples_class] 621 samples.append(this_samples) 622 except ValueError: 623 pass 624 625 try: 626 samples = np.concatenate(samples) 627 features, labels = features[samples], labels[samples] 628 629 # get the features and labels, add from previous rf if specified 630 if accumulate_samples: 631 features = np.concatenate([last_forest.train_features, features], axis=0) 632 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 633 except ValueError: 634 features, labels = last_forest.train_features, last_forest.train_labels 635 warnings.warn( 636 f"No features were sampled for forest {rf_id} using features of forest {rf_id - forests_per_stage}" 637 ) 638 639 return features, labels 640 641 642def balanced_dense_accumulate( 643 features, labels, rf_id, 644 forests, forests_per_stage, 645 sample_fraction_per_stage, 646 accumulate_samples=True, 647 **kwargs, 648): 649 samples = [] 650 nc = len(np.unique(labels)) 651 # sample in a class balanced way 652 # take all pixels from minority class 653 # and choose same amount from other classes randomly 654 n_samples_class = np.unique(labels, return_counts=True)[1].min() 655 for class_id in range(nc): 656 class_indices = np.where(labels == class_id)[0] 657 this_samples = np.random.choice( 658 class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class 659 ) 660 samples.append(this_samples) 661 samples = np.concatenate(samples) 662 features, labels = features[samples], labels[samples] 663 664 # accumulate 665 if accumulate_samples and rf_id >= forests_per_stage: 666 last_forest = forests[rf_id - forests_per_stage] 667 features = np.concatenate([last_forest.train_features, features], axis=0) 668 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 669 670 return features, labels 671 672 673SAMPLING_STRATEGIES = { 674 "random_points": random_points, 675 "uncertain_points": uncertain_points, 676 "uncertain_worst_points": uncertain_worst_points, 677 "worst_points": worst_points, 678 "worst_tiles": worst_tiles, 679 "balanced_dense_accumulate": balanced_dense_accumulate, 680} 681 682 683def prepare_shallow2deep_advanced( 684 raw_paths, 685 raw_key, 686 label_paths, 687 label_key, 688 patch_shape_min, 689 patch_shape_max, 690 n_forests, 691 n_threads, 692 output_folder, 693 ndim, 694 forests_per_stage, 695 sample_fraction_per_stage, 696 sampling_strategy="worst_points", 697 sampling_kwargs={}, 698 raw_transform=None, 699 label_transform=None, 700 rois=None, 701 is_seg_dataset=None, 702 filter_config=None, 703 sampler=None, 704 **rf_kwargs, 705): 706 """Advanced training of random forests for shallow2deep enhancer training. 707 708 This function accepts the 'sampling_strategy' parameter, which allows to implement custom 709 sampling strategies for the samples used for training the random forests. 710 Training operates in stages, the parameter 'forests_per_stage' determines how many forests 711 are trained in each stage, and 'sample_fraction_per_stage' which fraction of the samples is 712 taken per stage. The random forests in stage 0 are trained from random balanced labels. 713 For the other stages 'sampling_strategy' enables specifying the strategy; it has to be a function 714 with signature '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)', 715 and return the sampled features and labels. See for the 'worst_points' function 716 in this file for an example implementation. 717 """ 718 os.makedirs(output_folder, exist_ok=True) 719 ds, filters_and_sigmas = _prepare_shallow2deep( 720 raw_paths, raw_key, label_paths, label_key, 721 patch_shape_min, patch_shape_max, n_forests, ndim, 722 raw_transform, label_transform, rois, is_seg_dataset, 723 filter_config, sampler, 724 ) 725 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 726 727 forests = [] 728 n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\ 729 n_forests // forests_per_stage + 1 730 731 if isinstance(sampling_strategy, str): 732 assert sampling_strategy in SAMPLING_STRATEGIES,\ 733 f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}" 734 sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy] 735 assert callable(sampling_strategy) 736 737 with tqdm(total=n_forests) as pbar: 738 739 def _train_rf(rf_id): 740 # sample random patch with dataset 741 raw, labels = ds[rf_id] 742 743 # cast to numpy and remove channel axis 744 # need to update this to support multi-channel input data and/or multi class prediction 745 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 746 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 747 748 # monkey patch original shape to sampling_kwargs 749 # deepcopy needed due to multithreading 750 current_kwargs = copy.deepcopy(sampling_kwargs) 751 current_kwargs["img_shape"] = raw.shape 752 753 # only balance samples for the first (densely trained) rfs 754 features, labels, mask = _get_features_and_labels( 755 raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True 756 ) 757 if forests: # we have forests: apply the sampling strategy 758 features, labels = sampling_strategy( 759 features, labels, rf_id, 760 forests, forests_per_stage, 761 sample_fraction_per_stage, 762 mask=mask, 763 **current_kwargs, 764 ) 765 else: # sample randomly 766 features, labels = random_points( 767 features, labels, rf_id, 768 forests, forests_per_stage, 769 sample_fraction_per_stage, 770 ) 771 772 # fit the random forest 773 assert len(features) == len(labels) 774 rf = RandomForestClassifier(**rf_kwargs) 775 rf.fit(features, labels) 776 # monkey patch these so that we know the feature config and dimensionality 777 rf.feature_ndim = ndim 778 rf.feature_config = serialized_feature_config 779 780 # save the random forest, update pbar, return it 781 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 782 with open(out_path, "wb") as f: 783 pickle.dump(rf, f) 784 785 # monkey patch the training data and labels so we can re-use it in later stages 786 rf.train_features = features 787 rf.train_labels = labels 788 789 pbar.update(1) 790 return rf 791 792 for stage in range(n_stages): 793 pbar.set_description(f"Train RFs for stage {stage}") 794 with futures.ThreadPoolExecutor(n_threads) as tp: 795 this_forests = list(tp.map( 796 _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1)) 797 )) 798 forests.extend(this_forests)
25class RFSegmentationDataset(torch_em.data.SegmentationDataset): 26 _patch_shape_min = None 27 _patch_shape_max = None 28 29 @property 30 def patch_shape_min(self): 31 return self._patch_shape_min 32 33 @patch_shape_min.setter 34 def patch_shape_min(self, value): 35 self._patch_shape_min = value 36 37 @property 38 def patch_shape_max(self): 39 return self._patch_shape_max 40 41 @patch_shape_max.setter 42 def patch_shape_max(self, value): 43 self._patch_shape_max = value 44 45 def _sample_bounding_box(self): 46 assert self._patch_shape_min is not None and self._patch_shape_max is not None 47 sample_shape = [ 48 pmin if pmin == pmax else np.random.randint(pmin, pmax) 49 for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max) 50 ] 51 bb_start = [ 52 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 53 for sh, psh in zip(self.shape, sample_shape) 54 ] 55 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, sample_shape))
Inherited Members
58class RFImageCollectionDataset(torch_em.data.ImageCollectionDataset): 59 _patch_shape_min = None 60 _patch_shape_max = None 61 62 @property 63 def patch_shape_min(self): 64 return self._patch_shape_min 65 66 @patch_shape_min.setter 67 def patch_shape_min(self, value): 68 self._patch_shape_min = value 69 70 @property 71 def patch_shape_max(self): 72 return self._patch_shape_max 73 74 @patch_shape_max.setter 75 def patch_shape_max(self, value): 76 self._patch_shape_max = value 77 78 def _sample_bounding_box(self, shape): 79 if any(sh < psh for sh, psh in zip(shape, self.patch_shape_max)): 80 raise NotImplementedError("Image padding is not supported yet.") 81 assert self._patch_shape_min is not None and self._patch_shape_max is not None 82 patch_shape = [ 83 pmin if pmin == pmax else np.random.randint(pmin, pmax) 84 for pmin, pmax in zip(self._patch_shape_min, self._patch_shape_max) 85 ] 86 bb_start = [ 87 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, patch_shape) 88 ] 89 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape))
An abstract class representing a Dataset
.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__()
, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__()
, which is expected to return the size of the dataset by many
~torch.utils.data.Sampler
implementations and the default options
of ~torch.utils.data.DataLoader
. Subclasses could also
optionally implement __getitems__()
, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
~torch.utils.data.DataLoader
by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
352def prepare_shallow2deep( 353 raw_paths, 354 raw_key, 355 label_paths, 356 label_key, 357 patch_shape_min, 358 patch_shape_max, 359 n_forests, 360 n_threads, 361 output_folder, 362 ndim, 363 raw_transform=None, 364 label_transform=None, 365 rois=None, 366 is_seg_dataset=None, 367 balance_labels=True, 368 filter_config=None, 369 sampler=None, 370 **rf_kwargs, 371): 372 os.makedirs(output_folder, exist_ok=True) 373 ds, filters_and_sigmas = _prepare_shallow2deep( 374 raw_paths, raw_key, label_paths, label_key, 375 patch_shape_min, patch_shape_max, n_forests, ndim, 376 raw_transform, label_transform, rois, is_seg_dataset, 377 filter_config, sampler, 378 ) 379 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 380 381 def _train_rf(rf_id): 382 # sample random patch with dataset 383 raw, labels = ds[rf_id] 384 # cast to numpy and remove channel axis 385 # need to update this to support multi-channel input data and/or multi class prediction 386 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 387 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 388 features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels) 389 rf = RandomForestClassifier(**rf_kwargs) 390 rf.fit(features, labels) 391 # monkey patch these so that we know the feature config and dimensionality 392 rf.feature_ndim = ndim 393 rf.feature_config = serialized_feature_config 394 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 395 with open(out_path, "wb") as f: 396 pickle.dump(rf, f) 397 398 with futures.ThreadPoolExecutor(n_threads) as tp: 399 list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests))
438def worst_points( 439 features, labels, rf_id, 440 forests, forests_per_stage, 441 sample_fraction_per_stage, 442 accumulate_samples=True, 443 **kwargs, 444): 445 def score(pred, labels): 446 # labels to one-hot encoding 447 unique, inverse = np.unique(labels, return_inverse=True) 448 onehot = np.eye(unique.shape[0])[inverse] 449 # compute the difference between labels and prediction 450 return np.abs(onehot - pred).sum(axis=1) 451 452 return _score_based_points( 453 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 454 )
457def uncertain_points( 458 features, labels, rf_id, 459 forests, forests_per_stage, 460 sample_fraction_per_stage, 461 accumulate_samples=True, 462 **kwargs, 463): 464 def score(pred, labels): 465 assert pred.ndim == 2 466 channel_sorted = np.sort(pred, axis=1) 467 uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2] 468 return uncertainty 469 470 return _score_based_points( 471 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 472 )
475def uncertain_worst_points( 476 features, labels, rf_id, 477 forests, forests_per_stage, 478 sample_fraction_per_stage, 479 accumulate_samples=True, 480 alpha=0.5, 481 **kwargs, 482): 483 def score(pred, labels): 484 assert pred.ndim == 2 485 486 # labels to one-hot encoding 487 unique, inverse = np.unique(labels, return_inverse=True) 488 onehot = np.eye(unique.shape[0])[inverse] 489 # compute the difference between labels and prediction 490 diff = np.abs(onehot - pred).sum(axis=1) 491 492 channel_sorted = np.sort(pred, axis=1) 493 uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2] 494 return alpha * diff + (1.0 - alpha) * uncertainty 495 496 return _score_based_points( 497 score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples 498 )
501def random_points( 502 features, labels, rf_id, 503 forests, forests_per_stage, 504 sample_fraction_per_stage, 505 accumulate_samples=True, 506 **kwargs, 507): 508 samples = [] 509 nc = len(np.unique(labels)) 510 # sample in a class balanced way 511 n_samples = int(sample_fraction_per_stage * len(features)) 512 n_samples_class = n_samples // nc 513 for class_id in range(nc): 514 class_indices = np.where(labels == class_id)[0] 515 this_samples = np.random.choice( 516 class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class 517 ) 518 samples.append(this_samples) 519 samples = np.concatenate(samples) 520 features, labels = features[samples], labels[samples] 521 522 if accumulate_samples and rf_id >= forests_per_stage: 523 last_forest = forests[rf_id - forests_per_stage] 524 features = np.concatenate([last_forest.train_features, features], axis=0) 525 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 526 527 return features, labels
530def worst_tiles( 531 features, labels, rf_id, 532 forests, forests_per_stage, 533 sample_fraction_per_stage, 534 img_shape, 535 mask, 536 tile_shape=[25, 25], 537 smoothing_sigma=None, 538 accumulate_samples=True, 539 **kwargs, 540): 541 # check inputs 542 ndim = len(img_shape) 543 assert ndim in [2, 3], img_shape 544 assert len(tile_shape) == ndim, tile_shape 545 546 # get the corresponding random forest from the last stage 547 # and predict with it 548 last_forest = forests[rf_id - forests_per_stage] 549 pred = last_forest.predict_proba(features) 550 551 # labels to one-hot encoding 552 unique, inverse = np.unique(labels, return_inverse=True) 553 onehot = np.eye(unique.shape[0])[inverse] 554 555 # compute the difference between labels and prediction 556 diff = np.abs(onehot - pred) 557 assert len(diff) == len(features) 558 559 # reshape diff to image shape 560 # we need to also take into account the mask here, and if we apply any masking 561 # because we can't directly reshape if we have it 562 if mask.sum() != mask.size: 563 # get the diff image 564 diff_img = np.zeros(img_shape + diff.shape[-1:], dtype=diff.dtype) 565 diff_img[mask] = diff 566 # inflate the features 567 full_features = np.zeros((mask.size,) + features.shape[-1:], dtype=features.dtype) 568 full_features[mask.ravel()] = features 569 features = full_features 570 # inflate the labels (with -1 so this will not be sampled) 571 full_labels = np.full(mask.size, -1, dtype="int8") 572 full_labels[mask.ravel()] = labels 573 labels = full_labels 574 else: 575 diff_img = diff.reshape(img_shape + (-1,)) 576 577 # get the number of classes (not counting ignore label) 578 class_ids = np.unique(labels) 579 nc = len(class_ids) - 1 if -1 in class_ids else len(class_ids) 580 581 # sample in a class balanced way 582 n_samples_class = int(sample_fraction_per_stage * len(features)) // nc 583 samples = [] 584 for class_id in range(nc): 585 # smooth either with gaussian or 1-kernel 586 if smoothing_sigma: 587 diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode="constant") 588 else: 589 kernel = np.ones(tile_shape) 590 diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode="constant") 591 592 # get training samples based on tiles around maxima of the label-prediction diff 593 # do this in a class-specific way to ensure that each class is sampled 594 # get maxima of the label-prediction diff (they seem to be sorted already) 595 max_centers = peak_local_max( 596 diff_img_smooth, 597 min_distance=max(tile_shape), 598 exclude_border=tuple([s // 2 for s in tile_shape]) 599 ) 600 601 # get indices of tiles around maxima 602 tiles = [] 603 for center in max_centers: 604 tile_slice = tuple( 605 slice( 606 center[d]-tile_shape[d]//2, 607 center[d]+tile_shape[d]//2 + 1, 608 None 609 ) for d in range(ndim) 610 ) 611 grid = np.mgrid[tile_slice] 612 samples_in_tile = grid.reshape(ndim, -1) 613 samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape) 614 tiles.append(samples_in_tile) 615 616 # this (very rarely) fails due to empty tile list. Since we usually 617 # accumulate the features this doesn't hurt much and we can continue 618 try: 619 tiles = np.concatenate(tiles) 620 # take samples that belong to the current class 621 this_samples = tiles[labels[tiles] == class_id][:n_samples_class] 622 samples.append(this_samples) 623 except ValueError: 624 pass 625 626 try: 627 samples = np.concatenate(samples) 628 features, labels = features[samples], labels[samples] 629 630 # get the features and labels, add from previous rf if specified 631 if accumulate_samples: 632 features = np.concatenate([last_forest.train_features, features], axis=0) 633 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 634 except ValueError: 635 features, labels = last_forest.train_features, last_forest.train_labels 636 warnings.warn( 637 f"No features were sampled for forest {rf_id} using features of forest {rf_id - forests_per_stage}" 638 ) 639 640 return features, labels
643def balanced_dense_accumulate( 644 features, labels, rf_id, 645 forests, forests_per_stage, 646 sample_fraction_per_stage, 647 accumulate_samples=True, 648 **kwargs, 649): 650 samples = [] 651 nc = len(np.unique(labels)) 652 # sample in a class balanced way 653 # take all pixels from minority class 654 # and choose same amount from other classes randomly 655 n_samples_class = np.unique(labels, return_counts=True)[1].min() 656 for class_id in range(nc): 657 class_indices = np.where(labels == class_id)[0] 658 this_samples = np.random.choice( 659 class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class 660 ) 661 samples.append(this_samples) 662 samples = np.concatenate(samples) 663 features, labels = features[samples], labels[samples] 664 665 # accumulate 666 if accumulate_samples and rf_id >= forests_per_stage: 667 last_forest = forests[rf_id - forests_per_stage] 668 features = np.concatenate([last_forest.train_features, features], axis=0) 669 labels = np.concatenate([last_forest.train_labels, labels], axis=0) 670 671 return features, labels
684def prepare_shallow2deep_advanced( 685 raw_paths, 686 raw_key, 687 label_paths, 688 label_key, 689 patch_shape_min, 690 patch_shape_max, 691 n_forests, 692 n_threads, 693 output_folder, 694 ndim, 695 forests_per_stage, 696 sample_fraction_per_stage, 697 sampling_strategy="worst_points", 698 sampling_kwargs={}, 699 raw_transform=None, 700 label_transform=None, 701 rois=None, 702 is_seg_dataset=None, 703 filter_config=None, 704 sampler=None, 705 **rf_kwargs, 706): 707 """Advanced training of random forests for shallow2deep enhancer training. 708 709 This function accepts the 'sampling_strategy' parameter, which allows to implement custom 710 sampling strategies for the samples used for training the random forests. 711 Training operates in stages, the parameter 'forests_per_stage' determines how many forests 712 are trained in each stage, and 'sample_fraction_per_stage' which fraction of the samples is 713 taken per stage. The random forests in stage 0 are trained from random balanced labels. 714 For the other stages 'sampling_strategy' enables specifying the strategy; it has to be a function 715 with signature '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)', 716 and return the sampled features and labels. See for the 'worst_points' function 717 in this file for an example implementation. 718 """ 719 os.makedirs(output_folder, exist_ok=True) 720 ds, filters_and_sigmas = _prepare_shallow2deep( 721 raw_paths, raw_key, label_paths, label_key, 722 patch_shape_min, patch_shape_max, n_forests, ndim, 723 raw_transform, label_transform, rois, is_seg_dataset, 724 filter_config, sampler, 725 ) 726 serialized_feature_config = _serialize_feature_config(filters_and_sigmas) 727 728 forests = [] 729 n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\ 730 n_forests // forests_per_stage + 1 731 732 if isinstance(sampling_strategy, str): 733 assert sampling_strategy in SAMPLING_STRATEGIES,\ 734 f"Invalid sampling strategy {sampling_strategy}, only support {list(SAMPLING_STRATEGIES.keys())}" 735 sampling_strategy = SAMPLING_STRATEGIES[sampling_strategy] 736 assert callable(sampling_strategy) 737 738 with tqdm(total=n_forests) as pbar: 739 740 def _train_rf(rf_id): 741 # sample random patch with dataset 742 raw, labels = ds[rf_id] 743 744 # cast to numpy and remove channel axis 745 # need to update this to support multi-channel input data and/or multi class prediction 746 raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() 747 assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" 748 749 # monkey patch original shape to sampling_kwargs 750 # deepcopy needed due to multithreading 751 current_kwargs = copy.deepcopy(sampling_kwargs) 752 current_kwargs["img_shape"] = raw.shape 753 754 # only balance samples for the first (densely trained) rfs 755 features, labels, mask = _get_features_and_labels( 756 raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True 757 ) 758 if forests: # we have forests: apply the sampling strategy 759 features, labels = sampling_strategy( 760 features, labels, rf_id, 761 forests, forests_per_stage, 762 sample_fraction_per_stage, 763 mask=mask, 764 **current_kwargs, 765 ) 766 else: # sample randomly 767 features, labels = random_points( 768 features, labels, rf_id, 769 forests, forests_per_stage, 770 sample_fraction_per_stage, 771 ) 772 773 # fit the random forest 774 assert len(features) == len(labels) 775 rf = RandomForestClassifier(**rf_kwargs) 776 rf.fit(features, labels) 777 # monkey patch these so that we know the feature config and dimensionality 778 rf.feature_ndim = ndim 779 rf.feature_config = serialized_feature_config 780 781 # save the random forest, update pbar, return it 782 out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") 783 with open(out_path, "wb") as f: 784 pickle.dump(rf, f) 785 786 # monkey patch the training data and labels so we can re-use it in later stages 787 rf.train_features = features 788 rf.train_labels = labels 789 790 pbar.update(1) 791 return rf 792 793 for stage in range(n_stages): 794 pbar.set_description(f"Train RFs for stage {stage}") 795 with futures.ThreadPoolExecutor(n_threads) as tp: 796 this_forests = list(tp.map( 797 _train_rf, range(forests_per_stage * stage, forests_per_stage * (stage + 1)) 798 )) 799 forests.extend(this_forests)
Advanced training of random forests for shallow2deep enhancer training.
This function accepts the 'sampling_strategy' parameter, which allows to implement custom sampling strategies for the samples used for training the random forests. Training operates in stages, the parameter 'forests_per_stage' determines how many forests are trained in each stage, and 'sample_fraction_per_stage' which fraction of the samples is taken per stage. The random forests in stage 0 are trained from random balanced labels. For the other stages 'sampling_strategy' enables specifying the strategy; it has to be a function with signature '(features, labels, forests, rf_id, forests_per_stage, sample_fraction_per_stage)', and return the sampled features and labels. See for the 'worst_points' function in this file for an example implementation.