elf.segmentation.features

  1import multiprocessing
  2from concurrent import futures
  3from typing import Dict, List, Optional, Tuple
  4
  5import bioimage_cpp as bic
  6import numpy as np
  7from scipy.stats import kurtosis, skew
  8from skimage.measure import regionprops_table
  9
 10from tqdm import tqdm
 11from .multicut import transform_probabilities_to_costs
 12
 13
 14# Map fastfilters/vigra filter names to bic.filters callables.
 15_BIC_FILTERS = {
 16    "gaussianSmoothing": bic.filters.gaussian_smoothing,
 17    "gaussianGradientMagnitude": bic.filters.gaussian_gradient_magnitude,
 18    "laplacianOfGaussian": bic.filters.laplacian_of_gaussian,
 19    "hessianOfGaussianEigenvalues": bic.filters.hessian_of_gaussian_eigenvalues,
 20    "structureTensorEigenvalues": bic.filters.structure_tensor_eigenvalues,
 21    "gaussianDerivative": bic.filters.gaussian_derivative,
 22}
 23
 24
 25def _apply_filter(filter_name, image, sigma):
 26    """@private"""
 27    fu = _BIC_FILTERS[filter_name]
 28    if image.dtype not in (np.float32, np.float64, np.uint8, np.uint16):
 29        image = image.astype("float32")
 30    return fu(image, sigma)
 31
 32
 33#
 34# Region Adjacency Graph and Features
 35#
 36
 37def compute_rag(segmentation: np.ndarray, n_labels: Optional[int] = None, n_threads: Optional[int] = None):
 38    """Compute region adjacency graph of segmentation.
 39
 40    Args:
 41        segmentation: The segmentation.
 42        n_labels: Deprecated; ignored. Kept for backwards-compatibility.
 43        n_threads: The number of threads used, set to cpu count by default.
 44
 45    Returns:
 46        The region adjacency graph (`bioimage_cpp.graph.RegionAdjacencyGraph`).
 47    """
 48    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
 49    if segmentation.dtype not in (np.uint32, np.uint64, np.int32, np.int64):
 50        segmentation = segmentation.astype("uint32")
 51    rag = bic.graph.region_adjacency_graph(segmentation, number_of_threads=n_threads)
 52    return rag
 53
 54
 55def compute_boundary_features(
 56    rag,
 57    segmentation: np.ndarray,
 58    boundary_map: np.ndarray,
 59    min_value: float = 0.0,  # noqa: ARG001 — deprecated, ignored
 60    max_value: float = 1.0,  # noqa: ARG001 — deprecated, ignored
 61    n_threads: Optional[int] = None,
 62) -> np.ndarray:
 63    """Compute edge features from boundary map.
 64
 65    Args:
 66        rag: The region adjacency graph.
 67        segmentation: The over-segmentation used to construct the RAG.
 68        boundary_map: The boundary map.
 69        min_value: Deprecated; ignored.
 70        max_value: Deprecated; ignored.
 71        n_threads: The number of threads used, set to cpu count by default.
 72
 73    Returns:
 74        The edge features. Output has 12 columns
 75        (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size).
 76    """
 77    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
 78    if segmentation.shape != boundary_map.shape:
 79        raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(boundary_map.shape)))
 80    features = bic.graph.features.edge_map_features_complex(
 81        rag, segmentation, boundary_map, number_of_threads=n_threads,
 82    )
 83    return features
 84
 85
 86def compute_affinity_features(
 87    rag,
 88    segmentation: np.ndarray,
 89    affinity_map: np.ndarray,
 90    offsets: List[List[int]],
 91    min_value: float = 0.0,  # noqa: ARG001 — deprecated, ignored
 92    max_value: float = 1.0,  # noqa: ARG001 — deprecated, ignored
 93    n_threads: Optional[int] = None,
 94) -> np.ndarray:
 95    """Compute edge features from affinity map.
 96
 97    Args:
 98        rag: The region adjacency graph.
 99        segmentation: The over-segmentation used to construct the RAG.
100        affinity_map: The affinity map.
101        offsets: The offsets corresponding to the affinity channels.
102        min_value: Deprecated; ignored.
103        max_value: Deprecated; ignored.
104        n_threads: The number of threads used, set to cpu count by default.
105
106    Returns:
107        The edge features. Output has 12 columns
108        (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size).
109    """
110    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
111    if segmentation.shape != affinity_map.shape[1:]:
112        raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(affinity_map.shape[1:])))
113    if len(offsets) != affinity_map.shape[0]:
114        raise ValueError("Incompatible number of channels and offsets: %i, %i" % (len(offsets),
115                                                                                  affinity_map.shape[0]))
116    features = bic.graph.features.affinity_features_complex(
117        rag, segmentation, affinity_map, offsets, number_of_threads=n_threads,
118    )
119    return features
120
121
122def compute_boundary_mean_and_length(
123    rag, segmentation: np.ndarray, input_: np.ndarray, n_threads: Optional[int] = None,
124) -> np.ndarray:
125    """Compute mean value and length of boundaries.
126
127    Args:
128        rag: The region adjacency graph.
129        segmentation: The over-segmentation used to construct the RAG.
130        input_: The input map.
131        n_threads: The number of threads used, set to cpu count by default.
132
133    Returns:
134        The edge features with two columns (mean, size).
135    """
136    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
137    if segmentation.shape != input_.shape:
138        raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(input_.shape)))
139    features = bic.graph.features.edge_map_features(
140        rag, segmentation, input_, number_of_threads=n_threads,
141    )
142    return features
143
144
145# TODO generalize and move to elf.features.parallel
146def _filter_2d(input_, filter_name, sigma, n_threads):
147    def _fz(inp):
148        response = _apply_filter(filter_name, inp, sigma)
149        # we add a channel last axis for 2d filter responses
150        if response.ndim == 2:
151            response = response[None, ..., None]
152        elif response.ndim == 3:
153            response = response[None]
154        else:
155            raise RuntimeError("Invalid filter response")
156        return response
157
158    with futures.ThreadPoolExecutor(n_threads) as tp:
159        tasks = [tp.submit(_fz, input_[z]) for z in range(input_.shape[0])]
160        response = [t.result() for t in tasks]
161
162    response = np.concatenate(response, axis=0)
163    return response
164
165
166def compute_boundary_features_with_filters(
167    rag,
168    segmentation: np.ndarray,
169    input_: np.ndarray,
170    apply_2d: bool = False,
171    n_threads: Optional[int] = None,
172    filters: Dict[str, List[float]] = {"gaussianSmoothing": [1.6, 4.2, 8.3],
173                                       "laplacianOfGaussian": [1.6, 4.2, 8.3],
174                                       "hessianOfGaussianEigenvalues": [1.6, 4.2, 8.3]}
175) -> np.ndarray:
176    """Compute boundary features accumulated over filter responses on input.
177
178    Args:
179        rag: The region adjacency graph.
180        segmentation: The over-segmentation used to construct the RAG.
181        input_: The input data.
182        apply_2d: Whether to apply the filters in 2d for 3d input data.
183        n_threads: The number of threads.
184        filters: The filters to apply, expects a dictionary mapping filter names to sigma values.
185
186    Returns:
187        The edge features.
188    """
189    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
190    features = []
191
192    # apply 2d: we compute filters and derived features in parallel per filter
193    if apply_2d:
194
195        def _compute_2d(filter_name, sigma):
196            response = _filter_2d(input_, filter_name, sigma, n_threads)
197            assert response.ndim == 4
198            n_channels = response.shape[-1]
199            feats = []
200            for chan in range(n_channels):
201                chan_data = response[..., chan]
202                feats.append(compute_boundary_features(rag, segmentation, chan_data, n_threads=n_threads))
203
204            out = np.concatenate(feats, axis=1)
205            assert len(out) == rag.number_of_edges
206            return out
207
208        features = [_compute_2d(filter_name, sigma)
209                    for filter_name, sigmas in filters.items() for sigma in sigmas]
210
211    # apply 3d: we parallelize over the whole filter + feature computation
212    # this can be very memory intensive, and it would be better to parallelize inside
213    # of the loop, but 3d parallel filters in elf.parallel.filters are not working properly yet
214    else:
215
216        def _compute_3d(filter_name, sigma):
217            response = _apply_filter(filter_name, input_, sigma)
218            if response.ndim == input_.ndim:
219                response = response[..., None]
220
221            n_channels = response.shape[-1]
222            feats = []
223
224            for chan in range(n_channels):
225                chan_data = response[..., chan]
226                feats.append(compute_boundary_features(rag, segmentation, chan_data, n_threads=1))
227            out = np.concatenate(feats, axis=1)
228            assert len(out) == rag.number_of_edges, f"{len(out), {rag.number_of_edges}}"
229            return out
230
231        with futures.ThreadPoolExecutor(n_threads) as tp:
232            tasks = [tp.submit(_compute_3d, filter_name, sigma)
233                     for filter_name, sigmas in filters.items() for sigma in sigmas]
234            features = [t.result() for t in tasks]
235
236    features = np.concatenate(features, axis=1)
237    assert len(features) == rag.number_of_edges
238    return features
239
240
241# Intensity statistics that skimage.measure.regionprops does not provide natively.
242# Each callback receives the region's cropped (regionmask, intensity_image); see
243# `_region_features`. The function names double as the keys in the regionprops table.
244def _quantiles(regionmask, intensity_image):
245    """@private"""
246    return np.percentile(intensity_image[regionmask], [0, 10, 25, 50, 75, 90, 100])
247
248
249def _kurtosis(regionmask, intensity_image):
250    """@private"""
251    values = intensity_image[regionmask]
252    if values.size < 2 or values.min() == values.max():
253        return 0.0
254    return kurtosis(values)
255
256
257def _skewness(regionmask, intensity_image):
258    """@private"""
259    values = intensity_image[regionmask]
260    if values.size < 2 or values.min() == values.max():
261        return 0.0
262    return skew(values)
263
264
265def _variance(regionmask, intensity_image):
266    """@private"""
267    return np.var(intensity_image[regionmask])
268
269
270def _sum(regionmask, intensity_image):
271    """@private"""
272    return intensity_image[regionmask].sum()
273
274
275# Map vigra `extractRegionFeatures` names to their source in a skimage regionprops table.
276# Names starting with "_" are computed via the extra-property callbacks above; the rest are
277# native regionprops properties (array-valued ones are expanded into "<name>-<i>" columns).
278_REGION_FEATURE_KEYS = {
279    "Count": "num_pixels",
280    "Maximum": "intensity_max",
281    "Minimum": "intensity_min",
282    "mean": "intensity_mean",
283    "RegionCenter": "centroid",
284    "Weighted<RegionCenter>": "centroid_weighted",
285    "RegionRadii": "inertia_tensor_eigvals",
286    "Quantiles": "_quantiles",
287    "Kurtosis": "_kurtosis",
288    "Skewness": "_skewness",
289    "Variance": "_variance",
290    "Sum": "_sum",
291}
292_REGION_FEATURE_EXTRA = {
293    "_quantiles": _quantiles,
294    "_kurtosis": _kurtosis,
295    "_skewness": _skewness,
296    "_variance": _variance,
297    "_sum": _sum,
298}
299
300
301def _region_features(input_map: np.ndarray, segmentation: np.ndarray, feature_names: List[str]) -> Dict:
302    """@private
303
304    Replacement for ``vigra.analysis.extractRegionFeatures`` based on
305    ``skimage.measure.regionprops``. Returns a dict mapping each requested feature name to a
306    dense array indexed by label id (``0 .. segmentation.max()``); scalar features are 1D and
307    coordinate/quantile/radii features are 2D, matching the vigra layout. Missing label ids
308    (gaps) stay zero.
309    """
310    if segmentation.dtype.kind not in "iu":
311        segmentation = segmentation.astype("int64")
312    keys = [_REGION_FEATURE_KEYS[name] for name in feature_names]
313    native = tuple(dict.fromkeys(key for key in keys if not key.startswith("_")))
314    extra = tuple(dict.fromkeys(_REGION_FEATURE_EXTRA[key] for key in keys if key.startswith("_")))
315
316    # skimage treats label 0 as background; shift by 1 so the original label 0 is included.
317    table = regionprops_table(
318        segmentation + 1, intensity_image=input_map.astype("float32", copy=False),
319        properties=("label",) + native, extra_properties=(extra or None),
320    )
321    labels = np.asarray(table["label"]) - 1
322    n_nodes = int(segmentation.max()) + 1
323
324    def _gather(base):
325        if base in table:
326            return np.asarray(table[base], dtype="float32")[:, None]
327        cols, i = [], 0
328        while f"{base}-{i}" in table:
329            cols.append(np.asarray(table[f"{base}-{i}"], dtype="float32"))
330            i += 1
331        return np.stack(cols, axis=1)
332
333    result = {}
334    for name, base in zip(feature_names, keys):
335        cols = _gather(base)
336        if name == "RegionRadii":  # vigra returns radii = sqrt of the coordinate-covariance eigenvalues
337            cols = np.sqrt(np.maximum(cols, 0.0))
338        dense = np.zeros((n_nodes, cols.shape[1]), dtype="float32")
339        dense[labels] = cols
340        result[name] = dense[:, 0] if dense.shape[1] == 1 else dense
341    return result
342
343
344def compute_region_features(
345    uv_ids: np.ndarray,
346    input_map: np.ndarray,
347    segmentation: np.ndarray,
348    n_threads: Optional[int] = None
349) -> np.ndarray:
350    """Compute edge features from an input map accumulated over segmentation and mapped to edges.
351
352    Args:
353        uv_ids: The edge uv ids.
354        input_: The input data.
355        segmentation: The segmentation.
356        n_threads: The number of threads used, set to cpu count by default.
357
358    Returns:
359        The edge features.
360    """
361    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
362
363    # compute the node features
364    stat_feature_names = ["Count", "Kurtosis", "Maximum", "Minimum", "Quantiles",
365                          "RegionRadii", "Skewness", "Sum", "Variance"]
366    coord_feature_names = ["Weighted<RegionCenter>", "RegionCenter"]
367    feature_names = stat_feature_names + coord_feature_names
368    node_features = _region_features(input_map, segmentation, feature_names)
369
370    # get the image statistics based features, that are combined via [min, max, sum, absdiff]
371    stat_features = [node_features[fname] for fname in stat_feature_names]
372    stat_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat
373                                    for feat in stat_features], axis=1)
374
375    # get the coordinate based features, that are combined via euclidean distance
376    coord_features = [node_features[fname] for fname in coord_feature_names]
377    coord_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat
378                                     for feat in coord_features], axis=1)
379
380    u, v = uv_ids[:, 0], uv_ids[:, 1]
381
382    # combine the stat features for all edges
383    feats_u, feats_v = stat_features[u], stat_features[v]
384    features = [np.minimum(feats_u, feats_v), np.maximum(feats_u, feats_v),
385                np.abs(feats_u - feats_v), feats_u + feats_v]
386
387    # combine the coord features for all edges
388    feats_u, feats_v = coord_features[u], coord_features[v]
389    features.append((feats_u - feats_v) ** 2)
390
391    features = np.nan_to_num(np.concatenate(features, axis=1))
392    assert len(features) == len(uv_ids)
393    return features
394
395
396#
397# Grid Graph and Features
398#
399
400def compute_grid_graph(shape: Tuple[int, ...]):
401    """Compute grid graph for the given shape.
402
403    Args:
404        shape: The shape of the data.
405
406    Returns:
407        The grid graph.
408    """
409    return bic.graph.grid_graph(shape)
410
411
412def _nn_offsets(ndim):
413    return [[-1 if i == d else 0 for i in range(ndim)] for d in range(ndim)]
414
415
416def _apply_strides(edges, weights, strides, randomize_strides):
417    """Subsample (edges, weights) along the spatial periodicity defined by `strides`.
418
419    Mirrors the behaviour of nifty's strides/randomize_strides parameter without
420    spatial information: we simply keep one out of every `prod(strides)` entries
421    (or a random subset of the same size if `randomize_strides` is True).
422    """
423    if strides is None:
424        return edges, weights
425    keep = int(np.prod(strides))
426    if keep <= 1:
427        return edges, weights
428    n = len(edges)
429    if randomize_strides:
430        idx = np.random.choice(n, size=max(1, n // keep), replace=False)
431        idx.sort()
432    else:
433        idx = np.arange(0, n, keep)
434    return edges[idx], weights[idx]
435
436
437def compute_grid_graph_image_features(
438    grid_graph,
439    image: np.ndarray,
440    mode: str,
441    offsets: Optional[List[List[int]]] = None,
442    strides: Optional[List[int]] = None,
443    randomize_strides: bool = False,
444) -> Tuple[np.ndarray, np.ndarray]:
445    """Compute edge features for image for the given grid_graph.
446
447    Args:
448        grid_graph: The grid graph.
449        image: The image, from which the features will be derived.
450        mode: Feature accumulation method. For multi-channel images, one of
451            "l1", "l2", "cosine". For scalar images (without channels) only
452            grid-boundary averaging is supported (any mode value is accepted).
453        offsets: The offsets, which correspond to the affinity channels.
454        strides: The strides used to subsample edges that are computed from offsets.
455        randomize_strides: Whether to subsample randomly instead of using regular strides.
456
457    Returns:
458        The uv ids of the edges.
459        The edge features.
460    """
461    gndim = len(grid_graph.shape)
462
463    if image.ndim == gndim:
464        if offsets is not None:
465            raise NotImplementedError("Offsets with scalar images are not supported.")
466        weights = bic.graph.features.grid_boundary_features(grid_graph, image.astype("float32"))
467        edges = grid_graph.uv_ids()
468        return edges, weights
469
470    if image.ndim != gndim + 1:
471        raise ValueError(f"Invalid image dimension {image.ndim}, expected {gndim} or {gndim + 1}")
472
473    modes = ("l1", "l2", "cosine")
474    if mode not in modes:
475        raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}")
476
477    if offsets is None:
478        # Compute affinities between adjacent pixels using nearest-neighbor offsets.
479        nn_offs = _nn_offsets(gndim)
480        affs = bic.affinities.compute_embedding_distances(
481            image.astype("float32"), nn_offs, norm=mode,
482        )
483        weights, _valid = bic.graph.features.grid_affinity_features(grid_graph, affs, nn_offs)
484        edges = grid_graph.uv_ids()
485        return edges, weights
486
487    # General path with arbitrary offsets: compute affinities then use _with_lifted.
488    affs = bic.affinities.compute_embedding_distances(
489        image.astype("float32"), offsets, norm=mode,
490    )
491    local_w, local_valid, lifted_uvs, lifted_w, _ = bic.graph.features.grid_affinity_features_with_lifted(
492        grid_graph, affs, offsets,
493    )
494    edges = np.concatenate([grid_graph.uv_ids()[local_valid], lifted_uvs], axis=0)
495    weights = np.concatenate([local_w[local_valid], lifted_w], axis=0)
496    return _apply_strides(edges, weights, strides, randomize_strides)
497
498
499def compute_grid_graph_affinity_features(
500    grid_graph,
501    affinities: np.ndarray,
502    offsets: Optional[List[List[int]]] = None,
503    strides: Optional[List[int]] = None,
504    mask: Optional[np.ndarray] = None,
505    randomize_strides: bool = False,
506) -> Tuple[np.ndarray, np.ndarray]:
507    """Compute edge features from affinities for the given grid graph.
508
509    Args:
510        grid_graph: The grid graph.
511        affinities: The affinity map.
512        offsets: The offsets, which correspond to the affinity channels.
513        strides: The strides used to subsample edges that are computed from offsets.
514        mask: Mask to exclude from the edge and feature computation.
515        randomize_strides: Whether to subsample randomly instead of using regular strides.
516
517    Returns:
518        The uv ids of the edges.
519        The edge features.
520    """
521    gndim = len(grid_graph.shape)
522    if affinities.ndim != gndim + 1:
523        raise ValueError("affinities must have shape (channels, *grid_graph.shape)")
524
525    if offsets is None:
526        assert affinities.shape[0] == gndim
527        assert strides is None
528        assert mask is None
529        nn_offs = _nn_offsets(gndim)
530        weights, _valid = bic.graph.features.grid_affinity_features(grid_graph, affinities, nn_offs)
531        edges = grid_graph.uv_ids()
532        return edges, weights
533
534    local_w, local_valid, lifted_uvs, lifted_w, _ = bic.graph.features.grid_affinity_features_with_lifted(
535        grid_graph, affinities, offsets,
536    )
537    edges = np.concatenate([grid_graph.uv_ids()[local_valid], lifted_uvs], axis=0)
538    weights = np.concatenate([local_w[local_valid], lifted_w], axis=0)
539
540    if mask is not None:
541        assert strides is None and not randomize_strides, "Strides and mask cannot be used at the same time"
542        shape = tuple(grid_graph.shape)
543        assert mask.shape == shape, (
544            "compute_grid_graph_affinity_features with a per-pixel mask expects mask.shape == grid_graph.shape; "
545            "per-channel edge masks are only supported on legacy nifty grid graphs."
546        )
547        node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape)
548        masked_ids = node_ids[~mask]
549        edge_state = np.isin(edges, masked_ids).sum(axis=1)
550        keep = edge_state != 2
551        edges, weights = edges[keep], weights[keep]
552        return edges, weights
553
554    return _apply_strides(edges, weights, strides, randomize_strides)
555
556
557def apply_mask_to_grid_graph_weights(
558    grid_graph,
559    mask: np.ndarray,
560    weights: np.ndarray,
561    masked_edge_weight: float = 0.0,
562    transition_edge_weight: float = 1.0,
563) -> np.ndarray:
564    """Mask edges in grid graph.
565
566    Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes
567    and edges that connect masked and unmasked nodes.
568
569    Args:
570        grid_graph: The grid graph.
571        mask: The binary mask, foreground (=non-masked) is True.
572        weights: The edge weights.
573        masked_edge_weight: The value for edges that connect two masked nodes.
574        transition_edge_weight: The value for edges that connect a masked with a non-masked node.
575
576    Returns:
577        The masked edge weights.
578    """
579    assert np.dtype(mask.dtype) == np.dtype("bool")
580    shape = tuple(grid_graph.shape)
581    assert mask.shape == shape, f"{mask.shape}, {shape}"
582    node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape)
583    masked_ids = node_ids[~mask]
584
585    edges = grid_graph.uv_ids()
586    assert len(edges) == len(weights)
587    edge_state = np.isin(edges, masked_ids).sum(axis=1)
588    masked_edges = edge_state == 2
589    transition_edges = edge_state == 1
590    weights[masked_edges] = masked_edge_weight
591    weights[transition_edges] = transition_edge_weight
592    return weights
593
594
595def apply_mask_to_grid_graph_edges_and_weights(
596    grid_graph, mask: np.ndarray, edges: np.ndarray, weights: np.ndarray, transition_edge_weight: float = 1.0
597) -> Tuple[np.ndarray, np.ndarray]:
598    """Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value.
599
600    Args:
601        grid_graph: The grid graph.
602        mask: The binary mask, foreground (=non-masked) is True.
603        edges: The edges (uv-ids).
604        weights: The edge weights.
605        transition_edge_weight: The value for edges that connect a masked with a non-masked node.
606
607    Returns:
608        The edge uv-ids.
609        The edge weights.
610    """
611    assert np.dtype(mask.dtype) == np.dtype("bool")
612    shape = tuple(grid_graph.shape)
613    assert mask.shape == shape, f"{mask.shape}, {shape}"
614    node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape)
615    masked_ids = node_ids[~mask]
616
617    edge_state = np.isin(edges, masked_ids).sum(axis=1)
618    keep_edges = edge_state != 2
619
620    edges, weights, edge_state = edges[keep_edges], weights[keep_edges], edge_state[keep_edges]
621    transition_edges = edge_state == 1
622    weights[transition_edges] = transition_edge_weight
623
624    return edges, weights
625
626
627#
628# Lifted Features
629#
630
631def lifted_edges_from_graph_neighborhood(graph, max_graph_distance):
632    """@private
633    """
634    if max_graph_distance < 2:
635        raise ValueError(f"Graph distance must be greater equal 2, got {max_graph_distance}")
636    # With all-zero node_labels and mode='all', every node pair within the BFS hop window
637    # [2, max_graph_distance] is returned (base-graph edges excluded).
638    node_labels = np.zeros(graph.number_of_nodes, dtype="uint64")
639    lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels(
640        graph, node_labels, graph_depth=max_graph_distance, mode="all",
641    )
642    return lifted_uvs
643
644
645def feats_to_costs_default(lifted_labels, lifted_features):
646    """@private
647    """
648    # we assume that we only have different classes for a given lifted
649    # edge here (mode = "different") and then set all edges to be repulsive
650
651    # the higher the class probability, the more repulsive the edges should be,
652    # so we just multiply both probabilities
653    lifted_costs = lifted_features[:, 0] * lifted_features[:, 1]
654    lifted_costs = transform_probabilities_to_costs(lifted_costs)
655    return lifted_costs
656
657
658def lifted_problem_from_probabilities(
659    rag,
660    watershed: np.ndarray,
661    input_maps: List[np.ndarray],
662    assignment_threshold: float,
663    graph_depth: int,
664    feats_to_costs: callable = feats_to_costs_default,
665    mode: str = "different",
666    n_threads: Optional[int] = None,
667) -> Tuple[np.ndarray, np.ndarray]:
668    """Compute lifted problem from probability maps by mapping them to superpixels.
669
670    Args:
671        rag: The region adjacency graph.
672        watershed: The watershed over-segmentation.
673        input_maps: List of probability maps. Each map must have the same shape as the watersheds.
674        assignment_threshold: Minimal expression level to assign a class to a graph node.
675        graph_depth: Maximal graph depth up to which lifted edges will be included.
676        feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities.
677        mode: The mode for insertion of lifted edges. One of "all", "different", "same".
678        n_threads: The number of threads used for the calculation.
679
680    Returns:
681        The lifted uv ids.
682        The lifted costs.
683    """
684    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
685    assert isinstance(input_maps, (list, tuple))
686    assert all(isinstance(inp, np.ndarray) for inp in input_maps)
687    shape = watershed.shape
688    assert all(inp.shape == shape for inp in input_maps)
689
690    n_nodes = int(watershed.max()) + 1
691    node_labels = np.zeros(n_nodes, dtype="uint64")
692    node_features = np.zeros(n_nodes, dtype="float32")
693    for class_id, inp in enumerate(input_maps):
694        mean_prob = _region_features(inp, watershed, ["mean"])["mean"]
695        class_mask = mean_prob > assignment_threshold
696        node_labels[class_mask] = class_id
697        node_features[class_mask] = mean_prob[class_mask]
698
699    lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels(
700        rag, node_labels, graph_depth=graph_depth, mode=mode,
701        ignore_label=0, number_of_threads=n_threads,
702    )
703    lifted_labels = node_labels[lifted_uvs]
704    lifted_features = node_features[lifted_uvs]
705
706    lifted_costs = feats_to_costs(lifted_labels, lifted_features)
707    return lifted_uvs, lifted_costs
708
709
710def lifted_problem_from_segmentation(
711    rag,
712    watershed: np.ndarray,
713    input_segmentation: np.ndarray,
714    overlap_threshold: float,
715    graph_depth: int,
716    same_segment_cost: float,
717    different_segment_cost: float,
718    mode: str = "all",
719    n_threads: Optional[int] = None,
720) -> Tuple[np.ndarray, np.ndarray]:
721    """Compute lifted problem from segmentation by mapping segments to superpixels.
722
723    Args:
724        rag: The region adjacency graph.
725        watershed: The watershed over-segmentation.
726        input_segmentation: The segmentation used to determine node attribution.
727        overlap_threshold: The minimal overlap to assign a segment id to node.
728        graph_depth: The maximal graph depth up to which lifted edges will be included.
729        same_segment_cost: The cost for edges between nodes with same segment id attribution.
730        different_segment_cost: The cost for edges between nodes with different segment id attribution.
731        mode: The mode for insertion of lifted edges. One of "all", "different", "same".
732        n_threads: The number of threads used for the calculation.
733
734    Returns:
735        The lifted uv ids.
736        The lifted costs.
737    """
738    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
739    assert input_segmentation.shape == watershed.shape
740
741    ovlp = bic.utils.segmentation_overlap(watershed, input_segmentation)
742    ws_ids = np.unique(watershed)
743    n_labels = int(ws_ids[-1]) + 1
744    assert n_labels == rag.number_of_nodes, "%i, %i" % (n_labels, rag.number_of_nodes)
745
746    node_labels = np.zeros(n_labels, dtype="uint64")
747    node_label_vals = np.zeros(len(ws_ids), dtype="uint64")
748    overlap_values = np.zeros(len(ws_ids), dtype="float64")
749    for i, ws_id in enumerate(ws_ids):
750        best = ovlp.best_overlap_for_label_a(int(ws_id), ignore_zero=False)
751        node_label_vals[i] = best.label
752        overlap_values[i] = best.fraction
753    node_label_vals[overlap_values < overlap_threshold] = 0
754    node_labels[ws_ids] = node_label_vals
755
756    lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels(
757        rag, node_labels, graph_depth=graph_depth, mode=mode,
758        ignore_label=0, number_of_threads=n_threads,
759    )
760    assert lifted_uvs.max() < rag.number_of_nodes, "%i, %i" % (int(lifted_uvs.max()), rag.number_of_nodes)
761    lifted_labels = node_labels[lifted_uvs]
762    lifted_costs = np.zeros(len(lifted_labels), dtype="float64")
763
764    same_mask = lifted_labels[:, 0] == lifted_labels[:, 1]
765    lifted_costs[same_mask] = same_segment_cost
766    lifted_costs[~same_mask] = different_segment_cost
767
768    return lifted_uvs, lifted_costs
769
770
771#
772# Misc
773#
774
775def get_stitch_edges(
776    rag,
777    seg: np.ndarray,
778    block_shape: Tuple[int, ...],
779    n_threads: Optional[int] = None,
780    verbose: bool = False
781) -> np.ndarray:
782    """Get the edges between blocks.
783
784    Args:
785        rag: The region adjacency graph.
786        seg: The segmentation underlying the rag.
787        block_shape: The shape of the blocking.
788        n_threads: The number of threads used for the calculation.
789        verbose: Whether to be verbose.
790
791    Returns:
792        The edge mask indicating edges between blocks.
793    """
794    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
795    ndim = seg.ndim
796    blocking = bic.utils.Blocking([0] * ndim, list(seg.shape), list(block_shape))
797
798    def find_stitch_edges(block_id):
799        stitch_edges = []
800        block = blocking.get_block(block_id)
801        for axis in range(ndim):
802            if blocking.get_neighbor_id(block_id, axis, True) == -1:
803                continue
804            face_a = tuple(
805                beg if d == axis else slice(beg, end)
806                for d, beg, end in zip(range(ndim), block.begin, block.end)
807            )
808            face_b = tuple(
809                beg - 1 if d == axis else slice(beg, end)
810                for d, beg, end in zip(range(ndim), block.begin, block.end)
811            )
812
813            labels_a = seg[face_a].ravel()
814            labels_b = seg[face_b].ravel()
815
816            uv_ids = np.concatenate(
817                [labels_a[:, None], labels_b[:, None]],
818                axis=1
819            )
820            uv_ids = np.unique(uv_ids, axis=0)
821
822            edge_ids = rag.find_edges(uv_ids)
823            edge_ids = edge_ids[edge_ids != -1]
824            stitch_edges.append(edge_ids)
825
826        if stitch_edges:
827            stitch_edges = np.concatenate(stitch_edges)
828            stitch_edges = np.unique(stitch_edges)
829        else:
830            stitch_edges = None
831        return stitch_edges
832
833    with futures.ThreadPoolExecutor(n_threads) as tp:
834        if verbose:
835            stitch_edges = list(tqdm(
836                tp.map(find_stitch_edges, range(blocking.number_of_blocks)),
837                total=blocking.number_of_blocks
838            ))
839        else:
840            stitch_edges = tp.map(find_stitch_edges, range(blocking.number_of_blocks))
841
842    stitch_edges = np.concatenate([st for st in stitch_edges if st is not None])
843    stitch_edges = np.unique(stitch_edges)
844    full_edges = np.zeros(rag.number_of_edges, dtype="bool")
845    full_edges[stitch_edges] = 1
846    return full_edges
847
848
849def project_node_labels_to_pixels(
850    rag, segmentation: np.ndarray, node_labels: np.ndarray, n_threads: Optional[int] = None,
851) -> np.ndarray:
852    """Project label values for graph nodes back to pixels to obtain segmentation.
853
854    Args:
855        rag: The region adjacency graph.
856        segmentation: The over-segmentation used to construct the RAG.
857        node_labels: The array with node labels.
858        n_threads: The number of threads used, set to cpu count by default.
859
860    Returns:
861        The segmentation.
862    """
863    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
864    if len(node_labels) != rag.number_of_nodes:
865        raise ValueError("Incompatible number of node labels: %i, %i" % (len(node_labels), rag.number_of_nodes))
866    # bic.graph.project_node_labels_to_pixels requires integer dtypes for both arrays.
867    if segmentation.dtype not in (np.uint32, np.uint64, np.int32, np.int64):
868        segmentation = segmentation.astype("uint64")
869    if node_labels.dtype not in (np.uint32, np.uint64, np.int32, np.int64):
870        node_labels = node_labels.astype("uint64")
871    seg = bic.graph.project_node_labels_to_pixels(rag, segmentation, node_labels, number_of_threads=n_threads)
872    return seg
873
874
875def compute_z_edge_mask(rag, watershed: np.ndarray) -> np.ndarray:
876    """Compute edge mask of in-between plane edges for flat superpixels.
877
878    Args:
879        rag: The region adjacency graph.
880        watershed: The underlying watershed over-segmentation (superpixels).
881
882    Returns:
883        The edge mask indicating in-between slice edges.
884    """
885    node_z_coords = np.zeros(rag.number_of_nodes, dtype="uint32")
886    for z in range(watershed.shape[0]):
887        node_z_coords[watershed[z]] = z
888    uv_ids = rag.uv_ids()
889    z_edge_mask = node_z_coords[uv_ids[:, 0]] != node_z_coords[uv_ids[:, 1]]
890    return z_edge_mask
def compute_rag( segmentation: numpy.ndarray, n_labels: int | None = None, n_threads: int | None = None):
38def compute_rag(segmentation: np.ndarray, n_labels: Optional[int] = None, n_threads: Optional[int] = None):
39    """Compute region adjacency graph of segmentation.
40
41    Args:
42        segmentation: The segmentation.
43        n_labels: Deprecated; ignored. Kept for backwards-compatibility.
44        n_threads: The number of threads used, set to cpu count by default.
45
46    Returns:
47        The region adjacency graph (`bioimage_cpp.graph.RegionAdjacencyGraph`).
48    """
49    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
50    if segmentation.dtype not in (np.uint32, np.uint64, np.int32, np.int64):
51        segmentation = segmentation.astype("uint32")
52    rag = bic.graph.region_adjacency_graph(segmentation, number_of_threads=n_threads)
53    return rag

Compute region adjacency graph of segmentation.

Arguments:
  • segmentation: The segmentation.
  • n_labels: Deprecated; ignored. Kept for backwards-compatibility.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The region adjacency graph (bioimage_cpp.graph.RegionAdjacencyGraph).

def compute_boundary_features( rag, segmentation: numpy.ndarray, boundary_map: numpy.ndarray, min_value: float = 0.0, max_value: float = 1.0, n_threads: int | None = None) -> numpy.ndarray:
56def compute_boundary_features(
57    rag,
58    segmentation: np.ndarray,
59    boundary_map: np.ndarray,
60    min_value: float = 0.0,  # noqa: ARG001 — deprecated, ignored
61    max_value: float = 1.0,  # noqa: ARG001 — deprecated, ignored
62    n_threads: Optional[int] = None,
63) -> np.ndarray:
64    """Compute edge features from boundary map.
65
66    Args:
67        rag: The region adjacency graph.
68        segmentation: The over-segmentation used to construct the RAG.
69        boundary_map: The boundary map.
70        min_value: Deprecated; ignored.
71        max_value: Deprecated; ignored.
72        n_threads: The number of threads used, set to cpu count by default.
73
74    Returns:
75        The edge features. Output has 12 columns
76        (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size).
77    """
78    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
79    if segmentation.shape != boundary_map.shape:
80        raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(boundary_map.shape)))
81    features = bic.graph.features.edge_map_features_complex(
82        rag, segmentation, boundary_map, number_of_threads=n_threads,
83    )
84    return features

Compute edge features from boundary map.

Arguments:
  • rag: The region adjacency graph.
  • segmentation: The over-segmentation used to construct the RAG.
  • boundary_map: The boundary map.
  • min_value: Deprecated; ignored.
  • max_value: Deprecated; ignored.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The edge features. Output has 12 columns (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size).

def compute_affinity_features( rag, segmentation: numpy.ndarray, affinity_map: numpy.ndarray, offsets: List[List[int]], min_value: float = 0.0, max_value: float = 1.0, n_threads: int | None = None) -> numpy.ndarray:
 87def compute_affinity_features(
 88    rag,
 89    segmentation: np.ndarray,
 90    affinity_map: np.ndarray,
 91    offsets: List[List[int]],
 92    min_value: float = 0.0,  # noqa: ARG001 — deprecated, ignored
 93    max_value: float = 1.0,  # noqa: ARG001 — deprecated, ignored
 94    n_threads: Optional[int] = None,
 95) -> np.ndarray:
 96    """Compute edge features from affinity map.
 97
 98    Args:
 99        rag: The region adjacency graph.
100        segmentation: The over-segmentation used to construct the RAG.
101        affinity_map: The affinity map.
102        offsets: The offsets corresponding to the affinity channels.
103        min_value: Deprecated; ignored.
104        max_value: Deprecated; ignored.
105        n_threads: The number of threads used, set to cpu count by default.
106
107    Returns:
108        The edge features. Output has 12 columns
109        (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size).
110    """
111    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
112    if segmentation.shape != affinity_map.shape[1:]:
113        raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(affinity_map.shape[1:])))
114    if len(offsets) != affinity_map.shape[0]:
115        raise ValueError("Incompatible number of channels and offsets: %i, %i" % (len(offsets),
116                                                                                  affinity_map.shape[0]))
117    features = bic.graph.features.affinity_features_complex(
118        rag, segmentation, affinity_map, offsets, number_of_threads=n_threads,
119    )
120    return features

Compute edge features from affinity map.

Arguments:
  • rag: The region adjacency graph.
  • segmentation: The over-segmentation used to construct the RAG.
  • affinity_map: The affinity map.
  • offsets: The offsets corresponding to the affinity channels.
  • min_value: Deprecated; ignored.
  • max_value: Deprecated; ignored.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The edge features. Output has 12 columns (mean, median, std, min, max, p5, p10, p25, p75, p90, p95, size).

def compute_boundary_mean_and_length( rag, segmentation: numpy.ndarray, input_: numpy.ndarray, n_threads: int | None = None) -> numpy.ndarray:
123def compute_boundary_mean_and_length(
124    rag, segmentation: np.ndarray, input_: np.ndarray, n_threads: Optional[int] = None,
125) -> np.ndarray:
126    """Compute mean value and length of boundaries.
127
128    Args:
129        rag: The region adjacency graph.
130        segmentation: The over-segmentation used to construct the RAG.
131        input_: The input map.
132        n_threads: The number of threads used, set to cpu count by default.
133
134    Returns:
135        The edge features with two columns (mean, size).
136    """
137    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
138    if segmentation.shape != input_.shape:
139        raise ValueError("Incompatible shapes: %s, %s" % (str(segmentation.shape), str(input_.shape)))
140    features = bic.graph.features.edge_map_features(
141        rag, segmentation, input_, number_of_threads=n_threads,
142    )
143    return features

Compute mean value and length of boundaries.

Arguments:
  • rag: The region adjacency graph.
  • segmentation: The over-segmentation used to construct the RAG.
  • input_: The input map.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The edge features with two columns (mean, size).

def compute_boundary_features_with_filters( rag, segmentation: numpy.ndarray, input_: numpy.ndarray, apply_2d: bool = False, n_threads: int | None = None, filters: Dict[str, List[float]] = {'gaussianSmoothing': [1.6, 4.2, 8.3], 'laplacianOfGaussian': [1.6, 4.2, 8.3], 'hessianOfGaussianEigenvalues': [1.6, 4.2, 8.3]}) -> numpy.ndarray:
167def compute_boundary_features_with_filters(
168    rag,
169    segmentation: np.ndarray,
170    input_: np.ndarray,
171    apply_2d: bool = False,
172    n_threads: Optional[int] = None,
173    filters: Dict[str, List[float]] = {"gaussianSmoothing": [1.6, 4.2, 8.3],
174                                       "laplacianOfGaussian": [1.6, 4.2, 8.3],
175                                       "hessianOfGaussianEigenvalues": [1.6, 4.2, 8.3]}
176) -> np.ndarray:
177    """Compute boundary features accumulated over filter responses on input.
178
179    Args:
180        rag: The region adjacency graph.
181        segmentation: The over-segmentation used to construct the RAG.
182        input_: The input data.
183        apply_2d: Whether to apply the filters in 2d for 3d input data.
184        n_threads: The number of threads.
185        filters: The filters to apply, expects a dictionary mapping filter names to sigma values.
186
187    Returns:
188        The edge features.
189    """
190    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
191    features = []
192
193    # apply 2d: we compute filters and derived features in parallel per filter
194    if apply_2d:
195
196        def _compute_2d(filter_name, sigma):
197            response = _filter_2d(input_, filter_name, sigma, n_threads)
198            assert response.ndim == 4
199            n_channels = response.shape[-1]
200            feats = []
201            for chan in range(n_channels):
202                chan_data = response[..., chan]
203                feats.append(compute_boundary_features(rag, segmentation, chan_data, n_threads=n_threads))
204
205            out = np.concatenate(feats, axis=1)
206            assert len(out) == rag.number_of_edges
207            return out
208
209        features = [_compute_2d(filter_name, sigma)
210                    for filter_name, sigmas in filters.items() for sigma in sigmas]
211
212    # apply 3d: we parallelize over the whole filter + feature computation
213    # this can be very memory intensive, and it would be better to parallelize inside
214    # of the loop, but 3d parallel filters in elf.parallel.filters are not working properly yet
215    else:
216
217        def _compute_3d(filter_name, sigma):
218            response = _apply_filter(filter_name, input_, sigma)
219            if response.ndim == input_.ndim:
220                response = response[..., None]
221
222            n_channels = response.shape[-1]
223            feats = []
224
225            for chan in range(n_channels):
226                chan_data = response[..., chan]
227                feats.append(compute_boundary_features(rag, segmentation, chan_data, n_threads=1))
228            out = np.concatenate(feats, axis=1)
229            assert len(out) == rag.number_of_edges, f"{len(out), {rag.number_of_edges}}"
230            return out
231
232        with futures.ThreadPoolExecutor(n_threads) as tp:
233            tasks = [tp.submit(_compute_3d, filter_name, sigma)
234                     for filter_name, sigmas in filters.items() for sigma in sigmas]
235            features = [t.result() for t in tasks]
236
237    features = np.concatenate(features, axis=1)
238    assert len(features) == rag.number_of_edges
239    return features

Compute boundary features accumulated over filter responses on input.

Arguments:
  • rag: The region adjacency graph.
  • segmentation: The over-segmentation used to construct the RAG.
  • input_: The input data.
  • apply_2d: Whether to apply the filters in 2d for 3d input data.
  • n_threads: The number of threads.
  • filters: The filters to apply, expects a dictionary mapping filter names to sigma values.
Returns:

The edge features.

def compute_region_features( uv_ids: numpy.ndarray, input_map: numpy.ndarray, segmentation: numpy.ndarray, n_threads: int | None = None) -> numpy.ndarray:
345def compute_region_features(
346    uv_ids: np.ndarray,
347    input_map: np.ndarray,
348    segmentation: np.ndarray,
349    n_threads: Optional[int] = None
350) -> np.ndarray:
351    """Compute edge features from an input map accumulated over segmentation and mapped to edges.
352
353    Args:
354        uv_ids: The edge uv ids.
355        input_: The input data.
356        segmentation: The segmentation.
357        n_threads: The number of threads used, set to cpu count by default.
358
359    Returns:
360        The edge features.
361    """
362    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
363
364    # compute the node features
365    stat_feature_names = ["Count", "Kurtosis", "Maximum", "Minimum", "Quantiles",
366                          "RegionRadii", "Skewness", "Sum", "Variance"]
367    coord_feature_names = ["Weighted<RegionCenter>", "RegionCenter"]
368    feature_names = stat_feature_names + coord_feature_names
369    node_features = _region_features(input_map, segmentation, feature_names)
370
371    # get the image statistics based features, that are combined via [min, max, sum, absdiff]
372    stat_features = [node_features[fname] for fname in stat_feature_names]
373    stat_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat
374                                    for feat in stat_features], axis=1)
375
376    # get the coordinate based features, that are combined via euclidean distance
377    coord_features = [node_features[fname] for fname in coord_feature_names]
378    coord_features = np.concatenate([feat[:, None] if feat.ndim == 1 else feat
379                                     for feat in coord_features], axis=1)
380
381    u, v = uv_ids[:, 0], uv_ids[:, 1]
382
383    # combine the stat features for all edges
384    feats_u, feats_v = stat_features[u], stat_features[v]
385    features = [np.minimum(feats_u, feats_v), np.maximum(feats_u, feats_v),
386                np.abs(feats_u - feats_v), feats_u + feats_v]
387
388    # combine the coord features for all edges
389    feats_u, feats_v = coord_features[u], coord_features[v]
390    features.append((feats_u - feats_v) ** 2)
391
392    features = np.nan_to_num(np.concatenate(features, axis=1))
393    assert len(features) == len(uv_ids)
394    return features

Compute edge features from an input map accumulated over segmentation and mapped to edges.

Arguments:
  • uv_ids: The edge uv ids.
  • input_: The input data.
  • segmentation: The segmentation.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The edge features.

def compute_grid_graph(shape: Tuple[int, ...]):
401def compute_grid_graph(shape: Tuple[int, ...]):
402    """Compute grid graph for the given shape.
403
404    Args:
405        shape: The shape of the data.
406
407    Returns:
408        The grid graph.
409    """
410    return bic.graph.grid_graph(shape)

Compute grid graph for the given shape.

Arguments:
  • shape: The shape of the data.
Returns:

The grid graph.

def compute_grid_graph_image_features( grid_graph, image: numpy.ndarray, mode: str, offsets: List[List[int]] | None = None, strides: List[int] | None = None, randomize_strides: bool = False) -> Tuple[numpy.ndarray, numpy.ndarray]:
438def compute_grid_graph_image_features(
439    grid_graph,
440    image: np.ndarray,
441    mode: str,
442    offsets: Optional[List[List[int]]] = None,
443    strides: Optional[List[int]] = None,
444    randomize_strides: bool = False,
445) -> Tuple[np.ndarray, np.ndarray]:
446    """Compute edge features for image for the given grid_graph.
447
448    Args:
449        grid_graph: The grid graph.
450        image: The image, from which the features will be derived.
451        mode: Feature accumulation method. For multi-channel images, one of
452            "l1", "l2", "cosine". For scalar images (without channels) only
453            grid-boundary averaging is supported (any mode value is accepted).
454        offsets: The offsets, which correspond to the affinity channels.
455        strides: The strides used to subsample edges that are computed from offsets.
456        randomize_strides: Whether to subsample randomly instead of using regular strides.
457
458    Returns:
459        The uv ids of the edges.
460        The edge features.
461    """
462    gndim = len(grid_graph.shape)
463
464    if image.ndim == gndim:
465        if offsets is not None:
466            raise NotImplementedError("Offsets with scalar images are not supported.")
467        weights = bic.graph.features.grid_boundary_features(grid_graph, image.astype("float32"))
468        edges = grid_graph.uv_ids()
469        return edges, weights
470
471    if image.ndim != gndim + 1:
472        raise ValueError(f"Invalid image dimension {image.ndim}, expected {gndim} or {gndim + 1}")
473
474    modes = ("l1", "l2", "cosine")
475    if mode not in modes:
476        raise ValueError(f"Invalid feature mode {mode}, expect one of {modes}")
477
478    if offsets is None:
479        # Compute affinities between adjacent pixels using nearest-neighbor offsets.
480        nn_offs = _nn_offsets(gndim)
481        affs = bic.affinities.compute_embedding_distances(
482            image.astype("float32"), nn_offs, norm=mode,
483        )
484        weights, _valid = bic.graph.features.grid_affinity_features(grid_graph, affs, nn_offs)
485        edges = grid_graph.uv_ids()
486        return edges, weights
487
488    # General path with arbitrary offsets: compute affinities then use _with_lifted.
489    affs = bic.affinities.compute_embedding_distances(
490        image.astype("float32"), offsets, norm=mode,
491    )
492    local_w, local_valid, lifted_uvs, lifted_w, _ = bic.graph.features.grid_affinity_features_with_lifted(
493        grid_graph, affs, offsets,
494    )
495    edges = np.concatenate([grid_graph.uv_ids()[local_valid], lifted_uvs], axis=0)
496    weights = np.concatenate([local_w[local_valid], lifted_w], axis=0)
497    return _apply_strides(edges, weights, strides, randomize_strides)

Compute edge features for image for the given grid_graph.

Arguments:
  • grid_graph: The grid graph.
  • image: The image, from which the features will be derived.
  • mode: Feature accumulation method. For multi-channel images, one of "l1", "l2", "cosine". For scalar images (without channels) only grid-boundary averaging is supported (any mode value is accepted).
  • offsets: The offsets, which correspond to the affinity channels.
  • strides: The strides used to subsample edges that are computed from offsets.
  • randomize_strides: Whether to subsample randomly instead of using regular strides.
Returns:

The uv ids of the edges. The edge features.

def compute_grid_graph_affinity_features( grid_graph, affinities: numpy.ndarray, offsets: List[List[int]] | None = None, strides: List[int] | None = None, mask: numpy.ndarray | None = None, randomize_strides: bool = False) -> Tuple[numpy.ndarray, numpy.ndarray]:
500def compute_grid_graph_affinity_features(
501    grid_graph,
502    affinities: np.ndarray,
503    offsets: Optional[List[List[int]]] = None,
504    strides: Optional[List[int]] = None,
505    mask: Optional[np.ndarray] = None,
506    randomize_strides: bool = False,
507) -> Tuple[np.ndarray, np.ndarray]:
508    """Compute edge features from affinities for the given grid graph.
509
510    Args:
511        grid_graph: The grid graph.
512        affinities: The affinity map.
513        offsets: The offsets, which correspond to the affinity channels.
514        strides: The strides used to subsample edges that are computed from offsets.
515        mask: Mask to exclude from the edge and feature computation.
516        randomize_strides: Whether to subsample randomly instead of using regular strides.
517
518    Returns:
519        The uv ids of the edges.
520        The edge features.
521    """
522    gndim = len(grid_graph.shape)
523    if affinities.ndim != gndim + 1:
524        raise ValueError("affinities must have shape (channels, *grid_graph.shape)")
525
526    if offsets is None:
527        assert affinities.shape[0] == gndim
528        assert strides is None
529        assert mask is None
530        nn_offs = _nn_offsets(gndim)
531        weights, _valid = bic.graph.features.grid_affinity_features(grid_graph, affinities, nn_offs)
532        edges = grid_graph.uv_ids()
533        return edges, weights
534
535    local_w, local_valid, lifted_uvs, lifted_w, _ = bic.graph.features.grid_affinity_features_with_lifted(
536        grid_graph, affinities, offsets,
537    )
538    edges = np.concatenate([grid_graph.uv_ids()[local_valid], lifted_uvs], axis=0)
539    weights = np.concatenate([local_w[local_valid], lifted_w], axis=0)
540
541    if mask is not None:
542        assert strides is None and not randomize_strides, "Strides and mask cannot be used at the same time"
543        shape = tuple(grid_graph.shape)
544        assert mask.shape == shape, (
545            "compute_grid_graph_affinity_features with a per-pixel mask expects mask.shape == grid_graph.shape; "
546            "per-channel edge masks are only supported on legacy nifty grid graphs."
547        )
548        node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape)
549        masked_ids = node_ids[~mask]
550        edge_state = np.isin(edges, masked_ids).sum(axis=1)
551        keep = edge_state != 2
552        edges, weights = edges[keep], weights[keep]
553        return edges, weights
554
555    return _apply_strides(edges, weights, strides, randomize_strides)

Compute edge features from affinities for the given grid graph.

Arguments:
  • grid_graph: The grid graph.
  • affinities: The affinity map.
  • offsets: The offsets, which correspond to the affinity channels.
  • strides: The strides used to subsample edges that are computed from offsets.
  • mask: Mask to exclude from the edge and feature computation.
  • randomize_strides: Whether to subsample randomly instead of using regular strides.
Returns:

The uv ids of the edges. The edge features.

def apply_mask_to_grid_graph_weights( grid_graph, mask: numpy.ndarray, weights: numpy.ndarray, masked_edge_weight: float = 0.0, transition_edge_weight: float = 1.0) -> numpy.ndarray:
558def apply_mask_to_grid_graph_weights(
559    grid_graph,
560    mask: np.ndarray,
561    weights: np.ndarray,
562    masked_edge_weight: float = 0.0,
563    transition_edge_weight: float = 1.0,
564) -> np.ndarray:
565    """Mask edges in grid graph.
566
567    Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes
568    and edges that connect masked and unmasked nodes.
569
570    Args:
571        grid_graph: The grid graph.
572        mask: The binary mask, foreground (=non-masked) is True.
573        weights: The edge weights.
574        masked_edge_weight: The value for edges that connect two masked nodes.
575        transition_edge_weight: The value for edges that connect a masked with a non-masked node.
576
577    Returns:
578        The masked edge weights.
579    """
580    assert np.dtype(mask.dtype) == np.dtype("bool")
581    shape = tuple(grid_graph.shape)
582    assert mask.shape == shape, f"{mask.shape}, {shape}"
583    node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape)
584    masked_ids = node_ids[~mask]
585
586    edges = grid_graph.uv_ids()
587    assert len(edges) == len(weights)
588    edge_state = np.isin(edges, masked_ids).sum(axis=1)
589    masked_edges = edge_state == 2
590    transition_edges = edge_state == 1
591    weights[masked_edges] = masked_edge_weight
592    weights[transition_edges] = transition_edge_weight
593    return weights

Mask edges in grid graph.

Set the weights derived from a grid graph to a fixed value, for edges that connect masked nodes and edges that connect masked and unmasked nodes.

Arguments:
  • grid_graph: The grid graph.
  • mask: The binary mask, foreground (=non-masked) is True.
  • weights: The edge weights.
  • masked_edge_weight: The value for edges that connect two masked nodes.
  • transition_edge_weight: The value for edges that connect a masked with a non-masked node.
Returns:

The masked edge weights.

def apply_mask_to_grid_graph_edges_and_weights( grid_graph, mask: numpy.ndarray, edges: numpy.ndarray, weights: numpy.ndarray, transition_edge_weight: float = 1.0) -> Tuple[numpy.ndarray, numpy.ndarray]:
596def apply_mask_to_grid_graph_edges_and_weights(
597    grid_graph, mask: np.ndarray, edges: np.ndarray, weights: np.ndarray, transition_edge_weight: float = 1.0
598) -> Tuple[np.ndarray, np.ndarray]:
599    """Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value.
600
601    Args:
602        grid_graph: The grid graph.
603        mask: The binary mask, foreground (=non-masked) is True.
604        edges: The edges (uv-ids).
605        weights: The edge weights.
606        transition_edge_weight: The value for edges that connect a masked with a non-masked node.
607
608    Returns:
609        The edge uv-ids.
610        The edge weights.
611    """
612    assert np.dtype(mask.dtype) == np.dtype("bool")
613    shape = tuple(grid_graph.shape)
614    assert mask.shape == shape, f"{mask.shape}, {shape}"
615    node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape)
616    masked_ids = node_ids[~mask]
617
618    edge_state = np.isin(edges, masked_ids).sum(axis=1)
619    keep_edges = edge_state != 2
620
621    edges, weights, edge_state = edges[keep_edges], weights[keep_edges], edge_state[keep_edges]
622    transition_edges = edge_state == 1
623    weights[transition_edges] = transition_edge_weight
624
625    return edges, weights

Remove uv ids that connect masked nodes and set weights that connect masked to non-masked nodes to a fixed value.

Arguments:
  • grid_graph: The grid graph.
  • mask: The binary mask, foreground (=non-masked) is True.
  • edges: The edges (uv-ids).
  • weights: The edge weights.
  • transition_edge_weight: The value for edges that connect a masked with a non-masked node.
Returns:

The edge uv-ids. The edge weights.

def lifted_problem_from_probabilities( rag, watershed: numpy.ndarray, input_maps: List[numpy.ndarray], assignment_threshold: float, graph_depth: int, feats_to_costs: <built-in function callable> = <function feats_to_costs_default>, mode: str = 'different', n_threads: int | None = None) -> Tuple[numpy.ndarray, numpy.ndarray]:
659def lifted_problem_from_probabilities(
660    rag,
661    watershed: np.ndarray,
662    input_maps: List[np.ndarray],
663    assignment_threshold: float,
664    graph_depth: int,
665    feats_to_costs: callable = feats_to_costs_default,
666    mode: str = "different",
667    n_threads: Optional[int] = None,
668) -> Tuple[np.ndarray, np.ndarray]:
669    """Compute lifted problem from probability maps by mapping them to superpixels.
670
671    Args:
672        rag: The region adjacency graph.
673        watershed: The watershed over-segmentation.
674        input_maps: List of probability maps. Each map must have the same shape as the watersheds.
675        assignment_threshold: Minimal expression level to assign a class to a graph node.
676        graph_depth: Maximal graph depth up to which lifted edges will be included.
677        feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities.
678        mode: The mode for insertion of lifted edges. One of "all", "different", "same".
679        n_threads: The number of threads used for the calculation.
680
681    Returns:
682        The lifted uv ids.
683        The lifted costs.
684    """
685    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
686    assert isinstance(input_maps, (list, tuple))
687    assert all(isinstance(inp, np.ndarray) for inp in input_maps)
688    shape = watershed.shape
689    assert all(inp.shape == shape for inp in input_maps)
690
691    n_nodes = int(watershed.max()) + 1
692    node_labels = np.zeros(n_nodes, dtype="uint64")
693    node_features = np.zeros(n_nodes, dtype="float32")
694    for class_id, inp in enumerate(input_maps):
695        mean_prob = _region_features(inp, watershed, ["mean"])["mean"]
696        class_mask = mean_prob > assignment_threshold
697        node_labels[class_mask] = class_id
698        node_features[class_mask] = mean_prob[class_mask]
699
700    lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels(
701        rag, node_labels, graph_depth=graph_depth, mode=mode,
702        ignore_label=0, number_of_threads=n_threads,
703    )
704    lifted_labels = node_labels[lifted_uvs]
705    lifted_features = node_features[lifted_uvs]
706
707    lifted_costs = feats_to_costs(lifted_labels, lifted_features)
708    return lifted_uvs, lifted_costs

Compute lifted problem from probability maps by mapping them to superpixels.

Arguments:
  • rag: The region adjacency graph.
  • watershed: The watershed over-segmentation.
  • input_maps: List of probability maps. Each map must have the same shape as the watersheds.
  • assignment_threshold: Minimal expression level to assign a class to a graph node.
  • graph_depth: Maximal graph depth up to which lifted edges will be included.
  • feats_to_costs: Function to calculate the lifted costs from the class assignment probabilities.
  • mode: The mode for insertion of lifted edges. One of "all", "different", "same".
  • n_threads: The number of threads used for the calculation.
Returns:

The lifted uv ids. The lifted costs.

def lifted_problem_from_segmentation( rag, watershed: numpy.ndarray, input_segmentation: numpy.ndarray, overlap_threshold: float, graph_depth: int, same_segment_cost: float, different_segment_cost: float, mode: str = 'all', n_threads: int | None = None) -> Tuple[numpy.ndarray, numpy.ndarray]:
711def lifted_problem_from_segmentation(
712    rag,
713    watershed: np.ndarray,
714    input_segmentation: np.ndarray,
715    overlap_threshold: float,
716    graph_depth: int,
717    same_segment_cost: float,
718    different_segment_cost: float,
719    mode: str = "all",
720    n_threads: Optional[int] = None,
721) -> Tuple[np.ndarray, np.ndarray]:
722    """Compute lifted problem from segmentation by mapping segments to superpixels.
723
724    Args:
725        rag: The region adjacency graph.
726        watershed: The watershed over-segmentation.
727        input_segmentation: The segmentation used to determine node attribution.
728        overlap_threshold: The minimal overlap to assign a segment id to node.
729        graph_depth: The maximal graph depth up to which lifted edges will be included.
730        same_segment_cost: The cost for edges between nodes with same segment id attribution.
731        different_segment_cost: The cost for edges between nodes with different segment id attribution.
732        mode: The mode for insertion of lifted edges. One of "all", "different", "same".
733        n_threads: The number of threads used for the calculation.
734
735    Returns:
736        The lifted uv ids.
737        The lifted costs.
738    """
739    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
740    assert input_segmentation.shape == watershed.shape
741
742    ovlp = bic.utils.segmentation_overlap(watershed, input_segmentation)
743    ws_ids = np.unique(watershed)
744    n_labels = int(ws_ids[-1]) + 1
745    assert n_labels == rag.number_of_nodes, "%i, %i" % (n_labels, rag.number_of_nodes)
746
747    node_labels = np.zeros(n_labels, dtype="uint64")
748    node_label_vals = np.zeros(len(ws_ids), dtype="uint64")
749    overlap_values = np.zeros(len(ws_ids), dtype="float64")
750    for i, ws_id in enumerate(ws_ids):
751        best = ovlp.best_overlap_for_label_a(int(ws_id), ignore_zero=False)
752        node_label_vals[i] = best.label
753        overlap_values[i] = best.fraction
754    node_label_vals[overlap_values < overlap_threshold] = 0
755    node_labels[ws_ids] = node_label_vals
756
757    lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels(
758        rag, node_labels, graph_depth=graph_depth, mode=mode,
759        ignore_label=0, number_of_threads=n_threads,
760    )
761    assert lifted_uvs.max() < rag.number_of_nodes, "%i, %i" % (int(lifted_uvs.max()), rag.number_of_nodes)
762    lifted_labels = node_labels[lifted_uvs]
763    lifted_costs = np.zeros(len(lifted_labels), dtype="float64")
764
765    same_mask = lifted_labels[:, 0] == lifted_labels[:, 1]
766    lifted_costs[same_mask] = same_segment_cost
767    lifted_costs[~same_mask] = different_segment_cost
768
769    return lifted_uvs, lifted_costs

Compute lifted problem from segmentation by mapping segments to superpixels.

Arguments:
  • rag: The region adjacency graph.
  • watershed: The watershed over-segmentation.
  • input_segmentation: The segmentation used to determine node attribution.
  • overlap_threshold: The minimal overlap to assign a segment id to node.
  • graph_depth: The maximal graph depth up to which lifted edges will be included.
  • same_segment_cost: The cost for edges between nodes with same segment id attribution.
  • different_segment_cost: The cost for edges between nodes with different segment id attribution.
  • mode: The mode for insertion of lifted edges. One of "all", "different", "same".
  • n_threads: The number of threads used for the calculation.
Returns:

The lifted uv ids. The lifted costs.

def get_stitch_edges( rag, seg: numpy.ndarray, block_shape: Tuple[int, ...], n_threads: int | None = None, verbose: bool = False) -> numpy.ndarray:
776def get_stitch_edges(
777    rag,
778    seg: np.ndarray,
779    block_shape: Tuple[int, ...],
780    n_threads: Optional[int] = None,
781    verbose: bool = False
782) -> np.ndarray:
783    """Get the edges between blocks.
784
785    Args:
786        rag: The region adjacency graph.
787        seg: The segmentation underlying the rag.
788        block_shape: The shape of the blocking.
789        n_threads: The number of threads used for the calculation.
790        verbose: Whether to be verbose.
791
792    Returns:
793        The edge mask indicating edges between blocks.
794    """
795    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
796    ndim = seg.ndim
797    blocking = bic.utils.Blocking([0] * ndim, list(seg.shape), list(block_shape))
798
799    def find_stitch_edges(block_id):
800        stitch_edges = []
801        block = blocking.get_block(block_id)
802        for axis in range(ndim):
803            if blocking.get_neighbor_id(block_id, axis, True) == -1:
804                continue
805            face_a = tuple(
806                beg if d == axis else slice(beg, end)
807                for d, beg, end in zip(range(ndim), block.begin, block.end)
808            )
809            face_b = tuple(
810                beg - 1 if d == axis else slice(beg, end)
811                for d, beg, end in zip(range(ndim), block.begin, block.end)
812            )
813
814            labels_a = seg[face_a].ravel()
815            labels_b = seg[face_b].ravel()
816
817            uv_ids = np.concatenate(
818                [labels_a[:, None], labels_b[:, None]],
819                axis=1
820            )
821            uv_ids = np.unique(uv_ids, axis=0)
822
823            edge_ids = rag.find_edges(uv_ids)
824            edge_ids = edge_ids[edge_ids != -1]
825            stitch_edges.append(edge_ids)
826
827        if stitch_edges:
828            stitch_edges = np.concatenate(stitch_edges)
829            stitch_edges = np.unique(stitch_edges)
830        else:
831            stitch_edges = None
832        return stitch_edges
833
834    with futures.ThreadPoolExecutor(n_threads) as tp:
835        if verbose:
836            stitch_edges = list(tqdm(
837                tp.map(find_stitch_edges, range(blocking.number_of_blocks)),
838                total=blocking.number_of_blocks
839            ))
840        else:
841            stitch_edges = tp.map(find_stitch_edges, range(blocking.number_of_blocks))
842
843    stitch_edges = np.concatenate([st for st in stitch_edges if st is not None])
844    stitch_edges = np.unique(stitch_edges)
845    full_edges = np.zeros(rag.number_of_edges, dtype="bool")
846    full_edges[stitch_edges] = 1
847    return full_edges

Get the edges between blocks.

Arguments:
  • rag: The region adjacency graph.
  • seg: The segmentation underlying the rag.
  • block_shape: The shape of the blocking.
  • n_threads: The number of threads used for the calculation.
  • verbose: Whether to be verbose.
Returns:

The edge mask indicating edges between blocks.

def project_node_labels_to_pixels( rag, segmentation: numpy.ndarray, node_labels: numpy.ndarray, n_threads: int | None = None) -> numpy.ndarray:
850def project_node_labels_to_pixels(
851    rag, segmentation: np.ndarray, node_labels: np.ndarray, n_threads: Optional[int] = None,
852) -> np.ndarray:
853    """Project label values for graph nodes back to pixels to obtain segmentation.
854
855    Args:
856        rag: The region adjacency graph.
857        segmentation: The over-segmentation used to construct the RAG.
858        node_labels: The array with node labels.
859        n_threads: The number of threads used, set to cpu count by default.
860
861    Returns:
862        The segmentation.
863    """
864    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
865    if len(node_labels) != rag.number_of_nodes:
866        raise ValueError("Incompatible number of node labels: %i, %i" % (len(node_labels), rag.number_of_nodes))
867    # bic.graph.project_node_labels_to_pixels requires integer dtypes for both arrays.
868    if segmentation.dtype not in (np.uint32, np.uint64, np.int32, np.int64):
869        segmentation = segmentation.astype("uint64")
870    if node_labels.dtype not in (np.uint32, np.uint64, np.int32, np.int64):
871        node_labels = node_labels.astype("uint64")
872    seg = bic.graph.project_node_labels_to_pixels(rag, segmentation, node_labels, number_of_threads=n_threads)
873    return seg

Project label values for graph nodes back to pixels to obtain segmentation.

Arguments:
  • rag: The region adjacency graph.
  • segmentation: The over-segmentation used to construct the RAG.
  • node_labels: The array with node labels.
  • n_threads: The number of threads used, set to cpu count by default.
Returns:

The segmentation.

def compute_z_edge_mask(rag, watershed: numpy.ndarray) -> numpy.ndarray:
876def compute_z_edge_mask(rag, watershed: np.ndarray) -> np.ndarray:
877    """Compute edge mask of in-between plane edges for flat superpixels.
878
879    Args:
880        rag: The region adjacency graph.
881        watershed: The underlying watershed over-segmentation (superpixels).
882
883    Returns:
884        The edge mask indicating in-between slice edges.
885    """
886    node_z_coords = np.zeros(rag.number_of_nodes, dtype="uint32")
887    for z in range(watershed.shape[0]):
888        node_z_coords[watershed[z]] = z
889    uv_ids = rag.uv_ids()
890    z_edge_mask = node_z_coords[uv_ids[:, 0]] != node_z_coords[uv_ids[:, 1]]
891    return z_edge_mask

Compute edge mask of in-between plane edges for flat superpixels.

Arguments:
  • rag: The region adjacency graph.
  • watershed: The underlying watershed over-segmentation (superpixels).
Returns:

The edge mask indicating in-between slice edges.