elf.segmentation.features
1import multiprocessing 2from concurrent import futures 3from typing import Dict, List, Optional, Tuple 4 5import numpy as np 6import vigra 7import nifty 8import nifty.graph.rag as nrag 9import nifty.ground_truth as ngt 10try: 11 import nifty.distributed as ndist 12except ImportError: 13 ndist = None 14 15try: 16 import fastfilters as ff 17except ImportError: 18 import vigra.filters as ff 19 20from tqdm import tqdm 21from .multicut import transform_probabilities_to_costs 22 23 24# 25# Region Adjacency Graph and Features 26# 27 28def compute_rag(segmentation: np.ndarray, n_labels: Optional[int] = None, n_threads: Optional[int] = None): 29 """Compute region adjacency graph of segmentation. 30 31 Args: 32 segmentation: The segmentation. 33 n_labels: The number of labels in the segmentation. If None, will be computed from the data. 34 n_threads: The number of threads used, set to cpu count by default. 35 36 Returns: 37 The region adjacency graph. 38 """ 39 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 40 n_labels = int(segmentation.max()) + 1 if n_labels is None else n_labels 41 rag = nrag.gridRag(segmentation, numberOfLabels=n_labels, numberOfThreads=n_threads) 42 return rag 43 44 45def compute_boundary_features( 46 rag, boundary_map: np.ndarray, min_value: float = 0.0, max_value: float = 1.0, n_threads: Optional[int] = None 47) -> np.ndarray: 48 """Compute edge features from boundary map. 49 50 Args: 51 rag: The region adjacency graph. 52 boundary_map:The boundary map. 53 min_value: The minimum value used in accumulation. 54 max_value: The maximum value used in accumulation. 55 n_threads: The number of threads used, set to cpu count by default. 56 57 Returns: 58 The edge features. 59 """ 60 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 61 if tuple(rag.shape) != boundary_map.shape: 62 raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(boundary_map.shape))) 63 features = nrag.accumulateEdgeStandartFeatures( 64 rag, boundary_map, min_value, max_value, numberOfThreads=n_threads 65 ) 66 return features 67 68 69def compute_affinity_features( 70 rag, 71 affinity_map: np.ndarray, 72 offsets: List[List[int]], 73 min_value: float = 0.0, 74 max_value: float = 1.0, 75 n_threads: Optional[int] = None 76) -> np.ndarray: 77 """Compute edge features from affinity map. 78 79 Args: 80 rag: The region adjacency graph. 81 affinity_map: The affinity map. 82 min_value: The minimum value used in accumulation. 83 max_value: The maximum value used in accumulation. 84 n_threads: The umber of threads used, set to cpu count by default. 85 86 Returns: 87 The edge features. 88 """ 89 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 90 if tuple(rag.shape) != affinity_map.shape[1:]: 91 raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(affinity_map.shape[1:]))) 92 if len(offsets) != affinity_map.shape[0]: 93 raise ValueError("Incompatible number of channels and offsets: %i, %i" % (len(offsets), 94 affinity_map.shape[0])) 95 features = nrag.accumulateAffinityStandartFeatures( 96 rag, affinity_map, offsets, min_value, max_value, numberOfThreads=n_threads 97 ) 98 return features 99 100 101def compute_boundary_mean_and_length(rag, input_: np.ndarray, n_threads: Optional[int] = None) -> np.ndarray: 102 """Compute mean value and length of boundaries. 103 104 Args: 105 rag: The region adjacency graph. 106 input_: The input map. 107 n_threads: The number of threads used, set to cpu count by default. 108 109 Returns: 110 The edge features. 111 """ 112 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 113 if tuple(rag.shape) != input_.shape: 114 raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(input_.shape))) 115 features = nrag.accumulateEdgeMeanAndLength(rag, input_, numberOfThreads=n_threads) 116 return features 117 118 119# TODO generalize and move to elf.features.parallel 120def _filter_2d(input_, filter_name, sigma, n_threads): 121 filter_fu = getattr(ff, filter_name) 122 123 def _fz(inp): 124 response = filter_fu(inp, sigma) 125 # we add a channel last axis for 2d filter responses 126 if response.ndim == 2: 127 response = response[None, ..., None] 128 elif response.ndim == 3: 129 response = response[None] 130 else: 131 raise RuntimeError("Invalid filter response") 132 return response 133 134 with futures.ThreadPoolExecutor(n_threads) as tp: 135 tasks = [tp.submit(_fz, input_[z]) for z in range(input_.shape[0])] 136 response = [t.result() for t in tasks] 137 138 response = np.concatenate(response, axis=0) 139 return response 140 141 142def compute_boundary_features_with_filters( 143 rag, 144 input_: np.ndarray, 145 apply_2d: bool = False, 146 n_threads: Optional[int] = None, 147 filters: Dict[str, List[float]] = {"gaussianSmoothing": [1.6, 4.2, 8.3], 148 "laplacianOfGaussian": [1.6, 4.2, 8.3], 149 "hessianOfGaussianEigenvalues": [1.6, 4.2, 8.3]} 150) -> np.ndarray: 151 """Compute boundary features accumulated over filter responses on input. 152 153 Args: 154 rag: The region adjacency graph. 155 input_: The input data. 156 apply_2d: Whether to apply the filters in 2d for 3d input data. 157 n_threads: The number of threads. 158 filters: The filters to apply, expects a dictionary mapping filter names to sigma values. 159 160 Returns: 161 The edge filters. 162 """ 163 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 164 features = [] 165 166 # apply 2d: we compute filters and derived features in parallel per filter 167 if apply_2d: 168 169 def _compute_2d(filter_name, sigma): 170 response = _filter_2d(input_, filter_name, sigma, n_threads) 171 assert response.ndim == 4 172 n_channels = response.shape[-1] 173 features = [] 174 for chan in range(n_channels): 175 chan_data = response[..., chan] 176 feats = compute_boundary_features(rag, chan_data, 177 chan_data.min(), chan_data.max(), n_threads) 178 features.append(feats) 179 180 features = np.concatenate(features, axis=1) 181 assert len(features) == rag.numberOfEdges 182 return features 183 184 features = [_compute_2d(filter_name, sigma) 185 for filter_name, sigmas in filters.items() for sigma in sigmas] 186 187 # apply 3d: we parallelize over the whole filter + feature computation 188 # this can be very memory intensive, and it would be better to parallelize inside 189 # of the loop, but 3d parallel filters in elf.parallel.filters are not working properly yet 190 else: 191 192 def _compute_3d(filter_name, sigma): 193 filter_fu = getattr(ff, filter_name) 194 response = filter_fu(input_, sigma) 195 if response.ndim == 3: 196 response = response[..., None] 197 198 n_channels = response.shape[-1] 199 features = [] 200 201 for chan in range(n_channels): 202 chan_data = response[..., chan] 203 feats = compute_boundary_features(rag, chan_data, 204 chan_data.min(), chan_data.max(), 205 n_threads=1) 206 features.append(feats) 207 features = np.concatenate(features, axis=1) 208 assert len(features) == rag.numberOfEdges, f"{len(features), {rag.numberOfEdges}}" 209 return features 210 211 with futures.ThreadPoolExecutor(n_threads) as tp: 212 tasks = [tp.submit(_compute_3d, filter_name, sigma) 213 for filter_name, sigmas in filters.items() for sigma in sigmas] 214 features = [t.result() for t in tasks] 215 216 features = np.concatenate(features, axis=1) 217 assert len(features) == rag.numberOfEdges 218 return features 219 220 221def compute_region_features( 222 uv_ids: np.ndarray, 223 input_map: np.ndarray, 224 segmentation: np.ndarray, 225 n_threads: Optional[int] = None 226) -> np.ndarray: 227 """Compute edge features from an input map accumulated over segmentation and mapped to edges. 228 229 Args: 230 uv_ids: The edge uv ids. 231 input_: The input data. 232 segmentation: The segmentation. 233 n_threads: The number of threads used, set to cpu count by default. 234 235 Returns: 236 The edge features. 237 """ 238 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 239 240 # compute the node features 241 stat_feature_names = ["Count", "Kurtosis", "Maximum", "Minimum", "Quantiles", 242 "RegionRadii", "Skewness", "Sum", "Variance"] 243 coord_feature_names = ["Weighted<RegionCenter>", "RegionCenter"] 244 feature_names = stat_feature_names + coord_feature_names 245 node_features = vigra.analysis.extractRegionFeatures(input_map, segmentation, 246 features=feature_names) 247 248 # get the image statistics based features, that are combined via [min, max, sum, absdiff] 249 stat_features = [node_features[fname] for fname in stat_feature_names] 250 stat_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat 251 for feat in stat_features], axis=1) 252 253 # get the coordinate based features, that are combined via euclidean distance 254 coord_features = [node_features[fname] for fname in coord_feature_names] 255 coord_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat 256 for feat in coord_features], axis=1) 257 258 u, v = uv_ids[:, 0], uv_ids[:, 1] 259 260 # combine the stat features for all edges 261 feats_u, feats_v = stat_features[u], stat_features[v] 262 features = [np.minimum(feats_u, feats_v), np.maximum(feats_u, feats_v), 263 np.abs(feats_u - feats_v), feats_u + feats_v] 264 265 # combine the coord features for all edges 266 feats_u, feats_v = coord_features[u], coord_features[v] 267 features.append((feats_u - feats_v) ** 2) 268 269 features = np.nan_to_num(np.concatenate(features, axis=1)) 270 assert len(features) == len(uv_ids) 271 return features 272 273 274# 275# Grid Graph and Features 276# 277 278def compute_grid_graph(shape: Tuple[int, ...]): 279 """Compute grid graph for the given shape. 280 281 Args: 282 shape: The shape of the data. 283 284 Returns: 285 The grid graph. 286 """ 287 grid_graph = nifty.graph.undirectedGridGraph(shape) 288 return grid_graph 289 290 291def compute_grid_graph_image_features( 292 grid_graph, 293 image: np.ndarray, 294 mode: str, 295 offsets: Optional[List[List[int]]] = None, 296 strides: Optional[List[int]] = None, 297 randomize_strides: bool = False, 298) -> Tuple[np.ndarray, np.ndarray]: 299 """Compute edge features for image for the given grid_graph. 300 301 Args: 302 grid_graph: The grid graph 303 image: The image, from which the features will be derived. 304 mode: Feature accumulation method. 305 offsets: The offsets, which correspond to the affinity channels. 306 If none are given, the affinites for the nearest neighbor transitions are used. 307 strides: The strides used to subsample edges that are computed from offsets. 308 randomize_strides: Whether to subsample randomly instead of using regular strides. 309 310 Returns: 311 The uv ids of the edges. 312 The edge features. 313 """ 314 gndim = len(grid_graph.shape) 315 316 if image.ndim == gndim: 317 if offsets is not None: 318 raise NotImplementedError 319 modes = ("l1", "l2", "min", "max", "sum", "prod", "interpixel") 320 if mode not in modes: 321 raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}") 322 features = grid_graph.imageToEdgeMap(image, mode) 323 edges = grid_graph.uvIds() 324 325 elif image.ndim == gndim + 1: 326 modes = ("l1", "l2", "cosine") 327 if mode not in modes: 328 raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}") 329 330 if offsets is None: 331 features = grid_graph.imageWithChannelsToEdgeMap(image, mode) 332 edges = grid_graph.uvIds() 333 else: 334 (n_edges, 335 edges, 336 features) = grid_graph.imageWithChannelsToEdgeMapWithOffsets(image, mode, 337 offsets=offsets, 338 strides=strides, 339 randomize_strides=randomize_strides) 340 edges, features = edges[:n_edges], features[:n_edges] 341 342 else: 343 msg = f"Invalid image dimension {image.ndim}, expect one of {gndim} or {gndim + 1}" 344 raise ValueError(msg) 345 346 return edges, features 347 348 349def compute_grid_graph_affinity_features( 350 grid_graph, 351 affinities: np.ndarray, 352 offsets: Optional[List[List[int]]] = None, 353 strides: Optional[List[int]] = None, 354 mask: Optional[np.ndarray] = None, 355 randomize_strides: bool = False, 356) -> Tuple[np.ndarray, np.ndarray]: 357 """Compute edge features from affinities for the given grid graph. 358 359 Args: 360 grid_graph: The grid graph 361 affinities: The affinity map. 362 offsets: The offsets, which correspond to the affinity channels. 363 If none are given, the affinites for the nearest neighbor transitions are used. 364 strides: The strides used to subsample edges that are computed from offsets. 365 mask: Mask to exclude from the edge and feature computation. 366 randomize_strides: Whether to subsample randomly instead of using regular strides. 367 368 Returns: 369 The uv ids of the edges. 370 The edge features. 371 """ 372 gndim = len(grid_graph.shape) 373 if affinities.ndim != gndim + 1: 374 raise ValueError 375 376 if offsets is None: 377 assert affinities.shape[0] == gndim 378 assert strides is None 379 assert mask is None 380 features = grid_graph.affinitiesToEdgeMap(affinities) 381 edges = grid_graph.uvIds() 382 elif mask is not None: 383 assert strides is None and not randomize_strides, "Strides and mask cannot be used at the same time" 384 n_edges, edges, features = grid_graph.affinitiesToEdgeMapWithMask(affinities, 385 offsets=offsets, 386 mask=mask) 387 edges, features = edges[:n_edges], features[:n_edges] 388 else: 389 n_edges, edges, features = grid_graph.affinitiesToEdgeMapWithOffsets(affinities, 390 offsets=offsets, 391 strides=strides, 392 randomize_strides=randomize_strides) 393 edges, features = edges[:n_edges], features[:n_edges] 394 395 return edges, features 396 397 398def apply_mask_to_grid_graph_weights( 399 grid_graph, 400 mask: np.ndarray, 401 weights: np.ndarray, 402 masked_edge_weight: float = 0.0, 403 transition_edge_weight: float = 1.0, 404) -> np.ndarray: 405 """Mask edges in grid graph. 406 407 Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes 408 and edges that connect masked and unmasked nodes. 409 410 Args: 411 grid_graph: The grid graph. 412 mask: The binary mask, foreground (=non-masked) is True. 413 weights: The edge weights. 414 masked_edge_weight: The value for edges that connect two masked nodes. 415 transition_edge_weight: The value for edges that connect a masked with a non-masked node. 416 417 Returns: 418 The masked edge weights. 419 """ 420 assert np.dtype(mask.dtype) == np.dtype("bool") 421 node_ids = grid_graph.projectNodeIdsToPixels() 422 assert node_ids.shape == mask.shape == tuple(grid_graph.shape), \ 423 f"{node_ids.shape}, {mask.shape}, {grid_graph.shape}" 424 masked_ids = node_ids[~mask] 425 426 edges = grid_graph.uvIds() 427 assert len(edges) == len(weights) 428 edge_state = np.isin(edges, masked_ids).sum(axis=1) 429 masked_edges = edge_state == 2 430 transition_edges = edge_state == 1 431 weights[masked_edges] = masked_edge_weight 432 weights[transition_edges] = transition_edge_weight 433 return weights 434 435 436def apply_mask_to_grid_graph_edges_and_weights( 437 grid_graph, mask: np.ndarray, edges: np.ndarray, weights: np.ndarray, transition_edge_weight: float = 1.0 438) -> Tuple[np.ndarray, np.ndarray]: 439 """Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value. 440 441 Args: 442 grid_graph: The grid graph. 443 mask: The binary mask, foreground (=non-masked) is True. 444 edges: The edges (uv-ids). 445 weights: The edge weights. 446 transition_edge_weight: The value for edges that connect a masked with a non-masked node. 447 448 Returns: 449 The edge uv-ids. 450 The edge weights. 451 """ 452 assert np.dtype(mask.dtype) == np.dtype("bool") 453 node_ids = grid_graph.projectNodeIdsToPixels() 454 assert node_ids.shape == mask.shape == tuple(grid_graph.shape), \ 455 f"{node_ids.shape}, {mask.shape}, {grid_graph.shape}" 456 masked_ids = node_ids[~mask] 457 458 edge_state = np.isin(edges, masked_ids).sum(axis=1) 459 keep_edges = edge_state != 2 460 461 edges, weights, edge_state = edges[keep_edges], weights[keep_edges], edge_state[keep_edges] 462 transition_edges = edge_state == 1 463 weights[transition_edges] = transition_edge_weight 464 465 return edges, weights 466 467 468# 469# Lifted Features 470# 471 472def lifted_edges_from_graph_neighborhood(graph, max_graph_distance): 473 """@private 474 """ 475 if max_graph_distance < 2: 476 raise ValueError(f"Graph distance must be greater equal 2, got {max_graph_distance}") 477 if isinstance(graph, nifty.graph.UndirectedGraph): 478 objective = nifty.graph.opt.lifted_multicut.liftedMulticutObjective(graph) 479 else: 480 graph_ = nifty.graph.undirectedGraph(graph.numberOfNodes) 481 graph_.insertEdges(graph.uvIds()) 482 objective = nifty.graph.opt.lifted_multicut.liftedMulticutObjective(graph_) 483 objective.insertLiftedEdgesBfs(max_graph_distance) 484 lifted_uvs = objective.liftedUvIds() 485 return lifted_uvs 486 487 488def feats_to_costs_default(lifted_labels, lifted_features): 489 """@private 490 """ 491 # we assume that we only have different classes for a given lifted 492 # edge here (mode = "different") and then set all edges to be repulsive 493 494 # the higher the class probability, the more repulsive the edges should be, 495 # so we just multiply both probabilities 496 lifted_costs = lifted_features[:, 0] * lifted_features[:, 1] 497 lifted_costs = transform_probabilities_to_costs(lifted_costs) 498 return lifted_costs 499 500 501def lifted_problem_from_probabilities( 502 rag, 503 watershed: np.ndarray, 504 input_maps: List[np.ndarray], 505 assignment_threshold: float, 506 graph_depth: int, 507 feats_to_costs: callable = feats_to_costs_default, 508 mode: str = "different", 509 n_threads: Optional[int] = None, 510) -> Tuple[np.ndarray, np.ndarray]: 511 """Compute lifted problem from probability maps by mapping them to superpixels. 512 513 Example: compute a lifted problem from two attributions (axon, dendrite) that induce 514 repulsive edges between different attributions. The construction of lifted eges and 515 features can be customized using the `feats_to_costs` and `mode` arguments. 516 ``` 517 lifted_uvs, lifted_costs = lifted_problem_from_probabilties( 518 rag, superpixels, 519 input_maps=[ 520 axon_probabilities, # probabilty map for axon attribution 521 dendrite_probabilities # probability map for dendrite attributtion 522 ], 523 assignment_threshold=0.6, # probability threshold to assign superpixels to a class 524 graph_depth=10, # the max. graph depth along which lifted edges are introduced 525 ) 526 ``` 527 528 Args: 529 rag: The region adjacency graph. 530 watershed: The watershed over-segmentation. 531 input_maps: List of probability maps. Each map must have the same shape as the watersheds 532 and each map is treated as the probability to correspond to a different class. 533 assignment_threshold: Minimal expression level to assign a class to a graph node (= watershed segment). 534 graph_depth: Maximal graph depth up to which lifted edges will be included. 535 feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities. 536 The input to the function are `lifted_labels`, which stores the two classes assigned to a lifted edge, 537 and `lifted_features`, which stores the two assignment probabilities. 538 mode: The mode for insertion of lifted edges. One of: 539 "all" - lifted edges will be inserted in between all nodes with attribution. 540 "different" - lifted edges will only be inserted in between nodes attributed to different classes. 541 "same" - lifted edges will only be inserted in between nodes attribted to the same class. 542 n_threads: The number of threads used for the calculation. 543 544 Returns: 545 The lifted uv ids (= superpixel ids connected by the lifted edge). 546 The lifted costs (= cost associated with each lifted edge). 547 """ 548 assert ndist is not None, "Need nifty.distributed package" 549 550 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 551 # validate inputs 552 assert isinstance(input_maps, (list, tuple)) 553 assert all(isinstance(inp, np.ndarray) for inp in input_maps) 554 shape = watershed.shape 555 assert all(inp.shape == shape for inp in input_maps) 556 557 # map the probability maps to superpixels - we only map to superpixels which 558 # have a larger mean expression than `assignment_threshold` 559 560 # TODO handle the dtype conversion for vigra gracefully somehow ... 561 # think about supporting uint8 input and normalizing 562 563 # TODO how do we handle cases where the same superpixel is mapped to 564 # more than one class ? 565 566 n_nodes = int(watershed.max()) + 1 567 node_labels = np.zeros(n_nodes, dtype="uint64") 568 node_features = np.zeros(n_nodes, dtype="float32") 569 # TODO we could allow for more features that could then be used for the cost estimation 570 for class_id, inp in enumerate(input_maps): 571 mean_prob = vigra.analysis.extractRegionFeatures(inp, watershed, features=["mean"])["mean"] 572 # we can in principle map multiple classes here, and right now will just override 573 class_mask = mean_prob > assignment_threshold 574 node_labels[class_mask] = class_id 575 node_features[class_mask] = mean_prob[class_mask] 576 577 # find all lifted edges up to the graph depth between mapped nodes 578 # NOTE we need to convert to the different graph type for now, but 579 # it would be nice to support all nifty graphs at some type 580 uv_ids = rag.uvIds() 581 g_temp = ndist.Graph(uv_ids) 582 583 lifted_uvs = ndist.liftedNeighborhoodFromNodeLabels(g_temp, node_labels, graph_depth, mode=mode, 584 numberOfThreads=n_threads, ignoreLabel=0) 585 lifted_labels = node_labels[lifted_uvs] 586 lifted_features = node_features[lifted_uvs] 587 588 lifted_costs = feats_to_costs(lifted_labels, lifted_features) 589 return lifted_uvs, lifted_costs 590 591 592# TODO support setting costs proportional to overlaps 593def lifted_problem_from_segmentation( 594 rag, 595 watershed: np.ndarray, 596 input_segmentation: np.ndarray, 597 overlap_threshold: float, 598 graph_depth: int, 599 same_segment_cost: float, 600 different_segment_cost: float, 601 mode: str = "all", 602 n_threads: Optional[int] = None, 603) -> Tuple[np.ndarray, np.ndarray]: 604 """Compute lifted problem from segmentation by mapping segments to superpixels. 605 606 Args: 607 rag: The region adjacency graph. 608 watershed: The watershed over-segmentation. 609 input_segmentation: The segmentation used to determine node attribution. 610 overlap_threshold: The minimal overlap to assign a segment id to node. 611 graph_depth: The maximal graph depth up to which lifted edges will be included 612 same_segment_cost: The cost for edges between nodes with same segment id attribution. 613 different_segment_cost: The cost for edges between nodes with different segment id attribution. 614 mode: The mode for insertion of lifted edges. One of: 615 "all" - lifted edges will be inserted in between all nodes with attribution. 616 "different" - lifted edges will only be inserted in between nodes attributed to different classes. 617 "same" - lifted edges will only be inserted in between nodes attribted to the same class. 618 n_threads: The number of threads used for the calculation. 619 620 Returns: 621 The lifted uv ids (= superpixel ids connected by the lifted edge). 622 The lifted costs (= cost associated with each lifted edge). 623 """ 624 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 625 assert input_segmentation.shape == watershed.shape 626 627 # compute the overlaps 628 ovlp_comp = ngt.overlap(watershed, input_segmentation) 629 ws_ids = np.unique(watershed) 630 n_labels = int(ws_ids[-1]) + 1 631 assert n_labels == rag.numberOfNodes, "%i, %i" % (n_labels, rag.numberOfNodes) 632 633 # initialise the arrays for node labels, to be 634 # dense in the watershed id space (even if some ws-ids are not present) 635 node_labels = np.zeros(n_labels, dtype="uint64") 636 637 # extract the overlap values and node labels from the overlap 638 # computation results 639 overlaps = [ovlp_comp.overlapArraysNormalized(ws_id, sorted=False) 640 for ws_id in ws_ids] 641 node_label_vals = np.array([ovlp[0][0] for ovlp in overlaps]) 642 overlap_values = np.array([ovlp[1][0] for ovlp in overlaps]) 643 node_label_vals[overlap_values < overlap_threshold] = 0 644 assert len(node_label_vals) == len(ws_ids) 645 node_labels[ws_ids] = node_label_vals 646 647 # find all lifted edges up to the graph depth between mapped nodes 648 # NOTE we need to convert to the different graph type for now, but 649 # it would be nice to support all nifty graphs at some type 650 uv_ids = rag.uvIds() 651 g_temp = ndist.Graph(uv_ids) 652 653 lifted_uvs = ndist.liftedNeighborhoodFromNodeLabels(g_temp, node_labels, graph_depth, mode=mode, 654 numberOfThreads=n_threads, ignoreLabel=0) 655 # make sure that the lifted uv ids are in range of the node labels 656 assert lifted_uvs.max() < rag.numberOfNodes, "%i, %i" % (int(lifted_uvs.max()), 657 rag.numberOfNodes) 658 lifted_labels = node_labels[lifted_uvs] 659 lifted_costs = np.zeros(len(lifted_labels), dtype="float64") 660 661 same_mask = lifted_labels[:, 0] == lifted_labels[:, 1] 662 lifted_costs[same_mask] = same_segment_cost 663 lifted_costs[~same_mask] = different_segment_cost 664 665 return lifted_uvs, lifted_costs 666 667 668# 669# Misc 670# 671 672def get_stitch_edges( 673 rag, 674 seg: np.ndarray, 675 block_shape: Tuple[int, ...], 676 n_threads: Optional[int] = None, 677 verbose: bool = False 678) -> np.ndarray: 679 """Get the edges between blocks. 680 681 Args: 682 rag: The region adjacency graph. 683 seg: The segmentation underlying the rag. 684 block_shape: The shape of the blocking. 685 n_threads: The number of threads used for the calculation. 686 verbose: Whether to be verbose. 687 688 Returns: 689 The edge mask indicating edges between blocks. 690 """ 691 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 692 ndim = seg.ndim 693 blocking = nifty.tools.blocking([0] * ndim, seg.shape, block_shape) 694 695 def find_stitch_edges(block_id): 696 stitch_edges = [] 697 block = blocking.getBlock(block_id) 698 for axis in range(ndim): 699 if blocking.getNeighborId(block_id, axis, True) == -1: 700 continue 701 face_a = tuple( 702 beg if d == axis else slice(beg, end) 703 for d, beg, end in zip(range(ndim), block.begin, block.end) 704 ) 705 face_b = tuple( 706 beg - 1 if d == axis else slice(beg, end) 707 for d, beg, end in zip(range(ndim), block.begin, block.end) 708 ) 709 710 labels_a = seg[face_a].ravel() 711 labels_b = seg[face_b].ravel() 712 713 uv_ids = np.concatenate( 714 [labels_a[:, None], labels_b[:, None]], 715 axis=1 716 ) 717 uv_ids = np.unique(uv_ids, axis=0) 718 719 edge_ids = rag.findEdges(uv_ids) 720 edge_ids = edge_ids[edge_ids != -1] 721 stitch_edges.append(edge_ids) 722 723 if stitch_edges: 724 stitch_edges = np.concatenate(stitch_edges) 725 stitch_edges = np.unique(stitch_edges) 726 else: 727 stitch_edges = None 728 return stitch_edges 729 730 with futures.ThreadPoolExecutor(n_threads) as tp: 731 if verbose: 732 stitch_edges = list(tqdm( 733 tp.map(find_stitch_edges, range(blocking.numberOfBlocks)), 734 total=blocking.numberOfBlocks 735 )) 736 else: 737 stitch_edges = tp.map(find_stitch_edges, range(blocking.numberOfBlocks)) 738 739 stitch_edges = np.concatenate([st for st in stitch_edges if st is not None]) 740 stitch_edges = np.unique(stitch_edges) 741 full_edges = np.zeros(rag.numberOfEdges, dtype="bool") 742 full_edges[stitch_edges] = 1 743 return full_edges 744 745 746def project_node_labels_to_pixels(rag, node_labels: np.ndarray, n_threads: Optional[int] = None) -> np.ndarray: 747 """Project label values for graph nodes back to pixels to obtain segmentation. 748 749 Args: 750 rag: The region adjacency graph. 751 node_labels: The array with node labels. 752 n_threads: The number of threads used, set to cpu count by default. 753 754 Returns: 755 The segmentation. 756 """ 757 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 758 if len(node_labels) != rag.numberOfNodes: 759 raise ValueError("Incompatible number of node labels: %i, %i" % (len(node_labels), rag.numberOfNodes)) 760 seg = nrag.projectScalarNodeDataToPixels(rag, node_labels, numberOfThreads=n_threads) 761 return seg 762 763 764def compute_z_edge_mask(rag, watershed: np.ndarray) -> np.ndarray: 765 """Compute edge mask of in-between plane edges for flat superpixels. 766 767 Flat superpixels are volumetric superpixels that are independent across slices. 768 This function does not check wether the input watersheds are actually flat. 769 770 Args: 771 rag: The region adjacency graph. 772 watershed: The underlying watershed over-segmentation (superpixels). 773 774 Returns: 775 The edge mask indicating in-between slice edges. 776 """ 777 node_z_coords = np.zeros(rag.numberOfNodes, dtype="uint32") 778 for z in range(watershed.shape[0]): 779 node_z_coords[watershed[z]] = z 780 uv_ids = rag.uvIds() 781 z_edge_mask = node_z_coords[uv_ids[:, 0]] != node_z_coords[uv_ids[:, 1]] 782 return z_edge_mask
29def compute_rag(segmentation: np.ndarray, n_labels: Optional[int] = None, n_threads: Optional[int] = None): 30 """Compute region adjacency graph of segmentation. 31 32 Args: 33 segmentation: The segmentation. 34 n_labels: The number of labels in the segmentation. If None, will be computed from the data. 35 n_threads: The number of threads used, set to cpu count by default. 36 37 Returns: 38 The region adjacency graph. 39 """ 40 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 41 n_labels = int(segmentation.max()) + 1 if n_labels is None else n_labels 42 rag = nrag.gridRag(segmentation, numberOfLabels=n_labels, numberOfThreads=n_threads) 43 return rag
Compute region adjacency graph of segmentation.
Arguments:
- segmentation: The segmentation.
- n_labels: The number of labels in the segmentation. If None, will be computed from the data.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The region adjacency graph.
46def compute_boundary_features( 47 rag, boundary_map: np.ndarray, min_value: float = 0.0, max_value: float = 1.0, n_threads: Optional[int] = None 48) -> np.ndarray: 49 """Compute edge features from boundary map. 50 51 Args: 52 rag: The region adjacency graph. 53 boundary_map:The boundary map. 54 min_value: The minimum value used in accumulation. 55 max_value: The maximum value used in accumulation. 56 n_threads: The number of threads used, set to cpu count by default. 57 58 Returns: 59 The edge features. 60 """ 61 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 62 if tuple(rag.shape) != boundary_map.shape: 63 raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(boundary_map.shape))) 64 features = nrag.accumulateEdgeStandartFeatures( 65 rag, boundary_map, min_value, max_value, numberOfThreads=n_threads 66 ) 67 return features
Compute edge features from boundary map.
Arguments:
- rag: The region adjacency graph.
- boundary_map: The boundary map.
- min_value: The minimum value used in accumulation.
- max_value: The maximum value used in accumulation.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The edge features.
70def compute_affinity_features( 71 rag, 72 affinity_map: np.ndarray, 73 offsets: List[List[int]], 74 min_value: float = 0.0, 75 max_value: float = 1.0, 76 n_threads: Optional[int] = None 77) -> np.ndarray: 78 """Compute edge features from affinity map. 79 80 Args: 81 rag: The region adjacency graph. 82 affinity_map: The affinity map. 83 min_value: The minimum value used in accumulation. 84 max_value: The maximum value used in accumulation. 85 n_threads: The umber of threads used, set to cpu count by default. 86 87 Returns: 88 The edge features. 89 """ 90 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 91 if tuple(rag.shape) != affinity_map.shape[1:]: 92 raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(affinity_map.shape[1:]))) 93 if len(offsets) != affinity_map.shape[0]: 94 raise ValueError("Incompatible number of channels and offsets: %i, %i" % (len(offsets), 95 affinity_map.shape[0])) 96 features = nrag.accumulateAffinityStandartFeatures( 97 rag, affinity_map, offsets, min_value, max_value, numberOfThreads=n_threads 98 ) 99 return features
Compute edge features from affinity map.
Arguments:
- rag: The region adjacency graph.
- affinity_map: The affinity map.
- min_value: The minimum value used in accumulation.
- max_value: The maximum value used in accumulation.
- n_threads: The umber of threads used, set to cpu count by default.
Returns:
The edge features.
102def compute_boundary_mean_and_length(rag, input_: np.ndarray, n_threads: Optional[int] = None) -> np.ndarray: 103 """Compute mean value and length of boundaries. 104 105 Args: 106 rag: The region adjacency graph. 107 input_: The input map. 108 n_threads: The number of threads used, set to cpu count by default. 109 110 Returns: 111 The edge features. 112 """ 113 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 114 if tuple(rag.shape) != input_.shape: 115 raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(input_.shape))) 116 features = nrag.accumulateEdgeMeanAndLength(rag, input_, numberOfThreads=n_threads) 117 return features
Compute mean value and length of boundaries.
Arguments:
- rag: The region adjacency graph.
- input_: The input map.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The edge features.
143def compute_boundary_features_with_filters( 144 rag, 145 input_: np.ndarray, 146 apply_2d: bool = False, 147 n_threads: Optional[int] = None, 148 filters: Dict[str, List[float]] = {"gaussianSmoothing": [1.6, 4.2, 8.3], 149 "laplacianOfGaussian": [1.6, 4.2, 8.3], 150 "hessianOfGaussianEigenvalues": [1.6, 4.2, 8.3]} 151) -> np.ndarray: 152 """Compute boundary features accumulated over filter responses on input. 153 154 Args: 155 rag: The region adjacency graph. 156 input_: The input data. 157 apply_2d: Whether to apply the filters in 2d for 3d input data. 158 n_threads: The number of threads. 159 filters: The filters to apply, expects a dictionary mapping filter names to sigma values. 160 161 Returns: 162 The edge filters. 163 """ 164 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 165 features = [] 166 167 # apply 2d: we compute filters and derived features in parallel per filter 168 if apply_2d: 169 170 def _compute_2d(filter_name, sigma): 171 response = _filter_2d(input_, filter_name, sigma, n_threads) 172 assert response.ndim == 4 173 n_channels = response.shape[-1] 174 features = [] 175 for chan in range(n_channels): 176 chan_data = response[..., chan] 177 feats = compute_boundary_features(rag, chan_data, 178 chan_data.min(), chan_data.max(), n_threads) 179 features.append(feats) 180 181 features = np.concatenate(features, axis=1) 182 assert len(features) == rag.numberOfEdges 183 return features 184 185 features = [_compute_2d(filter_name, sigma) 186 for filter_name, sigmas in filters.items() for sigma in sigmas] 187 188 # apply 3d: we parallelize over the whole filter + feature computation 189 # this can be very memory intensive, and it would be better to parallelize inside 190 # of the loop, but 3d parallel filters in elf.parallel.filters are not working properly yet 191 else: 192 193 def _compute_3d(filter_name, sigma): 194 filter_fu = getattr(ff, filter_name) 195 response = filter_fu(input_, sigma) 196 if response.ndim == 3: 197 response = response[..., None] 198 199 n_channels = response.shape[-1] 200 features = [] 201 202 for chan in range(n_channels): 203 chan_data = response[..., chan] 204 feats = compute_boundary_features(rag, chan_data, 205 chan_data.min(), chan_data.max(), 206 n_threads=1) 207 features.append(feats) 208 features = np.concatenate(features, axis=1) 209 assert len(features) == rag.numberOfEdges, f"{len(features), {rag.numberOfEdges}}" 210 return features 211 212 with futures.ThreadPoolExecutor(n_threads) as tp: 213 tasks = [tp.submit(_compute_3d, filter_name, sigma) 214 for filter_name, sigmas in filters.items() for sigma in sigmas] 215 features = [t.result() for t in tasks] 216 217 features = np.concatenate(features, axis=1) 218 assert len(features) == rag.numberOfEdges 219 return features
Compute boundary features accumulated over filter responses on input.
Arguments:
- rag: The region adjacency graph.
- input_: The input data.
- apply_2d: Whether to apply the filters in 2d for 3d input data.
- n_threads: The number of threads.
- filters: The filters to apply, expects a dictionary mapping filter names to sigma values.
Returns:
The edge filters.
222def compute_region_features( 223 uv_ids: np.ndarray, 224 input_map: np.ndarray, 225 segmentation: np.ndarray, 226 n_threads: Optional[int] = None 227) -> np.ndarray: 228 """Compute edge features from an input map accumulated over segmentation and mapped to edges. 229 230 Args: 231 uv_ids: The edge uv ids. 232 input_: The input data. 233 segmentation: The segmentation. 234 n_threads: The number of threads used, set to cpu count by default. 235 236 Returns: 237 The edge features. 238 """ 239 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 240 241 # compute the node features 242 stat_feature_names = ["Count", "Kurtosis", "Maximum", "Minimum", "Quantiles", 243 "RegionRadii", "Skewness", "Sum", "Variance"] 244 coord_feature_names = ["Weighted<RegionCenter>", "RegionCenter"] 245 feature_names = stat_feature_names + coord_feature_names 246 node_features = vigra.analysis.extractRegionFeatures(input_map, segmentation, 247 features=feature_names) 248 249 # get the image statistics based features, that are combined via [min, max, sum, absdiff] 250 stat_features = [node_features[fname] for fname in stat_feature_names] 251 stat_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat 252 for feat in stat_features], axis=1) 253 254 # get the coordinate based features, that are combined via euclidean distance 255 coord_features = [node_features[fname] for fname in coord_feature_names] 256 coord_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat 257 for feat in coord_features], axis=1) 258 259 u, v = uv_ids[:, 0], uv_ids[:, 1] 260 261 # combine the stat features for all edges 262 feats_u, feats_v = stat_features[u], stat_features[v] 263 features = [np.minimum(feats_u, feats_v), np.maximum(feats_u, feats_v), 264 np.abs(feats_u - feats_v), feats_u + feats_v] 265 266 # combine the coord features for all edges 267 feats_u, feats_v = coord_features[u], coord_features[v] 268 features.append((feats_u - feats_v) ** 2) 269 270 features = np.nan_to_num(np.concatenate(features, axis=1)) 271 assert len(features) == len(uv_ids) 272 return features
Compute edge features from an input map accumulated over segmentation and mapped to edges.
Arguments:
- uv_ids: The edge uv ids.
- input_: The input data.
- segmentation: The segmentation.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The edge features.
279def compute_grid_graph(shape: Tuple[int, ...]): 280 """Compute grid graph for the given shape. 281 282 Args: 283 shape: The shape of the data. 284 285 Returns: 286 The grid graph. 287 """ 288 grid_graph = nifty.graph.undirectedGridGraph(shape) 289 return grid_graph
Compute grid graph for the given shape.
Arguments:
- shape: The shape of the data.
Returns:
The grid graph.
292def compute_grid_graph_image_features( 293 grid_graph, 294 image: np.ndarray, 295 mode: str, 296 offsets: Optional[List[List[int]]] = None, 297 strides: Optional[List[int]] = None, 298 randomize_strides: bool = False, 299) -> Tuple[np.ndarray, np.ndarray]: 300 """Compute edge features for image for the given grid_graph. 301 302 Args: 303 grid_graph: The grid graph 304 image: The image, from which the features will be derived. 305 mode: Feature accumulation method. 306 offsets: The offsets, which correspond to the affinity channels. 307 If none are given, the affinites for the nearest neighbor transitions are used. 308 strides: The strides used to subsample edges that are computed from offsets. 309 randomize_strides: Whether to subsample randomly instead of using regular strides. 310 311 Returns: 312 The uv ids of the edges. 313 The edge features. 314 """ 315 gndim = len(grid_graph.shape) 316 317 if image.ndim == gndim: 318 if offsets is not None: 319 raise NotImplementedError 320 modes = ("l1", "l2", "min", "max", "sum", "prod", "interpixel") 321 if mode not in modes: 322 raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}") 323 features = grid_graph.imageToEdgeMap(image, mode) 324 edges = grid_graph.uvIds() 325 326 elif image.ndim == gndim + 1: 327 modes = ("l1", "l2", "cosine") 328 if mode not in modes: 329 raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}") 330 331 if offsets is None: 332 features = grid_graph.imageWithChannelsToEdgeMap(image, mode) 333 edges = grid_graph.uvIds() 334 else: 335 (n_edges, 336 edges, 337 features) = grid_graph.imageWithChannelsToEdgeMapWithOffsets(image, mode, 338 offsets=offsets, 339 strides=strides, 340 randomize_strides=randomize_strides) 341 edges, features = edges[:n_edges], features[:n_edges] 342 343 else: 344 msg = f"Invalid image dimension {image.ndim}, expect one of {gndim} or {gndim + 1}" 345 raise ValueError(msg) 346 347 return edges, features
Compute edge features for image for the given grid_graph.
Arguments:
- grid_graph: The grid graph
- image: The image, from which the features will be derived.
- mode: Feature accumulation method.
- offsets: The offsets, which correspond to the affinity channels. If none are given, the affinites for the nearest neighbor transitions are used.
- strides: The strides used to subsample edges that are computed from offsets.
- randomize_strides: Whether to subsample randomly instead of using regular strides.
Returns:
The uv ids of the edges. The edge features.
350def compute_grid_graph_affinity_features( 351 grid_graph, 352 affinities: np.ndarray, 353 offsets: Optional[List[List[int]]] = None, 354 strides: Optional[List[int]] = None, 355 mask: Optional[np.ndarray] = None, 356 randomize_strides: bool = False, 357) -> Tuple[np.ndarray, np.ndarray]: 358 """Compute edge features from affinities for the given grid graph. 359 360 Args: 361 grid_graph: The grid graph 362 affinities: The affinity map. 363 offsets: The offsets, which correspond to the affinity channels. 364 If none are given, the affinites for the nearest neighbor transitions are used. 365 strides: The strides used to subsample edges that are computed from offsets. 366 mask: Mask to exclude from the edge and feature computation. 367 randomize_strides: Whether to subsample randomly instead of using regular strides. 368 369 Returns: 370 The uv ids of the edges. 371 The edge features. 372 """ 373 gndim = len(grid_graph.shape) 374 if affinities.ndim != gndim + 1: 375 raise ValueError 376 377 if offsets is None: 378 assert affinities.shape[0] == gndim 379 assert strides is None 380 assert mask is None 381 features = grid_graph.affinitiesToEdgeMap(affinities) 382 edges = grid_graph.uvIds() 383 elif mask is not None: 384 assert strides is None and not randomize_strides, "Strides and mask cannot be used at the same time" 385 n_edges, edges, features = grid_graph.affinitiesToEdgeMapWithMask(affinities, 386 offsets=offsets, 387 mask=mask) 388 edges, features = edges[:n_edges], features[:n_edges] 389 else: 390 n_edges, edges, features = grid_graph.affinitiesToEdgeMapWithOffsets(affinities, 391 offsets=offsets, 392 strides=strides, 393 randomize_strides=randomize_strides) 394 edges, features = edges[:n_edges], features[:n_edges] 395 396 return edges, features
Compute edge features from affinities for the given grid graph.
Arguments:
- grid_graph: The grid graph
- affinities: The affinity map.
- offsets: The offsets, which correspond to the affinity channels. If none are given, the affinites for the nearest neighbor transitions are used.
- strides: The strides used to subsample edges that are computed from offsets.
- mask: Mask to exclude from the edge and feature computation.
- randomize_strides: Whether to subsample randomly instead of using regular strides.
Returns:
The uv ids of the edges. The edge features.
399def apply_mask_to_grid_graph_weights( 400 grid_graph, 401 mask: np.ndarray, 402 weights: np.ndarray, 403 masked_edge_weight: float = 0.0, 404 transition_edge_weight: float = 1.0, 405) -> np.ndarray: 406 """Mask edges in grid graph. 407 408 Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes 409 and edges that connect masked and unmasked nodes. 410 411 Args: 412 grid_graph: The grid graph. 413 mask: The binary mask, foreground (=non-masked) is True. 414 weights: The edge weights. 415 masked_edge_weight: The value for edges that connect two masked nodes. 416 transition_edge_weight: The value for edges that connect a masked with a non-masked node. 417 418 Returns: 419 The masked edge weights. 420 """ 421 assert np.dtype(mask.dtype) == np.dtype("bool") 422 node_ids = grid_graph.projectNodeIdsToPixels() 423 assert node_ids.shape == mask.shape == tuple(grid_graph.shape), \ 424 f"{node_ids.shape}, {mask.shape}, {grid_graph.shape}" 425 masked_ids = node_ids[~mask] 426 427 edges = grid_graph.uvIds() 428 assert len(edges) == len(weights) 429 edge_state = np.isin(edges, masked_ids).sum(axis=1) 430 masked_edges = edge_state == 2 431 transition_edges = edge_state == 1 432 weights[masked_edges] = masked_edge_weight 433 weights[transition_edges] = transition_edge_weight 434 return weights
Mask edges in grid graph.
Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes and edges that connect masked and unmasked nodes.
Arguments:
- grid_graph: The grid graph.
- mask: The binary mask, foreground (=non-masked) is True.
- weights: The edge weights.
- masked_edge_weight: The value for edges that connect two masked nodes.
- transition_edge_weight: The value for edges that connect a masked with a non-masked node.
Returns:
The masked edge weights.
437def apply_mask_to_grid_graph_edges_and_weights( 438 grid_graph, mask: np.ndarray, edges: np.ndarray, weights: np.ndarray, transition_edge_weight: float = 1.0 439) -> Tuple[np.ndarray, np.ndarray]: 440 """Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value. 441 442 Args: 443 grid_graph: The grid graph. 444 mask: The binary mask, foreground (=non-masked) is True. 445 edges: The edges (uv-ids). 446 weights: The edge weights. 447 transition_edge_weight: The value for edges that connect a masked with a non-masked node. 448 449 Returns: 450 The edge uv-ids. 451 The edge weights. 452 """ 453 assert np.dtype(mask.dtype) == np.dtype("bool") 454 node_ids = grid_graph.projectNodeIdsToPixels() 455 assert node_ids.shape == mask.shape == tuple(grid_graph.shape), \ 456 f"{node_ids.shape}, {mask.shape}, {grid_graph.shape}" 457 masked_ids = node_ids[~mask] 458 459 edge_state = np.isin(edges, masked_ids).sum(axis=1) 460 keep_edges = edge_state != 2 461 462 edges, weights, edge_state = edges[keep_edges], weights[keep_edges], edge_state[keep_edges] 463 transition_edges = edge_state == 1 464 weights[transition_edges] = transition_edge_weight 465 466 return edges, weights
Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value.
Arguments:
- grid_graph: The grid graph.
- mask: The binary mask, foreground (=non-masked) is True.
- edges: The edges (uv-ids).
- weights: The edge weights.
- transition_edge_weight: The value for edges that connect a masked with a non-masked node.
Returns:
The edge uv-ids. The edge weights.
502def lifted_problem_from_probabilities( 503 rag, 504 watershed: np.ndarray, 505 input_maps: List[np.ndarray], 506 assignment_threshold: float, 507 graph_depth: int, 508 feats_to_costs: callable = feats_to_costs_default, 509 mode: str = "different", 510 n_threads: Optional[int] = None, 511) -> Tuple[np.ndarray, np.ndarray]: 512 """Compute lifted problem from probability maps by mapping them to superpixels. 513 514 Example: compute a lifted problem from two attributions (axon, dendrite) that induce 515 repulsive edges between different attributions. The construction of lifted eges and 516 features can be customized using the `feats_to_costs` and `mode` arguments. 517 ``` 518 lifted_uvs, lifted_costs = lifted_problem_from_probabilties( 519 rag, superpixels, 520 input_maps=[ 521 axon_probabilities, # probabilty map for axon attribution 522 dendrite_probabilities # probability map for dendrite attributtion 523 ], 524 assignment_threshold=0.6, # probability threshold to assign superpixels to a class 525 graph_depth=10, # the max. graph depth along which lifted edges are introduced 526 ) 527 ``` 528 529 Args: 530 rag: The region adjacency graph. 531 watershed: The watershed over-segmentation. 532 input_maps: List of probability maps. Each map must have the same shape as the watersheds 533 and each map is treated as the probability to correspond to a different class. 534 assignment_threshold: Minimal expression level to assign a class to a graph node (= watershed segment). 535 graph_depth: Maximal graph depth up to which lifted edges will be included. 536 feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities. 537 The input to the function are `lifted_labels`, which stores the two classes assigned to a lifted edge, 538 and `lifted_features`, which stores the two assignment probabilities. 539 mode: The mode for insertion of lifted edges. One of: 540 "all" - lifted edges will be inserted in between all nodes with attribution. 541 "different" - lifted edges will only be inserted in between nodes attributed to different classes. 542 "same" - lifted edges will only be inserted in between nodes attribted to the same class. 543 n_threads: The number of threads used for the calculation. 544 545 Returns: 546 The lifted uv ids (= superpixel ids connected by the lifted edge). 547 The lifted costs (= cost associated with each lifted edge). 548 """ 549 assert ndist is not None, "Need nifty.distributed package" 550 551 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 552 # validate inputs 553 assert isinstance(input_maps, (list, tuple)) 554 assert all(isinstance(inp, np.ndarray) for inp in input_maps) 555 shape = watershed.shape 556 assert all(inp.shape == shape for inp in input_maps) 557 558 # map the probability maps to superpixels - we only map to superpixels which 559 # have a larger mean expression than `assignment_threshold` 560 561 # TODO handle the dtype conversion for vigra gracefully somehow ... 562 # think about supporting uint8 input and normalizing 563 564 # TODO how do we handle cases where the same superpixel is mapped to 565 # more than one class ? 566 567 n_nodes = int(watershed.max()) + 1 568 node_labels = np.zeros(n_nodes, dtype="uint64") 569 node_features = np.zeros(n_nodes, dtype="float32") 570 # TODO we could allow for more features that could then be used for the cost estimation 571 for class_id, inp in enumerate(input_maps): 572 mean_prob = vigra.analysis.extractRegionFeatures(inp, watershed, features=["mean"])["mean"] 573 # we can in principle map multiple classes here, and right now will just override 574 class_mask = mean_prob > assignment_threshold 575 node_labels[class_mask] = class_id 576 node_features[class_mask] = mean_prob[class_mask] 577 578 # find all lifted edges up to the graph depth between mapped nodes 579 # NOTE we need to convert to the different graph type for now, but 580 # it would be nice to support all nifty graphs at some type 581 uv_ids = rag.uvIds() 582 g_temp = ndist.Graph(uv_ids) 583 584 lifted_uvs = ndist.liftedNeighborhoodFromNodeLabels(g_temp, node_labels, graph_depth, mode=mode, 585 numberOfThreads=n_threads, ignoreLabel=0) 586 lifted_labels = node_labels[lifted_uvs] 587 lifted_features = node_features[lifted_uvs] 588 589 lifted_costs = feats_to_costs(lifted_labels, lifted_features) 590 return lifted_uvs, lifted_costs
Compute lifted problem from probability maps by mapping them to superpixels.
Example: compute a lifted problem from two attributions (axon, dendrite) that induce
repulsive edges between different attributions. The construction of lifted eges and
features can be customized using the feats_to_costs
and mode
arguments.
lifted_uvs, lifted_costs = lifted_problem_from_probabilties(
rag, superpixels,
input_maps=[
axon_probabilities, # probabilty map for axon attribution
dendrite_probabilities # probability map for dendrite attributtion
],
assignment_threshold=0.6, # probability threshold to assign superpixels to a class
graph_depth=10, # the max. graph depth along which lifted edges are introduced
)
Arguments:
- rag: The region adjacency graph.
- watershed: The watershed over-segmentation.
- input_maps: List of probability maps. Each map must have the same shape as the watersheds and each map is treated as the probability to correspond to a different class.
- assignment_threshold: Minimal expression level to assign a class to a graph node (= watershed segment).
- graph_depth: Maximal graph depth up to which lifted edges will be included.
- feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities.
The input to the function are
lifted_labels
, which stores the two classes assigned to a lifted edge, andlifted_features
, which stores the two assignment probabilities. - mode: The mode for insertion of lifted edges. One of: "all" - lifted edges will be inserted in between all nodes with attribution. "different" - lifted edges will only be inserted in between nodes attributed to different classes. "same" - lifted edges will only be inserted in between nodes attribted to the same class.
- n_threads: The number of threads used for the calculation.
Returns:
The lifted uv ids (= superpixel ids connected by the lifted edge). The lifted costs (= cost associated with each lifted edge).
594def lifted_problem_from_segmentation( 595 rag, 596 watershed: np.ndarray, 597 input_segmentation: np.ndarray, 598 overlap_threshold: float, 599 graph_depth: int, 600 same_segment_cost: float, 601 different_segment_cost: float, 602 mode: str = "all", 603 n_threads: Optional[int] = None, 604) -> Tuple[np.ndarray, np.ndarray]: 605 """Compute lifted problem from segmentation by mapping segments to superpixels. 606 607 Args: 608 rag: The region adjacency graph. 609 watershed: The watershed over-segmentation. 610 input_segmentation: The segmentation used to determine node attribution. 611 overlap_threshold: The minimal overlap to assign a segment id to node. 612 graph_depth: The maximal graph depth up to which lifted edges will be included 613 same_segment_cost: The cost for edges between nodes with same segment id attribution. 614 different_segment_cost: The cost for edges between nodes with different segment id attribution. 615 mode: The mode for insertion of lifted edges. One of: 616 "all" - lifted edges will be inserted in between all nodes with attribution. 617 "different" - lifted edges will only be inserted in between nodes attributed to different classes. 618 "same" - lifted edges will only be inserted in between nodes attribted to the same class. 619 n_threads: The number of threads used for the calculation. 620 621 Returns: 622 The lifted uv ids (= superpixel ids connected by the lifted edge). 623 The lifted costs (= cost associated with each lifted edge). 624 """ 625 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 626 assert input_segmentation.shape == watershed.shape 627 628 # compute the overlaps 629 ovlp_comp = ngt.overlap(watershed, input_segmentation) 630 ws_ids = np.unique(watershed) 631 n_labels = int(ws_ids[-1]) + 1 632 assert n_labels == rag.numberOfNodes, "%i, %i" % (n_labels, rag.numberOfNodes) 633 634 # initialise the arrays for node labels, to be 635 # dense in the watershed id space (even if some ws-ids are not present) 636 node_labels = np.zeros(n_labels, dtype="uint64") 637 638 # extract the overlap values and node labels from the overlap 639 # computation results 640 overlaps = [ovlp_comp.overlapArraysNormalized(ws_id, sorted=False) 641 for ws_id in ws_ids] 642 node_label_vals = np.array([ovlp[0][0] for ovlp in overlaps]) 643 overlap_values = np.array([ovlp[1][0] for ovlp in overlaps]) 644 node_label_vals[overlap_values < overlap_threshold] = 0 645 assert len(node_label_vals) == len(ws_ids) 646 node_labels[ws_ids] = node_label_vals 647 648 # find all lifted edges up to the graph depth between mapped nodes 649 # NOTE we need to convert to the different graph type for now, but 650 # it would be nice to support all nifty graphs at some type 651 uv_ids = rag.uvIds() 652 g_temp = ndist.Graph(uv_ids) 653 654 lifted_uvs = ndist.liftedNeighborhoodFromNodeLabels(g_temp, node_labels, graph_depth, mode=mode, 655 numberOfThreads=n_threads, ignoreLabel=0) 656 # make sure that the lifted uv ids are in range of the node labels 657 assert lifted_uvs.max() < rag.numberOfNodes, "%i, %i" % (int(lifted_uvs.max()), 658 rag.numberOfNodes) 659 lifted_labels = node_labels[lifted_uvs] 660 lifted_costs = np.zeros(len(lifted_labels), dtype="float64") 661 662 same_mask = lifted_labels[:, 0] == lifted_labels[:, 1] 663 lifted_costs[same_mask] = same_segment_cost 664 lifted_costs[~same_mask] = different_segment_cost 665 666 return lifted_uvs, lifted_costs
Compute lifted problem from segmentation by mapping segments to superpixels.
Arguments:
- rag: The region adjacency graph.
- watershed: The watershed over-segmentation.
- input_segmentation: The segmentation used to determine node attribution.
- overlap_threshold: The minimal overlap to assign a segment id to node.
- graph_depth: The maximal graph depth up to which lifted edges will be included
- same_segment_cost: The cost for edges between nodes with same segment id attribution.
- different_segment_cost: The cost for edges between nodes with different segment id attribution.
- mode: The mode for insertion of lifted edges. One of: "all" - lifted edges will be inserted in between all nodes with attribution. "different" - lifted edges will only be inserted in between nodes attributed to different classes. "same" - lifted edges will only be inserted in between nodes attribted to the same class.
- n_threads: The number of threads used for the calculation.
Returns:
The lifted uv ids (= superpixel ids connected by the lifted edge). The lifted costs (= cost associated with each lifted edge).
673def get_stitch_edges( 674 rag, 675 seg: np.ndarray, 676 block_shape: Tuple[int, ...], 677 n_threads: Optional[int] = None, 678 verbose: bool = False 679) -> np.ndarray: 680 """Get the edges between blocks. 681 682 Args: 683 rag: The region adjacency graph. 684 seg: The segmentation underlying the rag. 685 block_shape: The shape of the blocking. 686 n_threads: The number of threads used for the calculation. 687 verbose: Whether to be verbose. 688 689 Returns: 690 The edge mask indicating edges between blocks. 691 """ 692 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 693 ndim = seg.ndim 694 blocking = nifty.tools.blocking([0] * ndim, seg.shape, block_shape) 695 696 def find_stitch_edges(block_id): 697 stitch_edges = [] 698 block = blocking.getBlock(block_id) 699 for axis in range(ndim): 700 if blocking.getNeighborId(block_id, axis, True) == -1: 701 continue 702 face_a = tuple( 703 beg if d == axis else slice(beg, end) 704 for d, beg, end in zip(range(ndim), block.begin, block.end) 705 ) 706 face_b = tuple( 707 beg - 1 if d == axis else slice(beg, end) 708 for d, beg, end in zip(range(ndim), block.begin, block.end) 709 ) 710 711 labels_a = seg[face_a].ravel() 712 labels_b = seg[face_b].ravel() 713 714 uv_ids = np.concatenate( 715 [labels_a[:, None], labels_b[:, None]], 716 axis=1 717 ) 718 uv_ids = np.unique(uv_ids, axis=0) 719 720 edge_ids = rag.findEdges(uv_ids) 721 edge_ids = edge_ids[edge_ids != -1] 722 stitch_edges.append(edge_ids) 723 724 if stitch_edges: 725 stitch_edges = np.concatenate(stitch_edges) 726 stitch_edges = np.unique(stitch_edges) 727 else: 728 stitch_edges = None 729 return stitch_edges 730 731 with futures.ThreadPoolExecutor(n_threads) as tp: 732 if verbose: 733 stitch_edges = list(tqdm( 734 tp.map(find_stitch_edges, range(blocking.numberOfBlocks)), 735 total=blocking.numberOfBlocks 736 )) 737 else: 738 stitch_edges = tp.map(find_stitch_edges, range(blocking.numberOfBlocks)) 739 740 stitch_edges = np.concatenate([st for st in stitch_edges if st is not None]) 741 stitch_edges = np.unique(stitch_edges) 742 full_edges = np.zeros(rag.numberOfEdges, dtype="bool") 743 full_edges[stitch_edges] = 1 744 return full_edges
Get the edges between blocks.
Arguments:
- rag: The region adjacency graph.
- seg: The segmentation underlying the rag.
- block_shape: The shape of the blocking.
- n_threads: The number of threads used for the calculation.
- verbose: Whether to be verbose.
Returns:
The edge mask indicating edges between blocks.
747def project_node_labels_to_pixels(rag, node_labels: np.ndarray, n_threads: Optional[int] = None) -> np.ndarray: 748 """Project label values for graph nodes back to pixels to obtain segmentation. 749 750 Args: 751 rag: The region adjacency graph. 752 node_labels: The array with node labels. 753 n_threads: The number of threads used, set to cpu count by default. 754 755 Returns: 756 The segmentation. 757 """ 758 n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads 759 if len(node_labels) != rag.numberOfNodes: 760 raise ValueError("Incompatible number of node labels: %i, %i" % (len(node_labels), rag.numberOfNodes)) 761 seg = nrag.projectScalarNodeDataToPixels(rag, node_labels, numberOfThreads=n_threads) 762 return seg
Project label values for graph nodes back to pixels to obtain segmentation.
Arguments:
- rag: The region adjacency graph.
- node_labels: The array with node labels.
- n_threads: The number of threads used, set to cpu count by default.
Returns:
The segmentation.
765def compute_z_edge_mask(rag, watershed: np.ndarray) -> np.ndarray: 766 """Compute edge mask of in-between plane edges for flat superpixels. 767 768 Flat superpixels are volumetric superpixels that are independent across slices. 769 This function does not check wether the input watersheds are actually flat. 770 771 Args: 772 rag: The region adjacency graph. 773 watershed: The underlying watershed over-segmentation (superpixels). 774 775 Returns: 776 The edge mask indicating in-between slice edges. 777 """ 778 node_z_coords = np.zeros(rag.numberOfNodes, dtype="uint32") 779 for z in range(watershed.shape[0]): 780 node_z_coords[watershed[z]] = z 781 uv_ids = rag.uvIds() 782 z_edge_mask = node_z_coords[uv_ids[:, 0]] != node_z_coords[uv_ids[:, 1]] 783 return z_edge_mask
Compute edge mask of in-between plane edges for flat superpixels.
Flat superpixels are volumetric superpixels that are independent across slices. This function does not check wether the input watersheds are actually flat.
Arguments:
- rag: The region adjacency graph.
- watershed: The underlying watershed over-segmentation (superpixels).
Returns:
The edge mask indicating in-between slice edges.