elf.segmentation.features

  1import multiprocessing
  2from concurrent import futures
  3from typing import Dict, List, Optional, Tuple
  4
  5import numpy as np
  6import vigra
  7import nifty
  8import nifty.graph.rag as nrag
  9import nifty.ground_truth as ngt
 10try:
 11    import nifty.distributed as ndist
 12except ImportError:
 13    ndist = None
 14
 15try:
 16    import fastfilters as ff
 17except ImportError:
 18    import vigra.filters as ff
 19
 20from tqdm import tqdm
 21from .multicut import transform_probabilities_to_costs
 22
 23
 24#
 25# Region Adjacency Graph and Features
 26#
 27
 28def compute_rag(segmentation: np.ndarray, n_labels: Optional[int] = None, n_threads: Optional[int] = None):
 29    """Compute region adjacency graph of segmentation.
 30
 31    Args:
 32        segmentation: The segmentation.
 33        n_labels: The number of labels in the segmentation. If None, will be computed from the data.
 34        n_threads: The number of threads used, set to cpu count by default.
 35
 36    Returns:
 37        The region adjacency graph.
 38    """
 39    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
 40    n_labels = int(segmentation.max()) + 1 if n_labels is None else n_labels
 41    rag = nrag.gridRag(segmentation, numberOfLabels=n_labels, numberOfThreads=n_threads)
 42    return rag
 43
 44
 45def compute_boundary_features(
 46    rag, boundary_map: np.ndarray, min_value: float = 0.0, max_value: float = 1.0, n_threads: Optional[int] = None
 47) -> np.ndarray:
 48    """Compute edge features from boundary map.
 49
 50    Args:
 51        rag: The region adjacency graph.
 52        boundary_map:The boundary map.
 53        min_value: The minimum value used in accumulation.
 54        max_value: The maximum value used in accumulation.
 55        n_threads: The number of threads used, set to cpu count by default.
 56
 57    Returns:
 58        The edge features.
 59    """
 60    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
 61    if tuple(rag.shape) != boundary_map.shape:
 62        raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(boundary_map.shape)))
 63    features = nrag.accumulateEdgeStandartFeatures(
 64        rag, boundary_map, min_value, max_value, numberOfThreads=n_threads
 65    )
 66    return features
 67
 68
 69def compute_affinity_features(
 70    rag,
 71    affinity_map: np.ndarray,
 72    offsets: List[List[int]],
 73    min_value: float = 0.0,
 74    max_value: float = 1.0,
 75    n_threads: Optional[int] = None
 76) -> np.ndarray:
 77    """Compute edge features from affinity map.
 78
 79    Args:
 80        rag: The region adjacency graph.
 81        affinity_map: The affinity map.
 82        min_value: The minimum value used in accumulation.
 83        max_value: The maximum value used in accumulation.
 84        n_threads: The umber of threads used, set to cpu count by default.
 85
 86    Returns:
 87        The edge features.
 88    """
 89    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
 90    if tuple(rag.shape) != affinity_map.shape[1:]:
 91        raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(affinity_map.shape[1:])))
 92    if len(offsets) != affinity_map.shape[0]:
 93        raise ValueError("Incompatible number of channels and offsets: %i, %i" % (len(offsets),
 94                                                                                  affinity_map.shape[0]))
 95    features = nrag.accumulateAffinityStandartFeatures(
 96        rag, affinity_map, offsets, min_value, max_value, numberOfThreads=n_threads
 97    )
 98    return features
 99
100
101def compute_boundary_mean_and_length(rag, input_: np.ndarray, n_threads: Optional[int] = None) -> np.ndarray:
102    """Compute mean value and length of boundaries.
103
104    Args:
105        rag: The region adjacency graph.
106        input_: The input map.
107        n_threads: The number of threads used, set to cpu count by default.
108
109    Returns:
110        The edge features.
111    """
112    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
113    if tuple(rag.shape) != input_.shape:
114        raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(input_.shape)))
115    features = nrag.accumulateEdgeMeanAndLength(rag, input_, numberOfThreads=n_threads)
116    return features
117
118
119# TODO generalize and move to elf.features.parallel
120def _filter_2d(input_, filter_name, sigma, n_threads):
121    filter_fu = getattr(ff, filter_name)
122
123    def _fz(inp):
124        response = filter_fu(inp, sigma)
125        # we add a channel last axis for 2d filter responses
126        if response.ndim == 2:
127            response = response[None, ..., None]
128        elif response.ndim == 3:
129            response = response[None]
130        else:
131            raise RuntimeError("Invalid filter response")
132        return response
133
134    with futures.ThreadPoolExecutor(n_threads) as tp:
135        tasks = [tp.submit(_fz, input_[z]) for z in range(input_.shape[0])]
136        response = [t.result() for t in tasks]
137
138    response = np.concatenate(response, axis=0)
139    return response
140
141
142def compute_boundary_features_with_filters(
143    rag,
144    input_: np.ndarray,
145    apply_2d: bool = False,
146    n_threads: Optional[int] = None,
147    filters: Dict[str, List[float]] = {"gaussianSmoothing": [1.6, 4.2, 8.3],
148                                       "laplacianOfGaussian": [1.6, 4.2, 8.3],
149                                       "hessianOfGaussianEigenvalues": [1.6, 4.2, 8.3]}
150) -> np.ndarray:
151    """Compute boundary features accumulated over filter responses on input.
152
153    Args:
154        rag: The region adjacency graph.
155        input_: The input data.
156        apply_2d: Whether to apply the filters in 2d for 3d input data.
157        n_threads: The number of threads.
158        filters: The filters to apply, expects a dictionary mapping filter names to sigma values.
159
160    Returns:
161        The edge filters.
162    """
163    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
164    features = []
165
166    # apply 2d: we compute filters and derived features in parallel per filter
167    if apply_2d:
168
169        def _compute_2d(filter_name, sigma):
170            response = _filter_2d(input_, filter_name, sigma, n_threads)
171            assert response.ndim == 4
172            n_channels = response.shape[-1]
173            features = []
174            for chan in range(n_channels):
175                chan_data = response[..., chan]
176                feats = compute_boundary_features(rag, chan_data,
177                                                  chan_data.min(), chan_data.max(), n_threads)
178                features.append(feats)
179
180            features = np.concatenate(features, axis=1)
181            assert len(features) == rag.numberOfEdges
182            return features
183
184        features = [_compute_2d(filter_name, sigma)
185                    for filter_name, sigmas in filters.items() for sigma in sigmas]
186
187    # apply 3d: we parallelize over the whole filter + feature computation
188    # this can be very memory intensive, and it would be better to parallelize inside
189    # of the loop, but 3d parallel filters in elf.parallel.filters are not working properly yet
190    else:
191
192        def _compute_3d(filter_name, sigma):
193            filter_fu = getattr(ff, filter_name)
194            response = filter_fu(input_, sigma)
195            if response.ndim == 3:
196                response = response[..., None]
197
198            n_channels = response.shape[-1]
199            features = []
200
201            for chan in range(n_channels):
202                chan_data = response[..., chan]
203                feats = compute_boundary_features(rag, chan_data,
204                                                  chan_data.min(), chan_data.max(),
205                                                  n_threads=1)
206                features.append(feats)
207            features = np.concatenate(features, axis=1)
208            assert len(features) == rag.numberOfEdges, f"{len(features), {rag.numberOfEdges}}"
209            return features
210
211        with futures.ThreadPoolExecutor(n_threads) as tp:
212            tasks = [tp.submit(_compute_3d, filter_name, sigma)
213                     for filter_name, sigmas in filters.items() for sigma in sigmas]
214            features = [t.result() for t in tasks]
215
216    features = np.concatenate(features, axis=1)
217    assert len(features) == rag.numberOfEdges
218    return features
219
220
221def compute_region_features(
222    uv_ids: np.ndarray,
223    input_map: np.ndarray,
224    segmentation: np.ndarray,
225    n_threads: Optional[int] = None
226) -> np.ndarray:
227    """Compute edge features from an input map accumulated over segmentation and mapped to edges.
228
229    Args:
230        uv_ids: The edge uv ids.
231        input_: The input data.
232        segmentation: The segmentation.
233        n_threads: The number of threads used, set to cpu count by default.
234
235    Returns:
236        The edge features.
237    """
238    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
239
240    # compute the node features
241    stat_feature_names = ["Count", "Kurtosis", "Maximum", "Minimum", "Quantiles",
242                          "RegionRadii", "Skewness", "Sum", "Variance"]
243    coord_feature_names = ["Weighted<RegionCenter>", "RegionCenter"]
244    feature_names = stat_feature_names + coord_feature_names
245    node_features = vigra.analysis.extractRegionFeatures(input_map, segmentation,
246                                                         features=feature_names)
247
248    # get the image statistics based features, that are combined via [min, max, sum, absdiff]
249    stat_features = [node_features[fname] for fname in stat_feature_names]
250    stat_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat
251                                    for feat in stat_features], axis=1)
252
253    # get the coordinate based features, that are combined via euclidean distance
254    coord_features = [node_features[fname] for fname in coord_feature_names]
255    coord_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat
256                                     for feat in coord_features], axis=1)
257
258    u, v = uv_ids[:, 0], uv_ids[:, 1]
259
260    # combine the stat features for all edges
261    feats_u, feats_v = stat_features[u], stat_features[v]
262    features = [np.minimum(feats_u, feats_v), np.maximum(feats_u, feats_v),
263                np.abs(feats_u - feats_v), feats_u + feats_v]
264
265    # combine the coord features for all edges
266    feats_u, feats_v = coord_features[u], coord_features[v]
267    features.append((feats_u - feats_v) ** 2)
268
269    features = np.nan_to_num(np.concatenate(features, axis=1))
270    assert len(features) == len(uv_ids)
271    return features
272
273
274#
275# Grid Graph and Features
276#
277
278def compute_grid_graph(shape: Tuple[int, ...]):
279    """Compute grid graph for the given shape.
280
281    Args:
282        shape: The shape of the data.
283
284    Returns:
285        The grid graph.
286    """
287    grid_graph = nifty.graph.undirectedGridGraph(shape)
288    return grid_graph
289
290
291def compute_grid_graph_image_features(
292    grid_graph,
293    image: np.ndarray,
294    mode: str,
295    offsets: Optional[List[List[int]]] = None,
296    strides: Optional[List[int]] = None,
297    randomize_strides: bool = False,
298) -> Tuple[np.ndarray, np.ndarray]:
299    """Compute edge features for image for the given grid_graph.
300
301    Args:
302        grid_graph: The grid graph
303        image: The image, from which the features will be derived.
304        mode: Feature accumulation method.
305        offsets: The offsets, which correspond to the affinity channels.
306            If none are given, the affinites for the nearest neighbor transitions are used.
307        strides: The strides used to subsample edges that are computed from offsets.
308        randomize_strides: Whether to subsample randomly instead of using regular strides.
309
310    Returns:
311        The uv ids of the edges.
312        The edge features.
313    """
314    gndim = len(grid_graph.shape)
315
316    if image.ndim == gndim:
317        if offsets is not None:
318            raise NotImplementedError
319        modes = ("l1", "l2", "min", "max", "sum", "prod", "interpixel")
320        if mode not in modes:
321            raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}")
322        features = grid_graph.imageToEdgeMap(image, mode)
323        edges = grid_graph.uvIds()
324
325    elif image.ndim == gndim + 1:
326        modes = ("l1", "l2", "cosine")
327        if mode not in modes:
328            raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}")
329
330        if offsets is None:
331            features = grid_graph.imageWithChannelsToEdgeMap(image, mode)
332            edges = grid_graph.uvIds()
333        else:
334            (n_edges,
335             edges,
336             features) = grid_graph.imageWithChannelsToEdgeMapWithOffsets(image, mode,
337                                                                          offsets=offsets,
338                                                                          strides=strides,
339                                                                          randomize_strides=randomize_strides)
340            edges, features = edges[:n_edges], features[:n_edges]
341
342    else:
343        msg = f"Invalid image dimension {image.ndim}, expect one of {gndim} or {gndim + 1}"
344        raise ValueError(msg)
345
346    return edges, features
347
348
349def compute_grid_graph_affinity_features(
350    grid_graph,
351    affinities: np.ndarray,
352    offsets: Optional[List[List[int]]] = None,
353    strides: Optional[List[int]] = None,
354    mask: Optional[np.ndarray] = None,
355    randomize_strides: bool = False,
356) -> Tuple[np.ndarray, np.ndarray]:
357    """Compute edge features from affinities for the given grid graph.
358
359    Args:
360        grid_graph: The grid graph
361        affinities: The affinity map.
362        offsets: The offsets, which correspond to the affinity channels.
363            If none are given, the affinites for the nearest neighbor transitions are used.
364        strides: The strides used to subsample edges that are computed from offsets.
365        mask: Mask to exclude from the edge and feature computation.
366        randomize_strides: Whether to subsample randomly instead of using regular strides.
367
368    Returns:
369        The uv ids of the edges.
370        The edge features.
371    """
372    gndim = len(grid_graph.shape)
373    if affinities.ndim != gndim + 1:
374        raise ValueError
375
376    if offsets is None:
377        assert affinities.shape[0] == gndim
378        assert strides is None
379        assert mask is None
380        features = grid_graph.affinitiesToEdgeMap(affinities)
381        edges = grid_graph.uvIds()
382    elif mask is not None:
383        assert strides is None and not randomize_strides, "Strides and mask cannot be used at the same time"
384        n_edges, edges, features = grid_graph.affinitiesToEdgeMapWithMask(affinities,
385                                                                          offsets=offsets,
386                                                                          mask=mask)
387        edges, features = edges[:n_edges], features[:n_edges]
388    else:
389        n_edges, edges, features = grid_graph.affinitiesToEdgeMapWithOffsets(affinities,
390                                                                             offsets=offsets,
391                                                                             strides=strides,
392                                                                             randomize_strides=randomize_strides)
393        edges, features = edges[:n_edges], features[:n_edges]
394
395    return edges, features
396
397
398def apply_mask_to_grid_graph_weights(
399    grid_graph,
400    mask: np.ndarray,
401    weights: np.ndarray,
402    masked_edge_weight: float = 0.0,
403    transition_edge_weight: float = 1.0,
404) -> np.ndarray:
405    """Mask edges in grid graph.
406
407    Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes
408    and edges that connect masked and unmasked nodes.
409
410    Args:
411        grid_graph: The grid graph.
412        mask: The binary mask, foreground (=non-masked) is True.
413        weights: The edge weights.
414        masked_edge_weight: The value for edges that connect two masked nodes.
415        transition_edge_weight: The value for edges that connect a masked with a non-masked node.
416
417    Returns:
418        The masked edge weights.
419    """
420    assert np.dtype(mask.dtype) == np.dtype("bool")
421    node_ids = grid_graph.projectNodeIdsToPixels()
422    assert node_ids.shape == mask.shape == tuple(grid_graph.shape), \
423        f"{node_ids.shape}, {mask.shape}, {grid_graph.shape}"
424    masked_ids = node_ids[~mask]
425
426    edges = grid_graph.uvIds()
427    assert len(edges) == len(weights)
428    edge_state = np.isin(edges, masked_ids).sum(axis=1)
429    masked_edges = edge_state == 2
430    transition_edges = edge_state == 1
431    weights[masked_edges] = masked_edge_weight
432    weights[transition_edges] = transition_edge_weight
433    return weights
434
435
436def apply_mask_to_grid_graph_edges_and_weights(
437    grid_graph, mask: np.ndarray, edges: np.ndarray, weights: np.ndarray, transition_edge_weight: float = 1.0
438) -> Tuple[np.ndarray, np.ndarray]:
439    """Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value.
440
441    Args:
442        grid_graph: The grid graph.
443        mask: The binary mask, foreground (=non-masked) is True.
444        edges: The edges (uv-ids).
445        weights: The edge weights.
446        transition_edge_weight: The value for edges that connect a masked with a non-masked node.
447
448    Returns:
449        The edge uv-ids.
450        The edge weights.
451    """
452    assert np.dtype(mask.dtype) == np.dtype("bool")
453    node_ids = grid_graph.projectNodeIdsToPixels()
454    assert node_ids.shape == mask.shape == tuple(grid_graph.shape), \
455        f"{node_ids.shape}, {mask.shape}, {grid_graph.shape}"
456    masked_ids = node_ids[~mask]
457
458    edge_state = np.isin(edges, masked_ids).sum(axis=1)
459    keep_edges = edge_state != 2
460
461    edges, weights, edge_state = edges[keep_edges], weights[keep_edges], edge_state[keep_edges]
462    transition_edges = edge_state == 1
463    weights[transition_edges] = transition_edge_weight
464
465    return edges, weights
466
467
468#
469# Lifted Features
470#
471
472def lifted_edges_from_graph_neighborhood(graph, max_graph_distance):
473    """@private
474    """
475    if max_graph_distance < 2:
476        raise ValueError(f"Graph distance must be greater equal 2, got {max_graph_distance}")
477    if isinstance(graph, nifty.graph.UndirectedGraph):
478        objective = nifty.graph.opt.lifted_multicut.liftedMulticutObjective(graph)
479    else:
480        graph_ = nifty.graph.undirectedGraph(graph.numberOfNodes)
481        graph_.insertEdges(graph.uvIds())
482        objective = nifty.graph.opt.lifted_multicut.liftedMulticutObjective(graph_)
483    objective.insertLiftedEdgesBfs(max_graph_distance)
484    lifted_uvs = objective.liftedUvIds()
485    return lifted_uvs
486
487
488def feats_to_costs_default(lifted_labels, lifted_features):
489    """@private
490    """
491    # we assume that we only have different classes for a given lifted
492    # edge here (mode = "different") and then set all edges to be repulsive
493
494    # the higher the class probability, the more repulsive the edges should be,
495    # so we just multiply both probabilities
496    lifted_costs = lifted_features[:, 0] * lifted_features[:, 1]
497    lifted_costs = transform_probabilities_to_costs(lifted_costs)
498    return lifted_costs
499
500
501def lifted_problem_from_probabilities(
502    rag,
503    watershed: np.ndarray,
504    input_maps: List[np.ndarray],
505    assignment_threshold: float,
506    graph_depth: int,
507    feats_to_costs: callable = feats_to_costs_default,
508    mode: str = "different",
509    n_threads: Optional[int] = None,
510) -> Tuple[np.ndarray, np.ndarray]:
511    """Compute lifted problem from probability maps by mapping them to superpixels.
512
513    Example: compute a lifted problem from two attributions (axon, dendrite) that induce
514    repulsive edges between different attributions. The construction of lifted eges and
515    features can be customized using the `feats_to_costs` and `mode` arguments.
516    ```
517    lifted_uvs, lifted_costs = lifted_problem_from_probabilties(
518       rag, superpixels,
519       input_maps=[
520         axon_probabilities,  # probabilty map for axon attribution
521         dendrite_probabilities  # probability map for dendrite attributtion
522       ],
523       assignment_threshold=0.6,  # probability threshold to assign superpixels to a class
524       graph_depth=10,  # the max. graph depth along which lifted edges are introduced
525    )
526    ```
527
528    Args:
529        rag: The region adjacency graph.
530        watershed: The watershed over-segmentation.
531        input_maps: List of probability maps. Each map must have the same shape as the watersheds
532            and each map is treated as the probability to correspond to a different class.
533        assignment_threshold: Minimal expression level to assign a class to a graph node (= watershed segment).
534        graph_depth: Maximal graph depth up to which lifted edges will be included.
535        feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities.
536            The input to the function are `lifted_labels`, which stores the two classes assigned to a lifted edge,
537            and `lifted_features`, which stores the two assignment probabilities.
538        mode: The mode for insertion of lifted edges. One of:
539            "all" - lifted edges will be inserted in between all nodes with attribution.
540            "different" - lifted edges will only be inserted in between nodes attributed to different classes.
541            "same" - lifted edges will only be inserted in between nodes attribted to the same class.
542        n_threads: The number of threads used for the calculation.
543
544    Returns:
545        The lifted uv ids (= superpixel ids connected by the lifted edge).
546        The lifted costs (= cost associated with each lifted edge).
547    """
548    assert ndist is not None, "Need nifty.distributed package"
549
550    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
551    # validate inputs
552    assert isinstance(input_maps, (list, tuple))
553    assert all(isinstance(inp, np.ndarray) for inp in input_maps)
554    shape = watershed.shape
555    assert all(inp.shape == shape for inp in input_maps)
556
557    # map the probability maps to superpixels - we only map to superpixels which
558    # have a larger mean expression than `assignment_threshold`
559
560    # TODO handle the dtype conversion for vigra gracefully somehow ...
561    # think about supporting uint8 input and normalizing
562
563    # TODO how do we handle cases where the same superpixel is mapped to
564    # more than one class ?
565
566    n_nodes = int(watershed.max()) + 1
567    node_labels = np.zeros(n_nodes, dtype="uint64")
568    node_features = np.zeros(n_nodes, dtype="float32")
569    # TODO we could allow for more features that could then be used for the cost estimation
570    for class_id, inp in enumerate(input_maps):
571        mean_prob = vigra.analysis.extractRegionFeatures(inp, watershed, features=["mean"])["mean"]
572        # we can in principle map multiple classes here, and right now will just override
573        class_mask = mean_prob > assignment_threshold
574        node_labels[class_mask] = class_id
575        node_features[class_mask] = mean_prob[class_mask]
576
577    # find all lifted edges up to the graph depth between mapped nodes
578    # NOTE we need to convert to the different graph type for now, but
579    # it would be nice to support all nifty graphs at some type
580    uv_ids = rag.uvIds()
581    g_temp = ndist.Graph(uv_ids)
582
583    lifted_uvs = ndist.liftedNeighborhoodFromNodeLabels(g_temp, node_labels, graph_depth, mode=mode,
584                                                        numberOfThreads=n_threads, ignoreLabel=0)
585    lifted_labels = node_labels[lifted_uvs]
586    lifted_features = node_features[lifted_uvs]
587
588    lifted_costs = feats_to_costs(lifted_labels, lifted_features)
589    return lifted_uvs, lifted_costs
590
591
592# TODO support setting costs proportional to overlaps
593def lifted_problem_from_segmentation(
594    rag,
595    watershed: np.ndarray,
596    input_segmentation: np.ndarray,
597    overlap_threshold: float,
598    graph_depth: int,
599    same_segment_cost: float,
600    different_segment_cost: float,
601    mode: str = "all",
602    n_threads: Optional[int] = None,
603) -> Tuple[np.ndarray, np.ndarray]:
604    """Compute lifted problem from segmentation by mapping segments to superpixels.
605
606    Args:
607        rag: The region adjacency graph.
608        watershed: The watershed over-segmentation.
609        input_segmentation: The segmentation used to determine node attribution.
610        overlap_threshold: The minimal overlap to assign a segment id to node.
611        graph_depth: The maximal graph depth up to which lifted edges will be included
612        same_segment_cost: The cost for edges between nodes with same segment id attribution.
613        different_segment_cost: The cost for edges between nodes with different segment id attribution.
614        mode: The mode for insertion of lifted edges. One of:
615            "all" - lifted edges will be inserted in between all nodes with attribution.
616            "different" - lifted edges will only be inserted in between nodes attributed to different classes.
617            "same" - lifted edges will only be inserted in between nodes attribted to the same class.
618        n_threads: The number of threads used for the calculation.
619
620    Returns:
621        The lifted uv ids (= superpixel ids connected by the lifted edge).
622        The lifted costs (= cost associated with each lifted edge).
623    """
624    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
625    assert input_segmentation.shape == watershed.shape
626
627    # compute the overlaps
628    ovlp_comp = ngt.overlap(watershed, input_segmentation)
629    ws_ids = np.unique(watershed)
630    n_labels = int(ws_ids[-1]) + 1
631    assert n_labels == rag.numberOfNodes, "%i, %i" % (n_labels, rag.numberOfNodes)
632
633    # initialise the arrays for node labels, to be
634    # dense in the watershed id space (even if some ws-ids are not present)
635    node_labels = np.zeros(n_labels, dtype="uint64")
636
637    # extract the overlap values and node labels from the overlap
638    # computation results
639    overlaps = [ovlp_comp.overlapArraysNormalized(ws_id, sorted=False)
640                for ws_id in ws_ids]
641    node_label_vals = np.array([ovlp[0][0] for ovlp in overlaps])
642    overlap_values = np.array([ovlp[1][0] for ovlp in overlaps])
643    node_label_vals[overlap_values < overlap_threshold] = 0
644    assert len(node_label_vals) == len(ws_ids)
645    node_labels[ws_ids] = node_label_vals
646
647    # find all lifted edges up to the graph depth between mapped nodes
648    # NOTE we need to convert to the different graph type for now, but
649    # it would be nice to support all nifty graphs at some type
650    uv_ids = rag.uvIds()
651    g_temp = ndist.Graph(uv_ids)
652
653    lifted_uvs = ndist.liftedNeighborhoodFromNodeLabels(g_temp, node_labels, graph_depth, mode=mode,
654                                                        numberOfThreads=n_threads, ignoreLabel=0)
655    # make sure that the lifted uv ids are in range of the node labels
656    assert lifted_uvs.max() < rag.numberOfNodes, "%i, %i" % (int(lifted_uvs.max()),
657                                                             rag.numberOfNodes)
658    lifted_labels = node_labels[lifted_uvs]
659    lifted_costs = np.zeros(len(lifted_labels), dtype="float64")
660
661    same_mask = lifted_labels[:, 0] == lifted_labels[:, 1]
662    lifted_costs[same_mask] = same_segment_cost
663    lifted_costs[~same_mask] = different_segment_cost
664
665    return lifted_uvs, lifted_costs
666
667
668#
669# Misc
670#
671
672def get_stitch_edges(
673    rag,
674    seg: np.ndarray,
675    block_shape: Tuple[int, ...],
676    n_threads: Optional[int] = None,
677    verbose: bool = False
678) -> np.ndarray:
679    """Get the edges between blocks.
680
681    Args:
682        rag: The region adjacency graph.
683        seg: The segmentation underlying the rag.
684        block_shape: The shape of the blocking.
685        n_threads: The number of threads used for the calculation.
686        verbose: Whether to be verbose.
687
688    Returns:
689        The edge mask indicating edges between blocks.
690    """
691    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
692    ndim = seg.ndim
693    blocking = nifty.tools.blocking([0] * ndim, seg.shape, block_shape)
694
695    def find_stitch_edges(block_id):
696        stitch_edges = []
697        block = blocking.getBlock(block_id)
698        for axis in range(ndim):
699            if blocking.getNeighborId(block_id, axis, True) == -1:
700                continue
701            face_a = tuple(
702                beg if d == axis else slice(beg, end)
703                for d, beg, end in zip(range(ndim), block.begin, block.end)
704            )
705            face_b = tuple(
706                beg - 1 if d == axis else slice(beg, end)
707                for d, beg, end in zip(range(ndim), block.begin, block.end)
708            )
709
710            labels_a = seg[face_a].ravel()
711            labels_b = seg[face_b].ravel()
712
713            uv_ids = np.concatenate(
714                [labels_a[:, None], labels_b[:, None]],
715                axis=1
716            )
717            uv_ids = np.unique(uv_ids, axis=0)
718
719            edge_ids = rag.findEdges(uv_ids)
720            edge_ids = edge_ids[edge_ids != -1]
721            stitch_edges.append(edge_ids)
722
723        if stitch_edges:
724            stitch_edges = np.concatenate(stitch_edges)
725            stitch_edges = np.unique(stitch_edges)
726        else:
727            stitch_edges = None
728        return stitch_edges
729
730    with futures.ThreadPoolExecutor(n_threads) as tp:
731        if verbose:
732            stitch_edges = list(tqdm(
733                tp.map(find_stitch_edges, range(blocking.numberOfBlocks)),
734                total=blocking.numberOfBlocks
735            ))
736        else:
737            stitch_edges = tp.map(find_stitch_edges, range(blocking.numberOfBlocks))
738
739    stitch_edges = np.concatenate([st for st in stitch_edges if st is not None])
740    stitch_edges = np.unique(stitch_edges)
741    full_edges = np.zeros(rag.numberOfEdges, dtype="bool")
742    full_edges[stitch_edges] = 1
743    return full_edges
744
745
746def project_node_labels_to_pixels(rag, node_labels: np.ndarray, n_threads: Optional[int] = None) -> np.ndarray:
747    """Project label values for graph nodes back to pixels to obtain segmentation.
748
749    Args:
750        rag: The region adjacency graph.
751        node_labels: The array with node labels.
752        n_threads: The number of threads used, set to cpu count by default.
753
754    Returns:
755        The segmentation.
756    """
757    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
758    if len(node_labels) != rag.numberOfNodes:
759        raise ValueError("Incompatible number of node labels: %i, %i" % (len(node_labels), rag.numberOfNodes))
760    seg = nrag.projectScalarNodeDataToPixels(rag, node_labels, numberOfThreads=n_threads)
761    return seg
762
763
764def compute_z_edge_mask(rag, watershed: np.ndarray) -> np.ndarray:
765    """Compute edge mask of in-between plane edges for flat superpixels.
766
767    Flat superpixels are volumetric superpixels that are independent across slices.
768    This function does not check wether the input watersheds are actually flat.
769
770    Args:
771        rag: The region adjacency graph.
772        watershed: The underlying watershed over-segmentation (superpixels).
773
774    Returns:
775        The edge mask indicating in-between slice edges.
776    """
777    node_z_coords = np.zeros(rag.numberOfNodes, dtype="uint32")
778    for z in range(watershed.shape[0]):
779        node_z_coords[watershed[z]] = z
780    uv_ids = rag.uvIds()
781    z_edge_mask = node_z_coords[uv_ids[:, 0]] != node_z_coords[uv_ids[:, 1]]
782    return z_edge_mask
def compute_rag( segmentation: numpy.ndarray, n_labels: Optional[int] = None, n_threads: Optional[int] = None):
29def compute_rag(segmentation: np.ndarray, n_labels: Optional[int] = None, n_threads: Optional[int] = None):
30    """Compute region adjacency graph of segmentation.
31
32    Args:
33        segmentation: The segmentation.
34        n_labels: The number of labels in the segmentation. If None, will be computed from the data.
35        n_threads: The number of threads used, set to cpu count by default.
36
37    Returns:
38        The region adjacency graph.
39    """
40    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
41    n_labels = int(segmentation.max()) + 1 if n_labels is None else n_labels
42    rag = nrag.gridRag(segmentation, numberOfLabels=n_labels, numberOfThreads=n_threads)
43    return rag

Compute region adjacency graph of segmentation.

Arguments:
  • segmentation: The segmentation.
  • n_labels: The number of labels in the segmentation. If None, will be computed from the data.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The region adjacency graph.

def compute_boundary_features( rag, boundary_map: numpy.ndarray, min_value: float = 0.0, max_value: float = 1.0, n_threads: Optional[int] = None) -> numpy.ndarray:
46def compute_boundary_features(
47    rag, boundary_map: np.ndarray, min_value: float = 0.0, max_value: float = 1.0, n_threads: Optional[int] = None
48) -> np.ndarray:
49    """Compute edge features from boundary map.
50
51    Args:
52        rag: The region adjacency graph.
53        boundary_map:The boundary map.
54        min_value: The minimum value used in accumulation.
55        max_value: The maximum value used in accumulation.
56        n_threads: The number of threads used, set to cpu count by default.
57
58    Returns:
59        The edge features.
60    """
61    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
62    if tuple(rag.shape) != boundary_map.shape:
63        raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(boundary_map.shape)))
64    features = nrag.accumulateEdgeStandartFeatures(
65        rag, boundary_map, min_value, max_value, numberOfThreads=n_threads
66    )
67    return features

Compute edge features from boundary map.

Arguments:
  • rag: The region adjacency graph.
  • boundary_map: The boundary map.
  • min_value: The minimum value used in accumulation.
  • max_value: The maximum value used in accumulation.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The edge features.

def compute_affinity_features( rag, affinity_map: numpy.ndarray, offsets: List[List[int]], min_value: float = 0.0, max_value: float = 1.0, n_threads: Optional[int] = None) -> numpy.ndarray:
70def compute_affinity_features(
71    rag,
72    affinity_map: np.ndarray,
73    offsets: List[List[int]],
74    min_value: float = 0.0,
75    max_value: float = 1.0,
76    n_threads: Optional[int] = None
77) -> np.ndarray:
78    """Compute edge features from affinity map.
79
80    Args:
81        rag: The region adjacency graph.
82        affinity_map: The affinity map.
83        min_value: The minimum value used in accumulation.
84        max_value: The maximum value used in accumulation.
85        n_threads: The umber of threads used, set to cpu count by default.
86
87    Returns:
88        The edge features.
89    """
90    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
91    if tuple(rag.shape) != affinity_map.shape[1:]:
92        raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(affinity_map.shape[1:])))
93    if len(offsets) != affinity_map.shape[0]:
94        raise ValueError("Incompatible number of channels and offsets: %i, %i" % (len(offsets),
95                                                                                  affinity_map.shape[0]))
96    features = nrag.accumulateAffinityStandartFeatures(
97        rag, affinity_map, offsets, min_value, max_value, numberOfThreads=n_threads
98    )
99    return features

Compute edge features from affinity map.

Arguments:
  • rag: The region adjacency graph.
  • affinity_map: The affinity map.
  • min_value: The minimum value used in accumulation.
  • max_value: The maximum value used in accumulation.
  • n_threads: The umber of threads used, set to cpu count by default.
Returns:

The edge features.

def compute_boundary_mean_and_length( rag, input_: numpy.ndarray, n_threads: Optional[int] = None) -> numpy.ndarray:
102def compute_boundary_mean_and_length(rag, input_: np.ndarray, n_threads: Optional[int] = None) -> np.ndarray:
103    """Compute mean value and length of boundaries.
104
105    Args:
106        rag: The region adjacency graph.
107        input_: The input map.
108        n_threads: The number of threads used, set to cpu count by default.
109
110    Returns:
111        The edge features.
112    """
113    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
114    if tuple(rag.shape) != input_.shape:
115        raise ValueError("Incompatible shapes: %s, %s" % (str(rag.shape), str(input_.shape)))
116    features = nrag.accumulateEdgeMeanAndLength(rag, input_, numberOfThreads=n_threads)
117    return features

Compute mean value and length of boundaries.

Arguments:
  • rag: The region adjacency graph.
  • input_: The input map.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The edge features.

def compute_boundary_features_with_filters( rag, input_: numpy.ndarray, apply_2d: bool = False, n_threads: Optional[int] = None, filters: Dict[str, List[float]] = {'gaussianSmoothing': [1.6, 4.2, 8.3], 'laplacianOfGaussian': [1.6, 4.2, 8.3], 'hessianOfGaussianEigenvalues': [1.6, 4.2, 8.3]}) -> numpy.ndarray:
143def compute_boundary_features_with_filters(
144    rag,
145    input_: np.ndarray,
146    apply_2d: bool = False,
147    n_threads: Optional[int] = None,
148    filters: Dict[str, List[float]] = {"gaussianSmoothing": [1.6, 4.2, 8.3],
149                                       "laplacianOfGaussian": [1.6, 4.2, 8.3],
150                                       "hessianOfGaussianEigenvalues": [1.6, 4.2, 8.3]}
151) -> np.ndarray:
152    """Compute boundary features accumulated over filter responses on input.
153
154    Args:
155        rag: The region adjacency graph.
156        input_: The input data.
157        apply_2d: Whether to apply the filters in 2d for 3d input data.
158        n_threads: The number of threads.
159        filters: The filters to apply, expects a dictionary mapping filter names to sigma values.
160
161    Returns:
162        The edge filters.
163    """
164    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
165    features = []
166
167    # apply 2d: we compute filters and derived features in parallel per filter
168    if apply_2d:
169
170        def _compute_2d(filter_name, sigma):
171            response = _filter_2d(input_, filter_name, sigma, n_threads)
172            assert response.ndim == 4
173            n_channels = response.shape[-1]
174            features = []
175            for chan in range(n_channels):
176                chan_data = response[..., chan]
177                feats = compute_boundary_features(rag, chan_data,
178                                                  chan_data.min(), chan_data.max(), n_threads)
179                features.append(feats)
180
181            features = np.concatenate(features, axis=1)
182            assert len(features) == rag.numberOfEdges
183            return features
184
185        features = [_compute_2d(filter_name, sigma)
186                    for filter_name, sigmas in filters.items() for sigma in sigmas]
187
188    # apply 3d: we parallelize over the whole filter + feature computation
189    # this can be very memory intensive, and it would be better to parallelize inside
190    # of the loop, but 3d parallel filters in elf.parallel.filters are not working properly yet
191    else:
192
193        def _compute_3d(filter_name, sigma):
194            filter_fu = getattr(ff, filter_name)
195            response = filter_fu(input_, sigma)
196            if response.ndim == 3:
197                response = response[..., None]
198
199            n_channels = response.shape[-1]
200            features = []
201
202            for chan in range(n_channels):
203                chan_data = response[..., chan]
204                feats = compute_boundary_features(rag, chan_data,
205                                                  chan_data.min(), chan_data.max(),
206                                                  n_threads=1)
207                features.append(feats)
208            features = np.concatenate(features, axis=1)
209            assert len(features) == rag.numberOfEdges, f"{len(features), {rag.numberOfEdges}}"
210            return features
211
212        with futures.ThreadPoolExecutor(n_threads) as tp:
213            tasks = [tp.submit(_compute_3d, filter_name, sigma)
214                     for filter_name, sigmas in filters.items() for sigma in sigmas]
215            features = [t.result() for t in tasks]
216
217    features = np.concatenate(features, axis=1)
218    assert len(features) == rag.numberOfEdges
219    return features

Compute boundary features accumulated over filter responses on input.

Arguments:
  • rag: The region adjacency graph.
  • input_: The input data.
  • apply_2d: Whether to apply the filters in 2d for 3d input data.
  • n_threads: The number of threads.
  • filters: The filters to apply, expects a dictionary mapping filter names to sigma values.
Returns:

The edge filters.

def compute_region_features( uv_ids: numpy.ndarray, input_map: numpy.ndarray, segmentation: numpy.ndarray, n_threads: Optional[int] = None) -> numpy.ndarray:
222def compute_region_features(
223    uv_ids: np.ndarray,
224    input_map: np.ndarray,
225    segmentation: np.ndarray,
226    n_threads: Optional[int] = None
227) -> np.ndarray:
228    """Compute edge features from an input map accumulated over segmentation and mapped to edges.
229
230    Args:
231        uv_ids: The edge uv ids.
232        input_: The input data.
233        segmentation: The segmentation.
234        n_threads: The number of threads used, set to cpu count by default.
235
236    Returns:
237        The edge features.
238    """
239    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
240
241    # compute the node features
242    stat_feature_names = ["Count", "Kurtosis", "Maximum", "Minimum", "Quantiles",
243                          "RegionRadii", "Skewness", "Sum", "Variance"]
244    coord_feature_names = ["Weighted<RegionCenter>", "RegionCenter"]
245    feature_names = stat_feature_names + coord_feature_names
246    node_features = vigra.analysis.extractRegionFeatures(input_map, segmentation,
247                                                         features=feature_names)
248
249    # get the image statistics based features, that are combined via [min, max, sum, absdiff]
250    stat_features = [node_features[fname] for fname in stat_feature_names]
251    stat_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat
252                                    for feat in stat_features], axis=1)
253
254    # get the coordinate based features, that are combined via euclidean distance
255    coord_features = [node_features[fname] for fname in coord_feature_names]
256    coord_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat
257                                     for feat in coord_features], axis=1)
258
259    u, v = uv_ids[:, 0], uv_ids[:, 1]
260
261    # combine the stat features for all edges
262    feats_u, feats_v = stat_features[u], stat_features[v]
263    features = [np.minimum(feats_u, feats_v), np.maximum(feats_u, feats_v),
264                np.abs(feats_u - feats_v), feats_u + feats_v]
265
266    # combine the coord features for all edges
267    feats_u, feats_v = coord_features[u], coord_features[v]
268    features.append((feats_u - feats_v) ** 2)
269
270    features = np.nan_to_num(np.concatenate(features, axis=1))
271    assert len(features) == len(uv_ids)
272    return features

Compute edge features from an input map accumulated over segmentation and mapped to edges.

Arguments:
  • uv_ids: The edge uv ids.
  • input_: The input data.
  • segmentation: The segmentation.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The edge features.

def compute_grid_graph(shape: Tuple[int, ...]):
279def compute_grid_graph(shape: Tuple[int, ...]):
280    """Compute grid graph for the given shape.
281
282    Args:
283        shape: The shape of the data.
284
285    Returns:
286        The grid graph.
287    """
288    grid_graph = nifty.graph.undirectedGridGraph(shape)
289    return grid_graph

Compute grid graph for the given shape.

Arguments:
  • shape: The shape of the data.
Returns:

The grid graph.

def compute_grid_graph_image_features( grid_graph, image: numpy.ndarray, mode: str, offsets: Optional[List[List[int]]] = None, strides: Optional[List[int]] = None, randomize_strides: bool = False) -> Tuple[numpy.ndarray, numpy.ndarray]:
292def compute_grid_graph_image_features(
293    grid_graph,
294    image: np.ndarray,
295    mode: str,
296    offsets: Optional[List[List[int]]] = None,
297    strides: Optional[List[int]] = None,
298    randomize_strides: bool = False,
299) -> Tuple[np.ndarray, np.ndarray]:
300    """Compute edge features for image for the given grid_graph.
301
302    Args:
303        grid_graph: The grid graph
304        image: The image, from which the features will be derived.
305        mode: Feature accumulation method.
306        offsets: The offsets, which correspond to the affinity channels.
307            If none are given, the affinites for the nearest neighbor transitions are used.
308        strides: The strides used to subsample edges that are computed from offsets.
309        randomize_strides: Whether to subsample randomly instead of using regular strides.
310
311    Returns:
312        The uv ids of the edges.
313        The edge features.
314    """
315    gndim = len(grid_graph.shape)
316
317    if image.ndim == gndim:
318        if offsets is not None:
319            raise NotImplementedError
320        modes = ("l1", "l2", "min", "max", "sum", "prod", "interpixel")
321        if mode not in modes:
322            raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}")
323        features = grid_graph.imageToEdgeMap(image, mode)
324        edges = grid_graph.uvIds()
325
326    elif image.ndim == gndim + 1:
327        modes = ("l1", "l2", "cosine")
328        if mode not in modes:
329            raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}")
330
331        if offsets is None:
332            features = grid_graph.imageWithChannelsToEdgeMap(image, mode)
333            edges = grid_graph.uvIds()
334        else:
335            (n_edges,
336             edges,
337             features) = grid_graph.imageWithChannelsToEdgeMapWithOffsets(image, mode,
338                                                                          offsets=offsets,
339                                                                          strides=strides,
340                                                                          randomize_strides=randomize_strides)
341            edges, features = edges[:n_edges], features[:n_edges]
342
343    else:
344        msg = f"Invalid image dimension {image.ndim}, expect one of {gndim} or {gndim + 1}"
345        raise ValueError(msg)
346
347    return edges, features

Compute edge features for image for the given grid_graph.

Arguments:
  • grid_graph: The grid graph
  • image: The image, from which the features will be derived.
  • mode: Feature accumulation method.
  • offsets: The offsets, which correspond to the affinity channels. If none are given, the affinites for the nearest neighbor transitions are used.
  • strides: The strides used to subsample edges that are computed from offsets.
  • randomize_strides: Whether to subsample randomly instead of using regular strides.
Returns:

The uv ids of the edges. The edge features.

def compute_grid_graph_affinity_features( grid_graph, affinities: numpy.ndarray, offsets: Optional[List[List[int]]] = None, strides: Optional[List[int]] = None, mask: Optional[numpy.ndarray] = None, randomize_strides: bool = False) -> Tuple[numpy.ndarray, numpy.ndarray]:
350def compute_grid_graph_affinity_features(
351    grid_graph,
352    affinities: np.ndarray,
353    offsets: Optional[List[List[int]]] = None,
354    strides: Optional[List[int]] = None,
355    mask: Optional[np.ndarray] = None,
356    randomize_strides: bool = False,
357) -> Tuple[np.ndarray, np.ndarray]:
358    """Compute edge features from affinities for the given grid graph.
359
360    Args:
361        grid_graph: The grid graph
362        affinities: The affinity map.
363        offsets: The offsets, which correspond to the affinity channels.
364            If none are given, the affinites for the nearest neighbor transitions are used.
365        strides: The strides used to subsample edges that are computed from offsets.
366        mask: Mask to exclude from the edge and feature computation.
367        randomize_strides: Whether to subsample randomly instead of using regular strides.
368
369    Returns:
370        The uv ids of the edges.
371        The edge features.
372    """
373    gndim = len(grid_graph.shape)
374    if affinities.ndim != gndim + 1:
375        raise ValueError
376
377    if offsets is None:
378        assert affinities.shape[0] == gndim
379        assert strides is None
380        assert mask is None
381        features = grid_graph.affinitiesToEdgeMap(affinities)
382        edges = grid_graph.uvIds()
383    elif mask is not None:
384        assert strides is None and not randomize_strides, "Strides and mask cannot be used at the same time"
385        n_edges, edges, features = grid_graph.affinitiesToEdgeMapWithMask(affinities,
386                                                                          offsets=offsets,
387                                                                          mask=mask)
388        edges, features = edges[:n_edges], features[:n_edges]
389    else:
390        n_edges, edges, features = grid_graph.affinitiesToEdgeMapWithOffsets(affinities,
391                                                                             offsets=offsets,
392                                                                             strides=strides,
393                                                                             randomize_strides=randomize_strides)
394        edges, features = edges[:n_edges], features[:n_edges]
395
396    return edges, features

Compute edge features from affinities for the given grid graph.

Arguments:
  • grid_graph: The grid graph
  • affinities: The affinity map.
  • offsets: The offsets, which correspond to the affinity channels. If none are given, the affinites for the nearest neighbor transitions are used.
  • strides: The strides used to subsample edges that are computed from offsets.
  • mask: Mask to exclude from the edge and feature computation.
  • randomize_strides: Whether to subsample randomly instead of using regular strides.
Returns:

The uv ids of the edges. The edge features.

def apply_mask_to_grid_graph_weights( grid_graph, mask: numpy.ndarray, weights: numpy.ndarray, masked_edge_weight: float = 0.0, transition_edge_weight: float = 1.0) -> numpy.ndarray:
399def apply_mask_to_grid_graph_weights(
400    grid_graph,
401    mask: np.ndarray,
402    weights: np.ndarray,
403    masked_edge_weight: float = 0.0,
404    transition_edge_weight: float = 1.0,
405) -> np.ndarray:
406    """Mask edges in grid graph.
407
408    Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes
409    and edges that connect masked and unmasked nodes.
410
411    Args:
412        grid_graph: The grid graph.
413        mask: The binary mask, foreground (=non-masked) is True.
414        weights: The edge weights.
415        masked_edge_weight: The value for edges that connect two masked nodes.
416        transition_edge_weight: The value for edges that connect a masked with a non-masked node.
417
418    Returns:
419        The masked edge weights.
420    """
421    assert np.dtype(mask.dtype) == np.dtype("bool")
422    node_ids = grid_graph.projectNodeIdsToPixels()
423    assert node_ids.shape == mask.shape == tuple(grid_graph.shape), \
424        f"{node_ids.shape}, {mask.shape}, {grid_graph.shape}"
425    masked_ids = node_ids[~mask]
426
427    edges = grid_graph.uvIds()
428    assert len(edges) == len(weights)
429    edge_state = np.isin(edges, masked_ids).sum(axis=1)
430    masked_edges = edge_state == 2
431    transition_edges = edge_state == 1
432    weights[masked_edges] = masked_edge_weight
433    weights[transition_edges] = transition_edge_weight
434    return weights

Mask edges in grid graph.

Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes and edges that connect masked and unmasked nodes.

Arguments:
  • grid_graph: The grid graph.
  • mask: The binary mask, foreground (=non-masked) is True.
  • weights: The edge weights.
  • masked_edge_weight: The value for edges that connect two masked nodes.
  • transition_edge_weight: The value for edges that connect a masked with a non-masked node.
Returns:

The masked edge weights.

def apply_mask_to_grid_graph_edges_and_weights( grid_graph, mask: numpy.ndarray, edges: numpy.ndarray, weights: numpy.ndarray, transition_edge_weight: float = 1.0) -> Tuple[numpy.ndarray, numpy.ndarray]:
437def apply_mask_to_grid_graph_edges_and_weights(
438    grid_graph, mask: np.ndarray, edges: np.ndarray, weights: np.ndarray, transition_edge_weight: float = 1.0
439) -> Tuple[np.ndarray, np.ndarray]:
440    """Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value.
441
442    Args:
443        grid_graph: The grid graph.
444        mask: The binary mask, foreground (=non-masked) is True.
445        edges: The edges (uv-ids).
446        weights: The edge weights.
447        transition_edge_weight: The value for edges that connect a masked with a non-masked node.
448
449    Returns:
450        The edge uv-ids.
451        The edge weights.
452    """
453    assert np.dtype(mask.dtype) == np.dtype("bool")
454    node_ids = grid_graph.projectNodeIdsToPixels()
455    assert node_ids.shape == mask.shape == tuple(grid_graph.shape), \
456        f"{node_ids.shape}, {mask.shape}, {grid_graph.shape}"
457    masked_ids = node_ids[~mask]
458
459    edge_state = np.isin(edges, masked_ids).sum(axis=1)
460    keep_edges = edge_state != 2
461
462    edges, weights, edge_state = edges[keep_edges], weights[keep_edges], edge_state[keep_edges]
463    transition_edges = edge_state == 1
464    weights[transition_edges] = transition_edge_weight
465
466    return edges, weights

Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value.

Arguments:
  • grid_graph: The grid graph.
  • mask: The binary mask, foreground (=non-masked) is True.
  • edges: The edges (uv-ids).
  • weights: The edge weights.
  • transition_edge_weight: The value for edges that connect a masked with a non-masked node.
Returns:

The edge uv-ids. The edge weights.

def lifted_problem_from_probabilities( rag, watershed: numpy.ndarray, input_maps: List[numpy.ndarray], assignment_threshold: float, graph_depth: int, feats_to_costs: <built-in function callable> = <function feats_to_costs_default>, mode: str = 'different', n_threads: Optional[int] = None) -> Tuple[numpy.ndarray, numpy.ndarray]:
502def lifted_problem_from_probabilities(
503    rag,
504    watershed: np.ndarray,
505    input_maps: List[np.ndarray],
506    assignment_threshold: float,
507    graph_depth: int,
508    feats_to_costs: callable = feats_to_costs_default,
509    mode: str = "different",
510    n_threads: Optional[int] = None,
511) -> Tuple[np.ndarray, np.ndarray]:
512    """Compute lifted problem from probability maps by mapping them to superpixels.
513
514    Example: compute a lifted problem from two attributions (axon, dendrite) that induce
515    repulsive edges between different attributions. The construction of lifted eges and
516    features can be customized using the `feats_to_costs` and `mode` arguments.
517    ```
518    lifted_uvs, lifted_costs = lifted_problem_from_probabilties(
519       rag, superpixels,
520       input_maps=[
521         axon_probabilities,  # probabilty map for axon attribution
522         dendrite_probabilities  # probability map for dendrite attributtion
523       ],
524       assignment_threshold=0.6,  # probability threshold to assign superpixels to a class
525       graph_depth=10,  # the max. graph depth along which lifted edges are introduced
526    )
527    ```
528
529    Args:
530        rag: The region adjacency graph.
531        watershed: The watershed over-segmentation.
532        input_maps: List of probability maps. Each map must have the same shape as the watersheds
533            and each map is treated as the probability to correspond to a different class.
534        assignment_threshold: Minimal expression level to assign a class to a graph node (= watershed segment).
535        graph_depth: Maximal graph depth up to which lifted edges will be included.
536        feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities.
537            The input to the function are `lifted_labels`, which stores the two classes assigned to a lifted edge,
538            and `lifted_features`, which stores the two assignment probabilities.
539        mode: The mode for insertion of lifted edges. One of:
540            "all" - lifted edges will be inserted in between all nodes with attribution.
541            "different" - lifted edges will only be inserted in between nodes attributed to different classes.
542            "same" - lifted edges will only be inserted in between nodes attribted to the same class.
543        n_threads: The number of threads used for the calculation.
544
545    Returns:
546        The lifted uv ids (= superpixel ids connected by the lifted edge).
547        The lifted costs (= cost associated with each lifted edge).
548    """
549    assert ndist is not None, "Need nifty.distributed package"
550
551    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
552    # validate inputs
553    assert isinstance(input_maps, (list, tuple))
554    assert all(isinstance(inp, np.ndarray) for inp in input_maps)
555    shape = watershed.shape
556    assert all(inp.shape == shape for inp in input_maps)
557
558    # map the probability maps to superpixels - we only map to superpixels which
559    # have a larger mean expression than `assignment_threshold`
560
561    # TODO handle the dtype conversion for vigra gracefully somehow ...
562    # think about supporting uint8 input and normalizing
563
564    # TODO how do we handle cases where the same superpixel is mapped to
565    # more than one class ?
566
567    n_nodes = int(watershed.max()) + 1
568    node_labels = np.zeros(n_nodes, dtype="uint64")
569    node_features = np.zeros(n_nodes, dtype="float32")
570    # TODO we could allow for more features that could then be used for the cost estimation
571    for class_id, inp in enumerate(input_maps):
572        mean_prob = vigra.analysis.extractRegionFeatures(inp, watershed, features=["mean"])["mean"]
573        # we can in principle map multiple classes here, and right now will just override
574        class_mask = mean_prob > assignment_threshold
575        node_labels[class_mask] = class_id
576        node_features[class_mask] = mean_prob[class_mask]
577
578    # find all lifted edges up to the graph depth between mapped nodes
579    # NOTE we need to convert to the different graph type for now, but
580    # it would be nice to support all nifty graphs at some type
581    uv_ids = rag.uvIds()
582    g_temp = ndist.Graph(uv_ids)
583
584    lifted_uvs = ndist.liftedNeighborhoodFromNodeLabels(g_temp, node_labels, graph_depth, mode=mode,
585                                                        numberOfThreads=n_threads, ignoreLabel=0)
586    lifted_labels = node_labels[lifted_uvs]
587    lifted_features = node_features[lifted_uvs]
588
589    lifted_costs = feats_to_costs(lifted_labels, lifted_features)
590    return lifted_uvs, lifted_costs

Compute lifted problem from probability maps by mapping them to superpixels.

Example: compute a lifted problem from two attributions (axon, dendrite) that induce repulsive edges between different attributions. The construction of lifted eges and features can be customized using the feats_to_costs and mode arguments.

lifted_uvs, lifted_costs = lifted_problem_from_probabilties(
   rag, superpixels,
   input_maps=[
     axon_probabilities,  # probabilty map for axon attribution
     dendrite_probabilities  # probability map for dendrite attributtion
   ],
   assignment_threshold=0.6,  # probability threshold to assign superpixels to a class
   graph_depth=10,  # the max. graph depth along which lifted edges are introduced
)
Arguments:
  • rag: The region adjacency graph.
  • watershed: The watershed over-segmentation.
  • input_maps: List of probability maps. Each map must have the same shape as the watersheds and each map is treated as the probability to correspond to a different class.
  • assignment_threshold: Minimal expression level to assign a class to a graph node (= watershed segment).
  • graph_depth: Maximal graph depth up to which lifted edges will be included.
  • feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities. The input to the function are lifted_labels, which stores the two classes assigned to a lifted edge, and lifted_features, which stores the two assignment probabilities.
  • mode: The mode for insertion of lifted edges. One of: "all" - lifted edges will be inserted in between all nodes with attribution. "different" - lifted edges will only be inserted in between nodes attributed to different classes. "same" - lifted edges will only be inserted in between nodes attribted to the same class.
  • n_threads: The number of threads used for the calculation.
Returns:

The lifted uv ids (= superpixel ids connected by the lifted edge). The lifted costs (= cost associated with each lifted edge).

def lifted_problem_from_segmentation( rag, watershed: numpy.ndarray, input_segmentation: numpy.ndarray, overlap_threshold: float, graph_depth: int, same_segment_cost: float, different_segment_cost: float, mode: str = 'all', n_threads: Optional[int] = None) -> Tuple[numpy.ndarray, numpy.ndarray]:
594def lifted_problem_from_segmentation(
595    rag,
596    watershed: np.ndarray,
597    input_segmentation: np.ndarray,
598    overlap_threshold: float,
599    graph_depth: int,
600    same_segment_cost: float,
601    different_segment_cost: float,
602    mode: str = "all",
603    n_threads: Optional[int] = None,
604) -> Tuple[np.ndarray, np.ndarray]:
605    """Compute lifted problem from segmentation by mapping segments to superpixels.
606
607    Args:
608        rag: The region adjacency graph.
609        watershed: The watershed over-segmentation.
610        input_segmentation: The segmentation used to determine node attribution.
611        overlap_threshold: The minimal overlap to assign a segment id to node.
612        graph_depth: The maximal graph depth up to which lifted edges will be included
613        same_segment_cost: The cost for edges between nodes with same segment id attribution.
614        different_segment_cost: The cost for edges between nodes with different segment id attribution.
615        mode: The mode for insertion of lifted edges. One of:
616            "all" - lifted edges will be inserted in between all nodes with attribution.
617            "different" - lifted edges will only be inserted in between nodes attributed to different classes.
618            "same" - lifted edges will only be inserted in between nodes attribted to the same class.
619        n_threads: The number of threads used for the calculation.
620
621    Returns:
622        The lifted uv ids (= superpixel ids connected by the lifted edge).
623        The lifted costs (= cost associated with each lifted edge).
624    """
625    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
626    assert input_segmentation.shape == watershed.shape
627
628    # compute the overlaps
629    ovlp_comp = ngt.overlap(watershed, input_segmentation)
630    ws_ids = np.unique(watershed)
631    n_labels = int(ws_ids[-1]) + 1
632    assert n_labels == rag.numberOfNodes, "%i, %i" % (n_labels, rag.numberOfNodes)
633
634    # initialise the arrays for node labels, to be
635    # dense in the watershed id space (even if some ws-ids are not present)
636    node_labels = np.zeros(n_labels, dtype="uint64")
637
638    # extract the overlap values and node labels from the overlap
639    # computation results
640    overlaps = [ovlp_comp.overlapArraysNormalized(ws_id, sorted=False)
641                for ws_id in ws_ids]
642    node_label_vals = np.array([ovlp[0][0] for ovlp in overlaps])
643    overlap_values = np.array([ovlp[1][0] for ovlp in overlaps])
644    node_label_vals[overlap_values < overlap_threshold] = 0
645    assert len(node_label_vals) == len(ws_ids)
646    node_labels[ws_ids] = node_label_vals
647
648    # find all lifted edges up to the graph depth between mapped nodes
649    # NOTE we need to convert to the different graph type for now, but
650    # it would be nice to support all nifty graphs at some type
651    uv_ids = rag.uvIds()
652    g_temp = ndist.Graph(uv_ids)
653
654    lifted_uvs = ndist.liftedNeighborhoodFromNodeLabels(g_temp, node_labels, graph_depth, mode=mode,
655                                                        numberOfThreads=n_threads, ignoreLabel=0)
656    # make sure that the lifted uv ids are in range of the node labels
657    assert lifted_uvs.max() < rag.numberOfNodes, "%i, %i" % (int(lifted_uvs.max()),
658                                                             rag.numberOfNodes)
659    lifted_labels = node_labels[lifted_uvs]
660    lifted_costs = np.zeros(len(lifted_labels), dtype="float64")
661
662    same_mask = lifted_labels[:, 0] == lifted_labels[:, 1]
663    lifted_costs[same_mask] = same_segment_cost
664    lifted_costs[~same_mask] = different_segment_cost
665
666    return lifted_uvs, lifted_costs

Compute lifted problem from segmentation by mapping segments to superpixels.

Arguments:
  • rag: The region adjacency graph.
  • watershed: The watershed over-segmentation.
  • input_segmentation: The segmentation used to determine node attribution.
  • overlap_threshold: The minimal overlap to assign a segment id to node.
  • graph_depth: The maximal graph depth up to which lifted edges will be included
  • same_segment_cost: The cost for edges between nodes with same segment id attribution.
  • different_segment_cost: The cost for edges between nodes with different segment id attribution.
  • mode: The mode for insertion of lifted edges. One of: "all" - lifted edges will be inserted in between all nodes with attribution. "different" - lifted edges will only be inserted in between nodes attributed to different classes. "same" - lifted edges will only be inserted in between nodes attribted to the same class.
  • n_threads: The number of threads used for the calculation.
Returns:

The lifted uv ids (= superpixel ids connected by the lifted edge). The lifted costs (= cost associated with each lifted edge).

def get_stitch_edges( rag, seg: numpy.ndarray, block_shape: Tuple[int, ...], n_threads: Optional[int] = None, verbose: bool = False) -> numpy.ndarray:
673def get_stitch_edges(
674    rag,
675    seg: np.ndarray,
676    block_shape: Tuple[int, ...],
677    n_threads: Optional[int] = None,
678    verbose: bool = False
679) -> np.ndarray:
680    """Get the edges between blocks.
681
682    Args:
683        rag: The region adjacency graph.
684        seg: The segmentation underlying the rag.
685        block_shape: The shape of the blocking.
686        n_threads: The number of threads used for the calculation.
687        verbose: Whether to be verbose.
688
689    Returns:
690        The edge mask indicating edges between blocks.
691    """
692    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
693    ndim = seg.ndim
694    blocking = nifty.tools.blocking([0] * ndim, seg.shape, block_shape)
695
696    def find_stitch_edges(block_id):
697        stitch_edges = []
698        block = blocking.getBlock(block_id)
699        for axis in range(ndim):
700            if blocking.getNeighborId(block_id, axis, True) == -1:
701                continue
702            face_a = tuple(
703                beg if d == axis else slice(beg, end)
704                for d, beg, end in zip(range(ndim), block.begin, block.end)
705            )
706            face_b = tuple(
707                beg - 1 if d == axis else slice(beg, end)
708                for d, beg, end in zip(range(ndim), block.begin, block.end)
709            )
710
711            labels_a = seg[face_a].ravel()
712            labels_b = seg[face_b].ravel()
713
714            uv_ids = np.concatenate(
715                [labels_a[:, None], labels_b[:, None]],
716                axis=1
717            )
718            uv_ids = np.unique(uv_ids, axis=0)
719
720            edge_ids = rag.findEdges(uv_ids)
721            edge_ids = edge_ids[edge_ids != -1]
722            stitch_edges.append(edge_ids)
723
724        if stitch_edges:
725            stitch_edges = np.concatenate(stitch_edges)
726            stitch_edges = np.unique(stitch_edges)
727        else:
728            stitch_edges = None
729        return stitch_edges
730
731    with futures.ThreadPoolExecutor(n_threads) as tp:
732        if verbose:
733            stitch_edges = list(tqdm(
734                tp.map(find_stitch_edges, range(blocking.numberOfBlocks)),
735                total=blocking.numberOfBlocks
736            ))
737        else:
738            stitch_edges = tp.map(find_stitch_edges, range(blocking.numberOfBlocks))
739
740    stitch_edges = np.concatenate([st for st in stitch_edges if st is not None])
741    stitch_edges = np.unique(stitch_edges)
742    full_edges = np.zeros(rag.numberOfEdges, dtype="bool")
743    full_edges[stitch_edges] = 1
744    return full_edges

Get the edges between blocks.

Arguments:
  • rag: The region adjacency graph.
  • seg: The segmentation underlying the rag.
  • block_shape: The shape of the blocking.
  • n_threads: The number of threads used for the calculation.
  • verbose: Whether to be verbose.
Returns:

The edge mask indicating edges between blocks.

def project_node_labels_to_pixels( rag, node_labels: numpy.ndarray, n_threads: Optional[int] = None) -> numpy.ndarray:
747def project_node_labels_to_pixels(rag, node_labels: np.ndarray, n_threads: Optional[int] = None) -> np.ndarray:
748    """Project label values for graph nodes back to pixels to obtain segmentation.
749
750    Args:
751        rag: The region adjacency graph.
752        node_labels: The array with node labels.
753        n_threads: The number of threads used, set to cpu count by default.
754
755    Returns:
756        The segmentation.
757    """
758    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
759    if len(node_labels) != rag.numberOfNodes:
760        raise ValueError("Incompatible number of node labels: %i, %i" % (len(node_labels), rag.numberOfNodes))
761    seg = nrag.projectScalarNodeDataToPixels(rag, node_labels, numberOfThreads=n_threads)
762    return seg

Project label values for graph nodes back to pixels to obtain segmentation.

Arguments:
  • rag: The region adjacency graph.
  • node_labels: The array with node labels.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The segmentation.

def compute_z_edge_mask(rag, watershed: numpy.ndarray) -> numpy.ndarray:
765def compute_z_edge_mask(rag, watershed: np.ndarray) -> np.ndarray:
766    """Compute edge mask of in-between plane edges for flat superpixels.
767
768    Flat superpixels are volumetric superpixels that are independent across slices.
769    This function does not check wether the input watersheds are actually flat.
770
771    Args:
772        rag: The region adjacency graph.
773        watershed: The underlying watershed over-segmentation (superpixels).
774
775    Returns:
776        The edge mask indicating in-between slice edges.
777    """
778    node_z_coords = np.zeros(rag.numberOfNodes, dtype="uint32")
779    for z in range(watershed.shape[0]):
780        node_z_coords[watershed[z]] = z
781    uv_ids = rag.uvIds()
782    z_edge_mask = node_z_coords[uv_ids[:, 0]] != node_z_coords[uv_ids[:, 1]]
783    return z_edge_mask

Compute edge mask of in-between plane edges for flat superpixels.

Flat superpixels are volumetric superpixels that are independent across slices. This function does not check wether the input watersheds are actually flat.

Arguments:
  • rag: The region adjacency graph.
  • watershed: The underlying watershed over-segmentation (superpixels).
Returns:

The edge mask indicating in-between slice edges.