elf.segmentation.features
1import multiprocessing 2from concurrent import futures 3from typing import Dict, List, Optional, Tuple 4 5import bioimage_cpp as bic 6import numpy as np 7from scipy.stats import kurtosis, skew 8from skimage.measure import regionprops_table 9 10from tqdm import tqdm 11from .multicut import transform_probabilities_to_costs 12 13 14# Map fastfilters/vigra filter names to bic.filters callables. 15_BIC_FILTERS = { 16 "gaussianSmoothing": bic.filters.gaussian_smoothing, 17 "gaussianGradientMagnitude": bic.filters.gaussian_gradient_magnitude, 18 "laplacianOfGaussian": bic.filters.laplacian_of_gaussian, 19 "hessianOfGaussianEigenvalues": bic.filters.hessian_of_gaussian_eigenvalues, 20 "structureTensorEigenvalues": bic.filters.structure_tensor_eigenvalues, 21 "gaussianDerivative": bic.filters.gaussian_derivative, 22} 23 24 25def _apply_filter(filter_name, image, sigma): 26 """@private""" 27 fu = _BIC_FILTERS[filter_name] 28 if image.dtype not in (np.float32, np.float64, np.uint8, np.uint16): 29 image = image.astype("float32") 30 return fu(image, sigma) 31 32 33# 34# Region Adjacency Graph and Features 35# 36 37def compute_rag(segmentation: np.ndarray, n_labels: Optional[int] = None, n_threads: Optional[int] = None): 38 """Compute region adjacency graph of segmentation. 39 40 Args: 41 segmentation: The segmentation. 42 n_labels: Deprecated; ignored. Kept for backwards-compatibility. 43 n_threads: The number of threads used, set to cpu count by default. 44 45 Returns: 46 The region adjacency graph (`bioimage_cpp.graph.RegionAdjacencyGraph`). 47 """ 48 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 49 if segmentation.dtype not in (np.uint32, np.uint64, np.int32, np.int64): 50 segmentation = segmentation.astype("uint32") 51 rag = bic.graph.region_adjacency_graph(segmentation, number_of_threads=n_threads) 52 return rag 53 54 55def compute_boundary_features( 56 rag, 57 segmentation: np.ndarray, 58 boundary_map: np.ndarray, 59 min_value: float = 0.0, # noqa: ARG001 — deprecated, ignored 60 max_value: float = 1.0, # noqa: ARG001 — deprecated, ignored 61 n_threads: Optional[int] = None, 62) -> np.ndarray: 63 """Compute edge features from boundary map. 64 65 Args: 66 rag: The region adjacency graph. 67 segmentation: The over-segmentation used to construct the RAG. 68 boundary_map: The boundary map. 69 min_value: Deprecated; ignored. 70 max_value: Deprecated; ignored. 71 n_threads: The number of threads used, set to cpu count by default. 72 73 Returns: 74 The edge features. Output has 12 columns 75 (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size). 76 """ 77 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 78 if segmentation.shape != boundary_map.shape: 79 raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(boundary_map.shape))) 80 features = bic.graph.features.edge_map_features_complex( 81 rag, segmentation, boundary_map, number_of_threads=n_threads, 82 ) 83 return features 84 85 86def compute_affinity_features( 87 rag, 88 segmentation: np.ndarray, 89 affinity_map: np.ndarray, 90 offsets: List[List[int]], 91 min_value: float = 0.0, # noqa: ARG001 — deprecated, ignored 92 max_value: float = 1.0, # noqa: ARG001 — deprecated, ignored 93 n_threads: Optional[int] = None, 94) -> np.ndarray: 95 """Compute edge features from affinity map. 96 97 Args: 98 rag: The region adjacency graph. 99 segmentation: The over-segmentation used to construct the RAG. 100 affinity_map: The affinity map. 101 offsets: The offsets corresponding to the affinity channels. 102 min_value: Deprecated; ignored. 103 max_value: Deprecated; ignored. 104 n_threads: The number of threads used, set to cpu count by default. 105 106 Returns: 107 The edge features. Output has 12 columns 108 (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size). 109 """ 110 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 111 if segmentation.shape != affinity_map.shape[1:]: 112 raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(affinity_map.shape[1:]))) 113 if len(offsets) != affinity_map.shape[0]: 114 raise ValueError("Incompatible number of channels and offsets: %i, %i" % (len(offsets), 115 affinity_map.shape[0])) 116 features = bic.graph.features.affinity_features_complex( 117 rag, segmentation, affinity_map, offsets, number_of_threads=n_threads, 118 ) 119 return features 120 121 122def compute_boundary_mean_and_length( 123 rag, segmentation: np.ndarray, input_: np.ndarray, n_threads: Optional[int] = None, 124) -> np.ndarray: 125 """Compute mean value and length of boundaries. 126 127 Args: 128 rag: The region adjacency graph. 129 segmentation: The over-segmentation used to construct the RAG. 130 input_: The input map. 131 n_threads: The number of threads used, set to cpu count by default. 132 133 Returns: 134 The edge features with two columns (mean, size). 135 """ 136 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 137 if segmentation.shape != input_.shape: 138 raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(input_.shape))) 139 features = bic.graph.features.edge_map_features( 140 rag, segmentation, input_, number_of_threads=n_threads, 141 ) 142 return features 143 144 145# TODO generalize and move to elf.features.parallel 146def _filter_2d(input_, filter_name, sigma, n_threads): 147 def _fz(inp): 148 response = _apply_filter(filter_name, inp, sigma) 149 # we add a channel last axis for 2d filter responses 150 if response.ndim == 2: 151 response = response[None, ..., None] 152 elif response.ndim == 3: 153 response = response[None] 154 else: 155 raise RuntimeError("Invalid filter response") 156 return response 157 158 with futures.ThreadPoolExecutor(n_threads) as tp: 159 tasks = [tp.submit(_fz, input_[z]) for z in range(input_.shape[0])] 160 response = [t.result() for t in tasks] 161 162 response = np.concatenate(response, axis=0) 163 return response 164 165 166def compute_boundary_features_with_filters( 167 rag, 168 segmentation: np.ndarray, 169 input_: np.ndarray, 170 apply_2d: bool = False, 171 n_threads: Optional[int] = None, 172 filters: Dict[str, List[float]] = {"gaussianSmoothing": [1.6, 4.2, 8.3], 173 "laplacianOfGaussian": [1.6, 4.2, 8.3], 174 "hessianOfGaussianEigenvalues": [1.6, 4.2, 8.3]} 175) -> np.ndarray: 176 """Compute boundary features accumulated over filter responses on input. 177 178 Args: 179 rag: The region adjacency graph. 180 segmentation: The over-segmentation used to construct the RAG. 181 input_: The input data. 182 apply_2d: Whether to apply the filters in 2d for 3d input data. 183 n_threads: The number of threads. 184 filters: The filters to apply, expects a dictionary mapping filter names to sigma values. 185 186 Returns: 187 The edge features. 188 """ 189 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 190 features = [] 191 192 # apply 2d: we compute filters and derived features in parallel per filter 193 if apply_2d: 194 195 def _compute_2d(filter_name, sigma): 196 response = _filter_2d(input_, filter_name, sigma, n_threads) 197 assert response.ndim == 4 198 n_channels = response.shape[-1] 199 feats = [] 200 for chan in range(n_channels): 201 chan_data = response[..., chan] 202 feats.append(compute_boundary_features(rag, segmentation, chan_data, n_threads=n_threads)) 203 204 out = np.concatenate(feats, axis=1) 205 assert len(out) == rag.number_of_edges 206 return out 207 208 features = [_compute_2d(filter_name, sigma) 209 for filter_name, sigmas in filters.items() for sigma in sigmas] 210 211 # apply 3d: we parallelize over the whole filter + feature computation 212 # this can be very memory intensive, and it would be better to parallelize inside 213 # of the loop, but 3d parallel filters in elf.parallel.filters are not working properly yet 214 else: 215 216 def _compute_3d(filter_name, sigma): 217 response = _apply_filter(filter_name, input_, sigma) 218 if response.ndim == input_.ndim: 219 response = response[..., None] 220 221 n_channels = response.shape[-1] 222 feats = [] 223 224 for chan in range(n_channels): 225 chan_data = response[..., chan] 226 feats.append(compute_boundary_features(rag, segmentation, chan_data, n_threads=1)) 227 out = np.concatenate(feats, axis=1) 228 assert len(out) == rag.number_of_edges, f"{len(out), {rag.number_of_edges}}" 229 return out 230 231 with futures.ThreadPoolExecutor(n_threads) as tp: 232 tasks = [tp.submit(_compute_3d, filter_name, sigma) 233 for filter_name, sigmas in filters.items() for sigma in sigmas] 234 features = [t.result() for t in tasks] 235 236 features = np.concatenate(features, axis=1) 237 assert len(features) == rag.number_of_edges 238 return features 239 240 241# Intensity statistics that skimage.measure.regionprops does not provide natively. 242# Each callback receives the region's cropped (regionmask, intensity_image); see 243# `_region_features`. The function names double as the keys in the regionprops table. 244def _quantiles(regionmask, intensity_image): 245 """@private""" 246 return np.percentile(intensity_image[regionmask], [0, 10, 25, 50, 75, 90, 100]) 247 248 249def _kurtosis(regionmask, intensity_image): 250 """@private""" 251 values = intensity_image[regionmask] 252 if values.size < 2 or values.min() == values.max(): 253 return 0.0 254 return kurtosis(values) 255 256 257def _skewness(regionmask, intensity_image): 258 """@private""" 259 values = intensity_image[regionmask] 260 if values.size < 2 or values.min() == values.max(): 261 return 0.0 262 return skew(values) 263 264 265def _variance(regionmask, intensity_image): 266 """@private""" 267 return np.var(intensity_image[regionmask]) 268 269 270def _sum(regionmask, intensity_image): 271 """@private""" 272 return intensity_image[regionmask].sum() 273 274 275# Map vigra `extractRegionFeatures` names to their source in a skimage regionprops table. 276# Names starting with "_" are computed via the extra-property callbacks above; the rest are 277# native regionprops properties (array-valued ones are expanded into "<name>-<i>" columns). 278_REGION_FEATURE_KEYS = { 279 "Count": "num_pixels", 280 "Maximum": "intensity_max", 281 "Minimum": "intensity_min", 282 "mean": "intensity_mean", 283 "RegionCenter": "centroid", 284 "Weighted<RegionCenter>": "centroid_weighted", 285 "RegionRadii": "inertia_tensor_eigvals", 286 "Quantiles": "_quantiles", 287 "Kurtosis": "_kurtosis", 288 "Skewness": "_skewness", 289 "Variance": "_variance", 290 "Sum": "_sum", 291} 292_REGION_FEATURE_EXTRA = { 293 "_quantiles": _quantiles, 294 "_kurtosis": _kurtosis, 295 "_skewness": _skewness, 296 "_variance": _variance, 297 "_sum": _sum, 298} 299 300 301def _region_features(input_map: np.ndarray, segmentation: np.ndarray, feature_names: List[str]) -> Dict: 302 """@private 303 304 Replacement for ``vigra.analysis.extractRegionFeatures`` based on 305 ``skimage.measure.regionprops``. Returns a dict mapping each requested feature name to a 306 dense array indexed by label id (``0 .. segmentation.max()``); scalar features are 1D and 307 coordinate/quantile/radii features are 2D, matching the vigra layout. Missing label ids 308 (gaps) stay zero. 309 """ 310 if segmentation.dtype.kind not in "iu": 311 segmentation = segmentation.astype("int64") 312 keys = [_REGION_FEATURE_KEYS[name] for name in feature_names] 313 native = tuple(dict.fromkeys(key for key in keys if not key.startswith("_"))) 314 extra = tuple(dict.fromkeys(_REGION_FEATURE_EXTRA[key] for key in keys if key.startswith("_"))) 315 316 # skimage treats label 0 as background; shift by 1 so the original label 0 is included. 317 table = regionprops_table( 318 segmentation + 1, intensity_image=input_map.astype("float32", copy=False), 319 properties=("label",) + native, extra_properties=(extra or None), 320 ) 321 labels = np.asarray(table["label"]) - 1 322 n_nodes = int(segmentation.max()) + 1 323 324 def _gather(base): 325 if base in table: 326 return np.asarray(table[base], dtype="float32")[:, None] 327 cols, i = [], 0 328 while f"{base}-{i}" in table: 329 cols.append(np.asarray(table[f"{base}-{i}"], dtype="float32")) 330 i += 1 331 return np.stack(cols, axis=1) 332 333 result = {} 334 for name, base in zip(feature_names, keys): 335 cols = _gather(base) 336 if name == "RegionRadii": # vigra returns radii = sqrt of the coordinate-covariance eigenvalues 337 cols = np.sqrt(np.maximum(cols, 0.0)) 338 dense = np.zeros((n_nodes, cols.shape[1]), dtype="float32") 339 dense[labels] = cols 340 result[name] = dense[:, 0] if dense.shape[1] == 1 else dense 341 return result 342 343 344def compute_region_features( 345 uv_ids: np.ndarray, 346 input_map: np.ndarray, 347 segmentation: np.ndarray, 348 n_threads: Optional[int] = None 349) -> np.ndarray: 350 """Compute edge features from an input map accumulated over segmentation and mapped to edges. 351 352 Args: 353 uv_ids: The edge uv ids. 354 input_: The input data. 355 segmentation: The segmentation. 356 n_threads: The number of threads used, set to cpu count by default. 357 358 Returns: 359 The edge features. 360 """ 361 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 362 363 # compute the node features 364 stat_feature_names = ["Count", "Kurtosis", "Maximum", "Minimum", "Quantiles", 365 "RegionRadii", "Skewness", "Sum", "Variance"] 366 coord_feature_names = ["Weighted<RegionCenter>", "RegionCenter"] 367 feature_names = stat_feature_names + coord_feature_names 368 node_features = _region_features(input_map, segmentation, feature_names) 369 370 # get the image statistics based features, that are combined via [min, max, sum, absdiff] 371 stat_features = [node_features[fname] for fname in stat_feature_names] 372 stat_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat 373 for feat in stat_features], axis=1) 374 375 # get the coordinate based features, that are combined via euclidean distance 376 coord_features = [node_features[fname] for fname in coord_feature_names] 377 coord_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat 378 for feat in coord_features], axis=1) 379 380 u, v = uv_ids[:, 0], uv_ids[:, 1] 381 382 # combine the stat features for all edges 383 feats_u, feats_v = stat_features[u], stat_features[v] 384 features = [np.minimum(feats_u, feats_v), np.maximum(feats_u, feats_v), 385 np.abs(feats_u - feats_v), feats_u + feats_v] 386 387 # combine the coord features for all edges 388 feats_u, feats_v = coord_features[u], coord_features[v] 389 features.append((feats_u - feats_v) ** 2) 390 391 features = np.nan_to_num(np.concatenate(features, axis=1)) 392 assert len(features) == len(uv_ids) 393 return features 394 395 396# 397# Grid Graph and Features 398# 399 400def compute_grid_graph(shape: Tuple[int, ...]): 401 """Compute grid graph for the given shape. 402 403 Args: 404 shape: The shape of the data. 405 406 Returns: 407 The grid graph. 408 """ 409 return bic.graph.grid_graph(shape) 410 411 412def _nn_offsets(ndim): 413 return [[-1 if i == d else 0 for i in range(ndim)] for d in range(ndim)] 414 415 416def _apply_strides(edges, weights, strides, randomize_strides): 417 """Subsample (edges, weights) along the spatial periodicity defined by `strides`. 418 419 Mirrors the behaviour of nifty's strides/randomize_strides parameter without 420 spatial information: we simply keep one out of every `prod(strides)` entries 421 (or a random subset of the same size if `randomize_strides` is True). 422 """ 423 if strides is None: 424 return edges, weights 425 keep = int(np.prod(strides)) 426 if keep <= 1: 427 return edges, weights 428 n = len(edges) 429 if randomize_strides: 430 idx = np.random.choice(n, size=max(1, n // keep), replace=False) 431 idx.sort() 432 else: 433 idx = np.arange(0, n, keep) 434 return edges[idx], weights[idx] 435 436 437def compute_grid_graph_image_features( 438 grid_graph, 439 image: np.ndarray, 440 mode: str, 441 offsets: Optional[List[List[int]]] = None, 442 strides: Optional[List[int]] = None, 443 randomize_strides: bool = False, 444) -> Tuple[np.ndarray, np.ndarray]: 445 """Compute edge features for image for the given grid_graph. 446 447 Args: 448 grid_graph: The grid graph. 449 image: The image, from which the features will be derived. 450 mode: Feature accumulation method. For multi-channel images, one of 451 "l1", "l2", "cosine". For scalar images (without channels) only 452 grid-boundary averaging is supported (any mode value is accepted). 453 offsets: The offsets, which correspond to the affinity channels. 454 strides: The strides used to subsample edges that are computed from offsets. 455 randomize_strides: Whether to subsample randomly instead of using regular strides. 456 457 Returns: 458 The uv ids of the edges. 459 The edge features. 460 """ 461 gndim = len(grid_graph.shape) 462 463 if image.ndim == gndim: 464 if offsets is not None: 465 raise NotImplementedError("Offsets with scalar images are not supported.") 466 weights = bic.graph.features.grid_boundary_features(grid_graph, image.astype("float32")) 467 edges = grid_graph.uv_ids() 468 return edges, weights 469 470 if image.ndim != gndim + 1: 471 raise ValueError(f"Invalid image dimension {image.ndim}, expected {gndim} or {gndim + 1}") 472 473 modes = ("l1", "l2", "cosine") 474 if mode not in modes: 475 raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}") 476 477 if offsets is None: 478 # Compute affinities between adjacent pixels using nearest-neighbor offsets. 479 nn_offs = _nn_offsets(gndim) 480 affs = bic.affinities.compute_embedding_distances( 481 image.astype("float32"), nn_offs, norm=mode, 482 ) 483 weights, _valid = bic.graph.features.grid_affinity_features(grid_graph, affs, nn_offs) 484 edges = grid_graph.uv_ids() 485 return edges, weights 486 487 # General path with arbitrary offsets: compute affinities then use _with_lifted. 488 affs = bic.affinities.compute_embedding_distances( 489 image.astype("float32"), offsets, norm=mode, 490 ) 491 local_w, local_valid, lifted_uvs, lifted_w, _ = bic.graph.features.grid_affinity_features_with_lifted( 492 grid_graph, affs, offsets, 493 ) 494 edges = np.concatenate([grid_graph.uv_ids()[local_valid], lifted_uvs], axis=0) 495 weights = np.concatenate([local_w[local_valid], lifted_w], axis=0) 496 return _apply_strides(edges, weights, strides, randomize_strides) 497 498 499def compute_grid_graph_affinity_features( 500 grid_graph, 501 affinities: np.ndarray, 502 offsets: Optional[List[List[int]]] = None, 503 strides: Optional[List[int]] = None, 504 mask: Optional[np.ndarray] = None, 505 randomize_strides: bool = False, 506) -> Tuple[np.ndarray, np.ndarray]: 507 """Compute edge features from affinities for the given grid graph. 508 509 Args: 510 grid_graph: The grid graph. 511 affinities: The affinity map. 512 offsets: The offsets, which correspond to the affinity channels. 513 strides: The strides used to subsample edges that are computed from offsets. 514 mask: Mask to exclude from the edge and feature computation. 515 randomize_strides: Whether to subsample randomly instead of using regular strides. 516 517 Returns: 518 The uv ids of the edges. 519 The edge features. 520 """ 521 gndim = len(grid_graph.shape) 522 if affinities.ndim != gndim + 1: 523 raise ValueError("affinities must have shape (channels, *grid_graph.shape)") 524 525 if offsets is None: 526 assert affinities.shape[0] == gndim 527 assert strides is None 528 assert mask is None 529 nn_offs = _nn_offsets(gndim) 530 weights, _valid = bic.graph.features.grid_affinity_features(grid_graph, affinities, nn_offs) 531 edges = grid_graph.uv_ids() 532 return edges, weights 533 534 local_w, local_valid, lifted_uvs, lifted_w, _ = bic.graph.features.grid_affinity_features_with_lifted( 535 grid_graph, affinities, offsets, 536 ) 537 edges = np.concatenate([grid_graph.uv_ids()[local_valid], lifted_uvs], axis=0) 538 weights = np.concatenate([local_w[local_valid], lifted_w], axis=0) 539 540 if mask is not None: 541 assert strides is None and not randomize_strides, "Strides and mask cannot be used at the same time" 542 shape = tuple(grid_graph.shape) 543 assert mask.shape == shape, ( 544 "compute_grid_graph_affinity_features with a per-pixel mask expects mask.shape == grid_graph.shape; " 545 "per-channel edge masks are only supported on legacy nifty grid graphs." 546 ) 547 node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape) 548 masked_ids = node_ids[~mask] 549 edge_state = np.isin(edges, masked_ids).sum(axis=1) 550 keep = edge_state != 2 551 edges, weights = edges[keep], weights[keep] 552 return edges, weights 553 554 return _apply_strides(edges, weights, strides, randomize_strides) 555 556 557def apply_mask_to_grid_graph_weights( 558 grid_graph, 559 mask: np.ndarray, 560 weights: np.ndarray, 561 masked_edge_weight: float = 0.0, 562 transition_edge_weight: float = 1.0, 563) -> np.ndarray: 564 """Mask edges in grid graph. 565 566 Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes 567 and edges that connect masked and unmasked nodes. 568 569 Args: 570 grid_graph: The grid graph. 571 mask: The binary mask, foreground (=non-masked) is True. 572 weights: The edge weights. 573 masked_edge_weight: The value for edges that connect two masked nodes. 574 transition_edge_weight: The value for edges that connect a masked with a non-masked node. 575 576 Returns: 577 The masked edge weights. 578 """ 579 assert np.dtype(mask.dtype) == np.dtype("bool") 580 shape = tuple(grid_graph.shape) 581 assert mask.shape == shape, f"{mask.shape}, {shape}" 582 node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape) 583 masked_ids = node_ids[~mask] 584 585 edges = grid_graph.uv_ids() 586 assert len(edges) == len(weights) 587 edge_state = np.isin(edges, masked_ids).sum(axis=1) 588 masked_edges = edge_state == 2 589 transition_edges = edge_state == 1 590 weights[masked_edges] = masked_edge_weight 591 weights[transition_edges] = transition_edge_weight 592 return weights 593 594 595def apply_mask_to_grid_graph_edges_and_weights( 596 grid_graph, mask: np.ndarray, edges: np.ndarray, weights: np.ndarray, transition_edge_weight: float = 1.0 597) -> Tuple[np.ndarray, np.ndarray]: 598 """Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value. 599 600 Args: 601 grid_graph: The grid graph. 602 mask: The binary mask, foreground (=non-masked) is True. 603 edges: The edges (uv-ids). 604 weights: The edge weights. 605 transition_edge_weight: The value for edges that connect a masked with a non-masked node. 606 607 Returns: 608 The edge uv-ids. 609 The edge weights. 610 """ 611 assert np.dtype(mask.dtype) == np.dtype("bool") 612 shape = tuple(grid_graph.shape) 613 assert mask.shape == shape, f"{mask.shape}, {shape}" 614 node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape) 615 masked_ids = node_ids[~mask] 616 617 edge_state = np.isin(edges, masked_ids).sum(axis=1) 618 keep_edges = edge_state != 2 619 620 edges, weights, edge_state = edges[keep_edges], weights[keep_edges], edge_state[keep_edges] 621 transition_edges = edge_state == 1 622 weights[transition_edges] = transition_edge_weight 623 624 return edges, weights 625 626 627# 628# Lifted Features 629# 630 631def lifted_edges_from_graph_neighborhood(graph, max_graph_distance): 632 """@private 633 """ 634 if max_graph_distance < 2: 635 raise ValueError(f"Graph distance must be greater equal 2, got {max_graph_distance}") 636 # With all-zero node_labels and mode='all', every node pair within the BFS hop window 637 # [2, max_graph_distance] is returned (base-graph edges excluded). 638 node_labels = np.zeros(graph.number_of_nodes, dtype="uint64") 639 lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels( 640 graph, node_labels, graph_depth=max_graph_distance, mode="all", 641 ) 642 return lifted_uvs 643 644 645def feats_to_costs_default(lifted_labels, lifted_features): 646 """@private 647 """ 648 # we assume that we only have different classes for a given lifted 649 # edge here (mode = "different") and then set all edges to be repulsive 650 651 # the higher the class probability, the more repulsive the edges should be, 652 # so we just multiply both probabilities 653 lifted_costs = lifted_features[:, 0] * lifted_features[:, 1] 654 lifted_costs = transform_probabilities_to_costs(lifted_costs) 655 return lifted_costs 656 657 658def lifted_problem_from_probabilities( 659 rag, 660 watershed: np.ndarray, 661 input_maps: List[np.ndarray], 662 assignment_threshold: float, 663 graph_depth: int, 664 feats_to_costs: callable = feats_to_costs_default, 665 mode: str = "different", 666 n_threads: Optional[int] = None, 667) -> Tuple[np.ndarray, np.ndarray]: 668 """Compute lifted problem from probability maps by mapping them to superpixels. 669 670 Args: 671 rag: The region adjacency graph. 672 watershed: The watershed over-segmentation. 673 input_maps: List of probability maps. Each map must have the same shape as the watersheds. 674 assignment_threshold: Minimal expression level to assign a class to a graph node. 675 graph_depth: Maximal graph depth up to which lifted edges will be included. 676 feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities. 677 mode: The mode for insertion of lifted edges. One of "all", "different", "same". 678 n_threads: The number of threads used for the calculation. 679 680 Returns: 681 The lifted uv ids. 682 The lifted costs. 683 """ 684 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 685 assert isinstance(input_maps, (list, tuple)) 686 assert all(isinstance(inp, np.ndarray) for inp in input_maps) 687 shape = watershed.shape 688 assert all(inp.shape == shape for inp in input_maps) 689 690 n_nodes = int(watershed.max()) + 1 691 node_labels = np.zeros(n_nodes, dtype="uint64") 692 node_features = np.zeros(n_nodes, dtype="float32") 693 for class_id, inp in enumerate(input_maps): 694 mean_prob = _region_features(inp, watershed, ["mean"])["mean"] 695 class_mask = mean_prob > assignment_threshold 696 node_labels[class_mask] = class_id 697 node_features[class_mask] = mean_prob[class_mask] 698 699 lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels( 700 rag, node_labels, graph_depth=graph_depth, mode=mode, 701 ignore_label=0, number_of_threads=n_threads, 702 ) 703 lifted_labels = node_labels[lifted_uvs] 704 lifted_features = node_features[lifted_uvs] 705 706 lifted_costs = feats_to_costs(lifted_labels, lifted_features) 707 return lifted_uvs, lifted_costs 708 709 710def lifted_problem_from_segmentation( 711 rag, 712 watershed: np.ndarray, 713 input_segmentation: np.ndarray, 714 overlap_threshold: float, 715 graph_depth: int, 716 same_segment_cost: float, 717 different_segment_cost: float, 718 mode: str = "all", 719 n_threads: Optional[int] = None, 720) -> Tuple[np.ndarray, np.ndarray]: 721 """Compute lifted problem from segmentation by mapping segments to superpixels. 722 723 Args: 724 rag: The region adjacency graph. 725 watershed: The watershed over-segmentation. 726 input_segmentation: The segmentation used to determine node attribution. 727 overlap_threshold: The minimal overlap to assign a segment id to node. 728 graph_depth: The maximal graph depth up to which lifted edges will be included. 729 same_segment_cost: The cost for edges between nodes with same segment id attribution. 730 different_segment_cost: The cost for edges between nodes with different segment id attribution. 731 mode: The mode for insertion of lifted edges. One of "all", "different", "same". 732 n_threads: The number of threads used for the calculation. 733 734 Returns: 735 The lifted uv ids. 736 The lifted costs. 737 """ 738 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 739 assert input_segmentation.shape == watershed.shape 740 741 ovlp = bic.utils.segmentation_overlap(watershed, input_segmentation) 742 ws_ids = np.unique(watershed) 743 n_labels = int(ws_ids[-1]) + 1 744 assert n_labels == rag.number_of_nodes, "%i, %i" % (n_labels, rag.number_of_nodes) 745 746 node_labels = np.zeros(n_labels, dtype="uint64") 747 node_label_vals = np.zeros(len(ws_ids), dtype="uint64") 748 overlap_values = np.zeros(len(ws_ids), dtype="float64") 749 for i, ws_id in enumerate(ws_ids): 750 best = ovlp.best_overlap_for_label_a(int(ws_id), ignore_zero=False) 751 node_label_vals[i] = best.label 752 overlap_values[i] = best.fraction 753 node_label_vals[overlap_values < overlap_threshold] = 0 754 node_labels[ws_ids] = node_label_vals 755 756 lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels( 757 rag, node_labels, graph_depth=graph_depth, mode=mode, 758 ignore_label=0, number_of_threads=n_threads, 759 ) 760 assert lifted_uvs.max() < rag.number_of_nodes, "%i, %i" % (int(lifted_uvs.max()), rag.number_of_nodes) 761 lifted_labels = node_labels[lifted_uvs] 762 lifted_costs = np.zeros(len(lifted_labels), dtype="float64") 763 764 same_mask = lifted_labels[:, 0] == lifted_labels[:, 1] 765 lifted_costs[same_mask] = same_segment_cost 766 lifted_costs[~same_mask] = different_segment_cost 767 768 return lifted_uvs, lifted_costs 769 770 771# 772# Misc 773# 774 775def get_stitch_edges( 776 rag, 777 seg: np.ndarray, 778 block_shape: Tuple[int, ...], 779 n_threads: Optional[int] = None, 780 verbose: bool = False 781) -> np.ndarray: 782 """Get the edges between blocks. 783 784 Args: 785 rag: The region adjacency graph. 786 seg: The segmentation underlying the rag. 787 block_shape: The shape of the blocking. 788 n_threads: The number of threads used for the calculation. 789 verbose: Whether to be verbose. 790 791 Returns: 792 The edge mask indicating edges between blocks. 793 """ 794 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 795 ndim = seg.ndim 796 blocking = bic.utils.Blocking([0] * ndim, list(seg.shape), list(block_shape)) 797 798 def find_stitch_edges(block_id): 799 stitch_edges = [] 800 block = blocking.get_block(block_id) 801 for axis in range(ndim): 802 if blocking.get_neighbor_id(block_id, axis, True) == -1: 803 continue 804 face_a = tuple( 805 beg if d == axis else slice(beg, end) 806 for d, beg, end in zip(range(ndim), block.begin, block.end) 807 ) 808 face_b = tuple( 809 beg - 1 if d == axis else slice(beg, end) 810 for d, beg, end in zip(range(ndim), block.begin, block.end) 811 ) 812 813 labels_a = seg[face_a].ravel() 814 labels_b = seg[face_b].ravel() 815 816 uv_ids = np.concatenate( 817 [labels_a[:, None], labels_b[:, None]], 818 axis=1 819 ) 820 uv_ids = np.unique(uv_ids, axis=0) 821 822 edge_ids = rag.find_edges(uv_ids) 823 edge_ids = edge_ids[edge_ids != -1] 824 stitch_edges.append(edge_ids) 825 826 if stitch_edges: 827 stitch_edges = np.concatenate(stitch_edges) 828 stitch_edges = np.unique(stitch_edges) 829 else: 830 stitch_edges = None 831 return stitch_edges 832 833 with futures.ThreadPoolExecutor(n_threads) as tp: 834 if verbose: 835 stitch_edges = list(tqdm( 836 tp.map(find_stitch_edges, range(blocking.number_of_blocks)), 837 total=blocking.number_of_blocks 838 )) 839 else: 840 stitch_edges = tp.map(find_stitch_edges, range(blocking.number_of_blocks)) 841 842 stitch_edges = np.concatenate([st for st in stitch_edges if st is not None]) 843 stitch_edges = np.unique(stitch_edges) 844 full_edges = np.zeros(rag.number_of_edges, dtype="bool") 845 full_edges[stitch_edges] = 1 846 return full_edges 847 848 849def project_node_labels_to_pixels( 850 rag, segmentation: np.ndarray, node_labels: np.ndarray, n_threads: Optional[int] = None, 851) -> np.ndarray: 852 """Project label values for graph nodes back to pixels to obtain segmentation. 853 854 Args: 855 rag: The region adjacency graph. 856 segmentation: The over-segmentation used to construct the RAG. 857 node_labels: The array with node labels. 858 n_threads: The number of threads used, set to cpu count by default. 859 860 Returns: 861 The segmentation. 862 """ 863 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 864 if len(node_labels) != rag.number_of_nodes: 865 raise ValueError("Incompatible number of node labels: %i, %i" % (len(node_labels), rag.number_of_nodes)) 866 # bic.graph.project_node_labels_to_pixels requires integer dtypes for both arrays. 867 if segmentation.dtype not in (np.uint32, np.uint64, np.int32, np.int64): 868 segmentation = segmentation.astype("uint64") 869 if node_labels.dtype not in (np.uint32, np.uint64, np.int32, np.int64): 870 node_labels = node_labels.astype("uint64") 871 seg = bic.graph.project_node_labels_to_pixels(rag, segmentation, node_labels, number_of_threads=n_threads) 872 return seg 873 874 875def compute_z_edge_mask(rag, watershed: np.ndarray) -> np.ndarray: 876 """Compute edge mask of in-between plane edges for flat superpixels. 877 878 Args: 879 rag: The region adjacency graph. 880 watershed: The underlying watershed over-segmentation (superpixels). 881 882 Returns: 883 The edge mask indicating in-between slice edges. 884 """ 885 node_z_coords = np.zeros(rag.number_of_nodes, dtype="uint32") 886 for z in range(watershed.shape[0]): 887 node_z_coords[watershed[z]] = z 888 uv_ids = rag.uv_ids() 889 z_edge_mask = node_z_coords[uv_ids[:, 0]] != node_z_coords[uv_ids[:, 1]] 890 return z_edge_mask
38def compute_rag(segmentation: np.ndarray, n_labels: Optional[int] = None, n_threads: Optional[int] = None): 39 """Compute region adjacency graph of segmentation. 40 41 Args: 42 segmentation: The segmentation. 43 n_labels: Deprecated; ignored. Kept for backwards-compatibility. 44 n_threads: The number of threads used, set to cpu count by default. 45 46 Returns: 47 The region adjacency graph (`bioimage_cpp.graph.RegionAdjacencyGraph`). 48 """ 49 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 50 if segmentation.dtype not in (np.uint32, np.uint64, np.int32, np.int64): 51 segmentation = segmentation.astype("uint32") 52 rag = bic.graph.region_adjacency_graph(segmentation, number_of_threads=n_threads) 53 return rag
Compute region adjacency graph of segmentation.
Arguments:
- segmentation: The segmentation.
- n_labels: Deprecated; ignored. Kept for backwards-compatibility.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The region adjacency graph (
bioimage_cpp.graph.RegionAdjacencyGraph).
56def compute_boundary_features( 57 rag, 58 segmentation: np.ndarray, 59 boundary_map: np.ndarray, 60 min_value: float = 0.0, # noqa: ARG001 — deprecated, ignored 61 max_value: float = 1.0, # noqa: ARG001 — deprecated, ignored 62 n_threads: Optional[int] = None, 63) -> np.ndarray: 64 """Compute edge features from boundary map. 65 66 Args: 67 rag: The region adjacency graph. 68 segmentation: The over-segmentation used to construct the RAG. 69 boundary_map: The boundary map. 70 min_value: Deprecated; ignored. 71 max_value: Deprecated; ignored. 72 n_threads: The number of threads used, set to cpu count by default. 73 74 Returns: 75 The edge features. Output has 12 columns 76 (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size). 77 """ 78 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 79 if segmentation.shape != boundary_map.shape: 80 raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(boundary_map.shape))) 81 features = bic.graph.features.edge_map_features_complex( 82 rag, segmentation, boundary_map, number_of_threads=n_threads, 83 ) 84 return features
Compute edge features from boundary map.
Arguments:
- rag: The region adjacency graph.
- segmentation: The over-segmentation used to construct the RAG.
- boundary_map: The boundary map.
- min_value: Deprecated; ignored.
- max_value: Deprecated; ignored.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The edge features. Output has 12 columns (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size).
87def compute_affinity_features( 88 rag, 89 segmentation: np.ndarray, 90 affinity_map: np.ndarray, 91 offsets: List[List[int]], 92 min_value: float = 0.0, # noqa: ARG001 — deprecated, ignored 93 max_value: float = 1.0, # noqa: ARG001 — deprecated, ignored 94 n_threads: Optional[int] = None, 95) -> np.ndarray: 96 """Compute edge features from affinity map. 97 98 Args: 99 rag: The region adjacency graph. 100 segmentation: The over-segmentation used to construct the RAG. 101 affinity_map: The affinity map. 102 offsets: The offsets corresponding to the affinity channels. 103 min_value: Deprecated; ignored. 104 max_value: Deprecated; ignored. 105 n_threads: The number of threads used, set to cpu count by default. 106 107 Returns: 108 The edge features. Output has 12 columns 109 (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size). 110 """ 111 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 112 if segmentation.shape != affinity_map.shape[1:]: 113 raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(affinity_map.shape[1:]))) 114 if len(offsets) != affinity_map.shape[0]: 115 raise ValueError("Incompatible number of channels and offsets: %i, %i" % (len(offsets), 116 affinity_map.shape[0])) 117 features = bic.graph.features.affinity_features_complex( 118 rag, segmentation, affinity_map, offsets, number_of_threads=n_threads, 119 ) 120 return features
Compute edge features from affinity map.
Arguments:
- rag: The region adjacency graph.
- segmentation: The over-segmentation used to construct the RAG.
- affinity_map: The affinity map.
- offsets: The offsets corresponding to the affinity channels.
- min_value: Deprecated; ignored.
- max_value: Deprecated; ignored.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The edge features. Output has 12 columns (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size).
123def compute_boundary_mean_and_length( 124 rag, segmentation: np.ndarray, input_: np.ndarray, n_threads: Optional[int] = None, 125) -> np.ndarray: 126 """Compute mean value and length of boundaries. 127 128 Args: 129 rag: The region adjacency graph. 130 segmentation: The over-segmentation used to construct the RAG. 131 input_: The input map. 132 n_threads: The number of threads used, set to cpu count by default. 133 134 Returns: 135 The edge features with two columns (mean, size). 136 """ 137 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 138 if segmentation.shape != input_.shape: 139 raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(input_.shape))) 140 features = bic.graph.features.edge_map_features( 141 rag, segmentation, input_, number_of_threads=n_threads, 142 ) 143 return features
Compute mean value and length of boundaries.
Arguments:
- rag: The region adjacency graph.
- segmentation: The over-segmentation used to construct the RAG.
- input_: The input map.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The edge features with two columns (mean, size).
167def compute_boundary_features_with_filters( 168 rag, 169 segmentation: np.ndarray, 170 input_: np.ndarray, 171 apply_2d: bool = False, 172 n_threads: Optional[int] = None, 173 filters: Dict[str, List[float]] = {"gaussianSmoothing": [1.6, 4.2, 8.3], 174 "laplacianOfGaussian": [1.6, 4.2, 8.3], 175 "hessianOfGaussianEigenvalues": [1.6, 4.2, 8.3]} 176) -> np.ndarray: 177 """Compute boundary features accumulated over filter responses on input. 178 179 Args: 180 rag: The region adjacency graph. 181 segmentation: The over-segmentation used to construct the RAG. 182 input_: The input data. 183 apply_2d: Whether to apply the filters in 2d for 3d input data. 184 n_threads: The number of threads. 185 filters: The filters to apply, expects a dictionary mapping filter names to sigma values. 186 187 Returns: 188 The edge features. 189 """ 190 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 191 features = [] 192 193 # apply 2d: we compute filters and derived features in parallel per filter 194 if apply_2d: 195 196 def _compute_2d(filter_name, sigma): 197 response = _filter_2d(input_, filter_name, sigma, n_threads) 198 assert response.ndim == 4 199 n_channels = response.shape[-1] 200 feats = [] 201 for chan in range(n_channels): 202 chan_data = response[..., chan] 203 feats.append(compute_boundary_features(rag, segmentation, chan_data, n_threads=n_threads)) 204 205 out = np.concatenate(feats, axis=1) 206 assert len(out) == rag.number_of_edges 207 return out 208 209 features = [_compute_2d(filter_name, sigma) 210 for filter_name, sigmas in filters.items() for sigma in sigmas] 211 212 # apply 3d: we parallelize over the whole filter + feature computation 213 # this can be very memory intensive, and it would be better to parallelize inside 214 # of the loop, but 3d parallel filters in elf.parallel.filters are not working properly yet 215 else: 216 217 def _compute_3d(filter_name, sigma): 218 response = _apply_filter(filter_name, input_, sigma) 219 if response.ndim == input_.ndim: 220 response = response[..., None] 221 222 n_channels = response.shape[-1] 223 feats = [] 224 225 for chan in range(n_channels): 226 chan_data = response[..., chan] 227 feats.append(compute_boundary_features(rag, segmentation, chan_data, n_threads=1)) 228 out = np.concatenate(feats, axis=1) 229 assert len(out) == rag.number_of_edges, f"{len(out), {rag.number_of_edges}}" 230 return out 231 232 with futures.ThreadPoolExecutor(n_threads) as tp: 233 tasks = [tp.submit(_compute_3d, filter_name, sigma) 234 for filter_name, sigmas in filters.items() for sigma in sigmas] 235 features = [t.result() for t in tasks] 236 237 features = np.concatenate(features, axis=1) 238 assert len(features) == rag.number_of_edges 239 return features
Compute boundary features accumulated over filter responses on input.
Arguments:
- rag: The region adjacency graph.
- segmentation: The over-segmentation used to construct the RAG.
- input_: The input data.
- apply_2d: Whether to apply the filters in 2d for 3d input data.
- n_threads: The number of threads.
- filters: The filters to apply, expects a dictionary mapping filter names to sigma values.
Returns:
The edge features.
345def compute_region_features( 346 uv_ids: np.ndarray, 347 input_map: np.ndarray, 348 segmentation: np.ndarray, 349 n_threads: Optional[int] = None 350) -> np.ndarray: 351 """Compute edge features from an input map accumulated over segmentation and mapped to edges. 352 353 Args: 354 uv_ids: The edge uv ids. 355 input_: The input data. 356 segmentation: The segmentation. 357 n_threads: The number of threads used, set to cpu count by default. 358 359 Returns: 360 The edge features. 361 """ 362 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 363 364 # compute the node features 365 stat_feature_names = ["Count", "Kurtosis", "Maximum", "Minimum", "Quantiles", 366 "RegionRadii", "Skewness", "Sum", "Variance"] 367 coord_feature_names = ["Weighted<RegionCenter>", "RegionCenter"] 368 feature_names = stat_feature_names + coord_feature_names 369 node_features = _region_features(input_map, segmentation, feature_names) 370 371 # get the image statistics based features, that are combined via [min, max, sum, absdiff] 372 stat_features = [node_features[fname] for fname in stat_feature_names] 373 stat_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat 374 for feat in stat_features], axis=1) 375 376 # get the coordinate based features, that are combined via euclidean distance 377 coord_features = [node_features[fname] for fname in coord_feature_names] 378 coord_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat 379 for feat in coord_features], axis=1) 380 381 u, v = uv_ids[:, 0], uv_ids[:, 1] 382 383 # combine the stat features for all edges 384 feats_u, feats_v = stat_features[u], stat_features[v] 385 features = [np.minimum(feats_u, feats_v), np.maximum(feats_u, feats_v), 386 np.abs(feats_u - feats_v), feats_u + feats_v] 387 388 # combine the coord features for all edges 389 feats_u, feats_v = coord_features[u], coord_features[v] 390 features.append((feats_u - feats_v) ** 2) 391 392 features = np.nan_to_num(np.concatenate(features, axis=1)) 393 assert len(features) == len(uv_ids) 394 return features
Compute edge features from an input map accumulated over segmentation and mapped to edges.
Arguments:
- uv_ids: The edge uv ids.
- input_: The input data.
- segmentation: The segmentation.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The edge features.
401def compute_grid_graph(shape: Tuple[int, ...]): 402 """Compute grid graph for the given shape. 403 404 Args: 405 shape: The shape of the data. 406 407 Returns: 408 The grid graph. 409 """ 410 return bic.graph.grid_graph(shape)
Compute grid graph for the given shape.
Arguments:
- shape: The shape of the data.
Returns:
The grid graph.
438def compute_grid_graph_image_features( 439 grid_graph, 440 image: np.ndarray, 441 mode: str, 442 offsets: Optional[List[List[int]]] = None, 443 strides: Optional[List[int]] = None, 444 randomize_strides: bool = False, 445) -> Tuple[np.ndarray, np.ndarray]: 446 """Compute edge features for image for the given grid_graph. 447 448 Args: 449 grid_graph: The grid graph. 450 image: The image, from which the features will be derived. 451 mode: Feature accumulation method. For multi-channel images, one of 452 "l1", "l2", "cosine". For scalar images (without channels) only 453 grid-boundary averaging is supported (any mode value is accepted). 454 offsets: The offsets, which correspond to the affinity channels. 455 strides: The strides used to subsample edges that are computed from offsets. 456 randomize_strides: Whether to subsample randomly instead of using regular strides. 457 458 Returns: 459 The uv ids of the edges. 460 The edge features. 461 """ 462 gndim = len(grid_graph.shape) 463 464 if image.ndim == gndim: 465 if offsets is not None: 466 raise NotImplementedError("Offsets with scalar images are not supported.") 467 weights = bic.graph.features.grid_boundary_features(grid_graph, image.astype("float32")) 468 edges = grid_graph.uv_ids() 469 return edges, weights 470 471 if image.ndim != gndim + 1: 472 raise ValueError(f"Invalid image dimension {image.ndim}, expected {gndim} or {gndim + 1}") 473 474 modes = ("l1", "l2", "cosine") 475 if mode not in modes: 476 raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}") 477 478 if offsets is None: 479 # Compute affinities between adjacent pixels using nearest-neighbor offsets. 480 nn_offs = _nn_offsets(gndim) 481 affs = bic.affinities.compute_embedding_distances( 482 image.astype("float32"), nn_offs, norm=mode, 483 ) 484 weights, _valid = bic.graph.features.grid_affinity_features(grid_graph, affs, nn_offs) 485 edges = grid_graph.uv_ids() 486 return edges, weights 487 488 # General path with arbitrary offsets: compute affinities then use _with_lifted. 489 affs = bic.affinities.compute_embedding_distances( 490 image.astype("float32"), offsets, norm=mode, 491 ) 492 local_w, local_valid, lifted_uvs, lifted_w, _ = bic.graph.features.grid_affinity_features_with_lifted( 493 grid_graph, affs, offsets, 494 ) 495 edges = np.concatenate([grid_graph.uv_ids()[local_valid], lifted_uvs], axis=0) 496 weights = np.concatenate([local_w[local_valid], lifted_w], axis=0) 497 return _apply_strides(edges, weights, strides, randomize_strides)
Compute edge features for image for the given grid_graph.
Arguments:
- grid_graph: The grid graph.
- image: The image, from which the features will be derived.
- mode: Feature accumulation method. For multi-channel images, one of "l1", "l2", "cosine". For scalar images (without channels) only grid-boundary averaging is supported (any mode value is accepted).
- offsets: The offsets, which correspond to the affinity channels.
- strides: The strides used to subsample edges that are computed from offsets.
- randomize_strides: Whether to subsample randomly instead of using regular strides.
Returns:
The uv ids of the edges. The edge features.
500def compute_grid_graph_affinity_features( 501 grid_graph, 502 affinities: np.ndarray, 503 offsets: Optional[List[List[int]]] = None, 504 strides: Optional[List[int]] = None, 505 mask: Optional[np.ndarray] = None, 506 randomize_strides: bool = False, 507) -> Tuple[np.ndarray, np.ndarray]: 508 """Compute edge features from affinities for the given grid graph. 509 510 Args: 511 grid_graph: The grid graph. 512 affinities: The affinity map. 513 offsets: The offsets, which correspond to the affinity channels. 514 strides: The strides used to subsample edges that are computed from offsets. 515 mask: Mask to exclude from the edge and feature computation. 516 randomize_strides: Whether to subsample randomly instead of using regular strides. 517 518 Returns: 519 The uv ids of the edges. 520 The edge features. 521 """ 522 gndim = len(grid_graph.shape) 523 if affinities.ndim != gndim + 1: 524 raise ValueError("affinities must have shape (channels, *grid_graph.shape)") 525 526 if offsets is None: 527 assert affinities.shape[0] == gndim 528 assert strides is None 529 assert mask is None 530 nn_offs = _nn_offsets(gndim) 531 weights, _valid = bic.graph.features.grid_affinity_features(grid_graph, affinities, nn_offs) 532 edges = grid_graph.uv_ids() 533 return edges, weights 534 535 local_w, local_valid, lifted_uvs, lifted_w, _ = bic.graph.features.grid_affinity_features_with_lifted( 536 grid_graph, affinities, offsets, 537 ) 538 edges = np.concatenate([grid_graph.uv_ids()[local_valid], lifted_uvs], axis=0) 539 weights = np.concatenate([local_w[local_valid], lifted_w], axis=0) 540 541 if mask is not None: 542 assert strides is None and not randomize_strides, "Strides and mask cannot be used at the same time" 543 shape = tuple(grid_graph.shape) 544 assert mask.shape == shape, ( 545 "compute_grid_graph_affinity_features with a per-pixel mask expects mask.shape == grid_graph.shape; " 546 "per-channel edge masks are only supported on legacy nifty grid graphs." 547 ) 548 node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape) 549 masked_ids = node_ids[~mask] 550 edge_state = np.isin(edges, masked_ids).sum(axis=1) 551 keep = edge_state != 2 552 edges, weights = edges[keep], weights[keep] 553 return edges, weights 554 555 return _apply_strides(edges, weights, strides, randomize_strides)
Compute edge features from affinities for the given grid graph.
Arguments:
- grid_graph: The grid graph.
- affinities: The affinity map.
- offsets: The offsets, which correspond to the affinity channels.
- strides: The strides used to subsample edges that are computed from offsets.
- mask: Mask to exclude from the edge and feature computation.
- randomize_strides: Whether to subsample randomly instead of using regular strides.
Returns:
The uv ids of the edges. The edge features.
558def apply_mask_to_grid_graph_weights( 559 grid_graph, 560 mask: np.ndarray, 561 weights: np.ndarray, 562 masked_edge_weight: float = 0.0, 563 transition_edge_weight: float = 1.0, 564) -> np.ndarray: 565 """Mask edges in grid graph. 566 567 Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes 568 and edges that connect masked and unmasked nodes. 569 570 Args: 571 grid_graph: The grid graph. 572 mask: The binary mask, foreground (=non-masked) is True. 573 weights: The edge weights. 574 masked_edge_weight: The value for edges that connect two masked nodes. 575 transition_edge_weight: The value for edges that connect a masked with a non-masked node. 576 577 Returns: 578 The masked edge weights. 579 """ 580 assert np.dtype(mask.dtype) == np.dtype("bool") 581 shape = tuple(grid_graph.shape) 582 assert mask.shape == shape, f"{mask.shape}, {shape}" 583 node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape) 584 masked_ids = node_ids[~mask] 585 586 edges = grid_graph.uv_ids() 587 assert len(edges) == len(weights) 588 edge_state = np.isin(edges, masked_ids).sum(axis=1) 589 masked_edges = edge_state == 2 590 transition_edges = edge_state == 1 591 weights[masked_edges] = masked_edge_weight 592 weights[transition_edges] = transition_edge_weight 593 return weights
Mask edges in grid graph.
Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes and edges that connect masked and unmasked nodes.
Arguments:
- grid_graph: The grid graph.
- mask: The binary mask, foreground (=non-masked) is True.
- weights: The edge weights.
- masked_edge_weight: The value for edges that connect two masked nodes.
- transition_edge_weight: The value for edges that connect a masked with a non-masked node.
Returns:
The masked edge weights.
596def apply_mask_to_grid_graph_edges_and_weights( 597 grid_graph, mask: np.ndarray, edges: np.ndarray, weights: np.ndarray, transition_edge_weight: float = 1.0 598) -> Tuple[np.ndarray, np.ndarray]: 599 """Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value. 600 601 Args: 602 grid_graph: The grid graph. 603 mask: The binary mask, foreground (=non-masked) is True. 604 edges: The edges (uv-ids). 605 weights: The edge weights. 606 transition_edge_weight: The value for edges that connect a masked with a non-masked node. 607 608 Returns: 609 The edge uv-ids. 610 The edge weights. 611 """ 612 assert np.dtype(mask.dtype) == np.dtype("bool") 613 shape = tuple(grid_graph.shape) 614 assert mask.shape == shape, f"{mask.shape}, {shape}" 615 node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape) 616 masked_ids = node_ids[~mask] 617 618 edge_state = np.isin(edges, masked_ids).sum(axis=1) 619 keep_edges = edge_state != 2 620 621 edges, weights, edge_state = edges[keep_edges], weights[keep_edges], edge_state[keep_edges] 622 transition_edges = edge_state == 1 623 weights[transition_edges] = transition_edge_weight 624 625 return edges, weights
Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value.
Arguments:
- grid_graph: The grid graph.
- mask: The binary mask, foreground (=non-masked) is True.
- edges: The edges (uv-ids).
- weights: The edge weights.
- transition_edge_weight: The value for edges that connect a masked with a non-masked node.
Returns:
The edge uv-ids. The edge weights.
659def lifted_problem_from_probabilities( 660 rag, 661 watershed: np.ndarray, 662 input_maps: List[np.ndarray], 663 assignment_threshold: float, 664 graph_depth: int, 665 feats_to_costs: callable = feats_to_costs_default, 666 mode: str = "different", 667 n_threads: Optional[int] = None, 668) -> Tuple[np.ndarray, np.ndarray]: 669 """Compute lifted problem from probability maps by mapping them to superpixels. 670 671 Args: 672 rag: The region adjacency graph. 673 watershed: The watershed over-segmentation. 674 input_maps: List of probability maps. Each map must have the same shape as the watersheds. 675 assignment_threshold: Minimal expression level to assign a class to a graph node. 676 graph_depth: Maximal graph depth up to which lifted edges will be included. 677 feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities. 678 mode: The mode for insertion of lifted edges. One of "all", "different", "same". 679 n_threads: The number of threads used for the calculation. 680 681 Returns: 682 The lifted uv ids. 683 The lifted costs. 684 """ 685 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 686 assert isinstance(input_maps, (list, tuple)) 687 assert all(isinstance(inp, np.ndarray) for inp in input_maps) 688 shape = watershed.shape 689 assert all(inp.shape == shape for inp in input_maps) 690 691 n_nodes = int(watershed.max()) + 1 692 node_labels = np.zeros(n_nodes, dtype="uint64") 693 node_features = np.zeros(n_nodes, dtype="float32") 694 for class_id, inp in enumerate(input_maps): 695 mean_prob = _region_features(inp, watershed, ["mean"])["mean"] 696 class_mask = mean_prob > assignment_threshold 697 node_labels[class_mask] = class_id 698 node_features[class_mask] = mean_prob[class_mask] 699 700 lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels( 701 rag, node_labels, graph_depth=graph_depth, mode=mode, 702 ignore_label=0, number_of_threads=n_threads, 703 ) 704 lifted_labels = node_labels[lifted_uvs] 705 lifted_features = node_features[lifted_uvs] 706 707 lifted_costs = feats_to_costs(lifted_labels, lifted_features) 708 return lifted_uvs, lifted_costs
Compute lifted problem from probability maps by mapping them to superpixels.
Arguments:
- rag: The region adjacency graph.
- watershed: The watershed over-segmentation.
- input_maps: List of probability maps. Each map must have the same shape as the watersheds.
- assignment_threshold: Minimal expression level to assign a class to a graph node.
- graph_depth: Maximal graph depth up to which lifted edges will be included.
- feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities.
- mode: The mode for insertion of lifted edges. One of "all", "different", "same".
- n_threads: The number of threads used for the calculation.
Returns:
The lifted uv ids. The lifted costs.
711def lifted_problem_from_segmentation( 712 rag, 713 watershed: np.ndarray, 714 input_segmentation: np.ndarray, 715 overlap_threshold: float, 716 graph_depth: int, 717 same_segment_cost: float, 718 different_segment_cost: float, 719 mode: str = "all", 720 n_threads: Optional[int] = None, 721) -> Tuple[np.ndarray, np.ndarray]: 722 """Compute lifted problem from segmentation by mapping segments to superpixels. 723 724 Args: 725 rag: The region adjacency graph. 726 watershed: The watershed over-segmentation. 727 input_segmentation: The segmentation used to determine node attribution. 728 overlap_threshold: The minimal overlap to assign a segment id to node. 729 graph_depth: The maximal graph depth up to which lifted edges will be included. 730 same_segment_cost: The cost for edges between nodes with same segment id attribution. 731 different_segment_cost: The cost for edges between nodes with different segment id attribution. 732 mode: The mode for insertion of lifted edges. One of "all", "different", "same". 733 n_threads: The number of threads used for the calculation. 734 735 Returns: 736 The lifted uv ids. 737 The lifted costs. 738 """ 739 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 740 assert input_segmentation.shape == watershed.shape 741 742 ovlp = bic.utils.segmentation_overlap(watershed, input_segmentation) 743 ws_ids = np.unique(watershed) 744 n_labels = int(ws_ids[-1]) + 1 745 assert n_labels == rag.number_of_nodes, "%i, %i" % (n_labels, rag.number_of_nodes) 746 747 node_labels = np.zeros(n_labels, dtype="uint64") 748 node_label_vals = np.zeros(len(ws_ids), dtype="uint64") 749 overlap_values = np.zeros(len(ws_ids), dtype="float64") 750 for i, ws_id in enumerate(ws_ids): 751 best = ovlp.best_overlap_for_label_a(int(ws_id), ignore_zero=False) 752 node_label_vals[i] = best.label 753 overlap_values[i] = best.fraction 754 node_label_vals[overlap_values < overlap_threshold] = 0 755 node_labels[ws_ids] = node_label_vals 756 757 lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels( 758 rag, node_labels, graph_depth=graph_depth, mode=mode, 759 ignore_label=0, number_of_threads=n_threads, 760 ) 761 assert lifted_uvs.max() < rag.number_of_nodes, "%i, %i" % (int(lifted_uvs.max()), rag.number_of_nodes) 762 lifted_labels = node_labels[lifted_uvs] 763 lifted_costs = np.zeros(len(lifted_labels), dtype="float64") 764 765 same_mask = lifted_labels[:, 0] == lifted_labels[:, 1] 766 lifted_costs[same_mask] = same_segment_cost 767 lifted_costs[~same_mask] = different_segment_cost 768 769 return lifted_uvs, lifted_costs
Compute lifted problem from segmentation by mapping segments to superpixels.
Arguments:
- rag: The region adjacency graph.
- watershed: The watershed over-segmentation.
- input_segmentation: The segmentation used to determine node attribution.
- overlap_threshold: The minimal overlap to assign a segment id to node.
- graph_depth: The maximal graph depth up to which lifted edges will be included.
- same_segment_cost: The cost for edges between nodes with same segment id attribution.
- different_segment_cost: The cost for edges between nodes with different segment id attribution.
- mode: The mode for insertion of lifted edges. One of "all", "different", "same".
- n_threads: The number of threads used for the calculation.
Returns:
The lifted uv ids. The lifted costs.
776def get_stitch_edges( 777 rag, 778 seg: np.ndarray, 779 block_shape: Tuple[int, ...], 780 n_threads: Optional[int] = None, 781 verbose: bool = False 782) -> np.ndarray: 783 """Get the edges between blocks. 784 785 Args: 786 rag: The region adjacency graph. 787 seg: The segmentation underlying the rag. 788 block_shape: The shape of the blocking. 789 n_threads: The number of threads used for the calculation. 790 verbose: Whether to be verbose. 791 792 Returns: 793 The edge mask indicating edges between blocks. 794 """ 795 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 796 ndim = seg.ndim 797 blocking = bic.utils.Blocking([0] * ndim, list(seg.shape), list(block_shape)) 798 799 def find_stitch_edges(block_id): 800 stitch_edges = [] 801 block = blocking.get_block(block_id) 802 for axis in range(ndim): 803 if blocking.get_neighbor_id(block_id, axis, True) == -1: 804 continue 805 face_a = tuple( 806 beg if d == axis else slice(beg, end) 807 for d, beg, end in zip(range(ndim), block.begin, block.end) 808 ) 809 face_b = tuple( 810 beg - 1 if d == axis else slice(beg, end) 811 for d, beg, end in zip(range(ndim), block.begin, block.end) 812 ) 813 814 labels_a = seg[face_a].ravel() 815 labels_b = seg[face_b].ravel() 816 817 uv_ids = np.concatenate( 818 [labels_a[:, None], labels_b[:, None]], 819 axis=1 820 ) 821 uv_ids = np.unique(uv_ids, axis=0) 822 823 edge_ids = rag.find_edges(uv_ids) 824 edge_ids = edge_ids[edge_ids != -1] 825 stitch_edges.append(edge_ids) 826 827 if stitch_edges: 828 stitch_edges = np.concatenate(stitch_edges) 829 stitch_edges = np.unique(stitch_edges) 830 else: 831 stitch_edges = None 832 return stitch_edges 833 834 with futures.ThreadPoolExecutor(n_threads) as tp: 835 if verbose: 836 stitch_edges = list(tqdm( 837 tp.map(find_stitch_edges, range(blocking.number_of_blocks)), 838 total=blocking.number_of_blocks 839 )) 840 else: 841 stitch_edges = tp.map(find_stitch_edges, range(blocking.number_of_blocks)) 842 843 stitch_edges = np.concatenate([st for st in stitch_edges if st is not None]) 844 stitch_edges = np.unique(stitch_edges) 845 full_edges = np.zeros(rag.number_of_edges, dtype="bool") 846 full_edges[stitch_edges] = 1 847 return full_edges
Get the edges between blocks.
Arguments:
- rag: The region adjacency graph.
- seg: The segmentation underlying the rag.
- block_shape: The shape of the blocking.
- n_threads: The number of threads used for the calculation.
- verbose: Whether to be verbose.
Returns:
The edge mask indicating edges between blocks.
850def project_node_labels_to_pixels( 851 rag, segmentation: np.ndarray, node_labels: np.ndarray, n_threads: Optional[int] = None, 852) -> np.ndarray: 853 """Project label values for graph nodes back to pixels to obtain segmentation. 854 855 Args: 856 rag: The region adjacency graph. 857 segmentation: The over-segmentation used to construct the RAG. 858 node_labels: The array with node labels. 859 n_threads: The number of threads used, set to cpu count by default. 860 861 Returns: 862 The segmentation. 863 """ 864 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 865 if len(node_labels) != rag.number_of_nodes: 866 raise ValueError("Incompatible number of node labels: %i, %i" % (len(node_labels), rag.number_of_nodes)) 867 # bic.graph.project_node_labels_to_pixels requires integer dtypes for both arrays. 868 if segmentation.dtype not in (np.uint32, np.uint64, np.int32, np.int64): 869 segmentation = segmentation.astype("uint64") 870 if node_labels.dtype not in (np.uint32, np.uint64, np.int32, np.int64): 871 node_labels = node_labels.astype("uint64") 872 seg = bic.graph.project_node_labels_to_pixels(rag, segmentation, node_labels, number_of_threads=n_threads) 873 return seg
Project label values for graph nodes back to pixels to obtain segmentation.
Arguments:
- rag: The region adjacency graph.
- segmentation: The over-segmentation used to construct the RAG.
- node_labels: The array with node labels.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The segmentation.
876def compute_z_edge_mask(rag, watershed: np.ndarray) -> np.ndarray: 877 """Compute edge mask of in-between plane edges for flat superpixels. 878 879 Args: 880 rag: The region adjacency graph. 881 watershed: The underlying watershed over-segmentation (superpixels). 882 883 Returns: 884 The edge mask indicating in-between slice edges. 885 """ 886 node_z_coords = np.zeros(rag.number_of_nodes, dtype="uint32") 887 for z in range(watershed.shape[0]): 888 node_z_coords[watershed[z]] = z 889 uv_ids = rag.uv_ids() 890 z_edge_mask = node_z_coords[uv_ids[:, 0]] != node_z_coords[uv_ids[:, 1]] 891 return z_edge_mask
Compute edge mask of in-between plane edges for flat superpixels.
Arguments:
- rag: The region adjacency graph.
- watershed: The underlying watershed over-segmentation (superpixels).
Returns:
The edge mask indicating in-between slice edges.