elf.segmentation.embeddings

  1# IMPORTANT do threadctl import first (before numpy imports)
  2from threadpoolctl import threadpool_limits
  3from typing import List, Optional
  4
  5import numpy as np
  6import vigra
  7try:
  8    import hdbscan
  9except ImportError:
 10    hdbscan = None
 11
 12from scipy.ndimage import shift
 13from sklearn.cluster import MeanShift
 14from sklearn.decomposition import PCA
 15
 16from .features import (compute_grid_graph,
 17                       compute_grid_graph_affinity_features,
 18                       compute_grid_graph_image_features)
 19from .multicut import compute_edge_costs
 20from .mutex_watershed import mutex_watershed_clustering
 21
 22#
 23# utils
 24#
 25
 26
 27def embedding_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray:
 28    """Compute PCA of per-pixel embeddings.
 29
 30    Args:
 31        embeddings: The per-pixel embeddings.
 32        n_components: The number of PCA components.
 33        as_rgb: Whether to reshape the output so that it can be displayed as RGB image.
 34
 35    Returns:
 36        The PCA of the embeddings.
 37    """
 38    if as_rgb and n_components != 3:
 39        raise ValueError("")
 40
 41    pca = PCA(n_components=n_components)
 42    embed_dim = embeddings.shape[0]
 43    shape = embeddings.shape[1:]
 44
 45    embed_flat = embeddings.reshape(embed_dim, -1).T
 46    embed_flat = pca.fit_transform(embed_flat).T
 47    embed_flat = embed_flat.reshape((n_components,) + shape)
 48
 49    if as_rgb:
 50        embed_flat = 255 * (embed_flat - embed_flat.min()) / np.ptp(embed_flat)
 51        embed_flat = embed_flat.astype("uint8")
 52
 53    return embed_flat
 54
 55
 56def _embeddings_to_probabilities(embed1, embed2, delta, embedding_axis):
 57    probs = (2 * delta - np.linalg.norm(embed1 - embed2, axis=embedding_axis)) / (2 * delta)
 58    probs = np.maximum(probs, 0) ** 2
 59    return probs
 60
 61
 62def edge_probabilities_from_embeddings(
 63    embeddings: np.ndarray, segmentation: np.ndarray, rag, delta: float
 64) -> np.ndarray:
 65    """Derive edge probabilities from pixel embeddings.
 66
 67    Args:
 68        embeddings: The pixel embeddings.
 69        segmentation: The segmentation.
 70        rag: The region adjacency graph derived from the segmentation.
 71        delta: The delta factor used in the push force when training the embeddings.
 72
 73    Returns:
 74        The edge probabilties.
 75    """
 76    n_nodes = rag.numberOfNodes
 77    embed_dim = embeddings.shape[0]
 78
 79    segmentation = segmentation.astype("uint32")
 80    mean_embeddings = np.zeros((n_nodes, embed_dim), dtype="float32")
 81    for cid in range(embed_dim):
 82        mean_embed = vigra.analysis.extractRegionFeatures(embeddings[cid], segmentation, features=["mean"])["mean"]
 83        mean_embeddings[:, cid] = mean_embed
 84
 85    uv_ids = rag.uvIds()
 86    embed_u = mean_embeddings[uv_ids[:, 0]]
 87    embed_v = mean_embeddings[uv_ids[:, 1]]
 88    edge_probabilities = 1. - _embeddings_to_probabilities(embed_u, embed_v, delta, embedding_axis=1)
 89    return edge_probabilities
 90
 91
 92# Could probably be implemented more efficiently with shift kernels instead of explicit call to shift.
 93# (or implement in C++ to save memory)
 94def embeddings_to_affinities(
 95    embeddings: np.ndarray,
 96    offsets: List[List[int]],
 97    delta: float,
 98    invert: bool = False,
 99) -> np.ndarray:
100    """Convert pixel embeddings to affinities.
101
102    Computes the affinity according to the formula
103    a_ij = max((2 * delta - ||x_i - x_j||) / 2 * delta, 0) ** 2,
104    where delta is the push force used in training the embeddings.
105    Introduced in "Learning Dense Voxel Embeddings for 3D Neuron Reconstruction":
106    https://arxiv.org/pdf/1909.09872.pdf
107
108    Args:
109        embeddings: The pixel embeddings.
110        offsets: The offset vectors for which to compute affinities.
111        delta: The delta factor used in the push force when training the embeddings.
112        invert: Whether to invert the affinites.
113
114    Returns:
115        The affinity values.
116    """
117    ndim = embeddings.ndim - 1
118    if not all(len(off) == ndim for off in offsets):
119        raise ValueError("Incosistent dimension of offsets and embeddings")
120
121    n_channels = len(offsets)
122    shape = embeddings.shape[1:]
123    affinities = np.zeros((n_channels,) + shape, dtype="float32")
124
125    for cid, off in enumerate(offsets):
126        # we need to  shift in the other direction in order to
127        # get the correct offset
128        # also, we need to add a zero shift in the first axis
129        shift_off = [0] + [-o for o in off]
130        # we could also shift via np.pad and slicing
131        shifted = shift(embeddings, shift_off, order=0, prefilter=False)
132        affs = _embeddings_to_probabilities(embeddings, shifted, delta, embedding_axis=0)
133        affinities[cid] = affs
134
135    if invert:
136        affinities = 1. - affinities
137
138    return affinities
139
140
141#
142# density based segmentation
143#
144
145
146def _cluster(embeddings, clustering_alg, semantic_mask=None, remove_largest=False):
147    output_shape = embeddings.shape[1:]
148    # reshape (E, D, H, W) -> (E, D * H * W) and transpose -> (D * H * W, E)
149    flattened_embeddings = embeddings.reshape(embeddings.shape[0], -1).transpose()
150
151    result = np.zeros(flattened_embeddings.shape[0])
152
153    if semantic_mask is not None:
154        flattened_mask = semantic_mask.reshape(-1)
155        assert flattened_mask.shape[0] == flattened_embeddings.shape[0]
156    else:
157        flattened_mask = np.ones(flattened_embeddings.shape[0])
158
159    if flattened_mask.sum() == 0:
160        # return zeros for empty masks
161        return result.reshape(output_shape)
162
163    # cluster only within the foreground mask
164    clusters = clustering_alg.fit_predict(flattened_embeddings[flattened_mask == 1])
165    # always increase the labels by 1 cause clustering results start from 0 and we may loose one object
166    result[flattened_mask == 1] = clusters + 1
167
168    if remove_largest:
169        # set largest object to 0-label
170        ids, counts = np.unique(result, return_counts=True)
171        result[ids[np.argmax(counts)] == result] = 0
172
173    return result.reshape(output_shape)
174
175
176def segment_hdbscan(
177    embeddings: np.ndarray,
178    min_size: int, eps: float,
179    remove_largest: bool,
180    n_jobs: int = 1,
181) -> np.ndarray:
182    """Compute a segmentation by clustering pixel emeddings with HDBSCAN.
183
184    Args:
185        embeddings: The pixel embeddings.
186        min_size: The minimal segment size.
187        eps: Epsilon factor for HDBSCAN.
188        remove_largest: Whether to remove the largest (=background) object.
189        n_jobs: The number of jobs for parallelizing HDBSCAN.
190
191    Returns:
192        The segmentation.
193    """
194    assert hdbscan is not None, "Needs hdbscan library"
195    with threadpool_limits(limits=n_jobs):
196        clustering = hdbscan.HDBSCAN(
197            min_cluster_size=min_size, cluster_selection_epsilon=eps, core_dist_n_jobs=n_jobs
198        )
199        result = _cluster(embeddings, clustering, remove_largest=remove_largest).astype("uint64")
200    return result
201
202
203def segment_mean_shift(embeddings: np.ndarray, bandwidth: float, n_jobs: int = 1) -> np.ndarray:
204    """Compute a segmentation by clustering pixel emeddings with mean shift.
205
206    Args:
207        embeddings: The pixel embeddings.
208        bandwidth: The bandwidth parameter for the mean shift algorithm.
209        n_jobs: The number of jobs for parallelizing MeanShift.
210
211    Returns:
212        The segmentation.
213    """
214    with threadpool_limits(limits=n_jobs):
215        clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
216        result = _cluster(embeddings, clustering).astype("uint64")
217    return result
218
219
220def segment_consistency(
221    embeddings1: np.ndarray,
222    embeddings2: np.ndarray,
223    bandwidth: float,
224    iou_threshold: float,
225    num_anchors: int,
226    skip_zero: bool = True,
227    n_jobs: int = 1
228) -> np.ndarray:
229    """Compute a segmentation by clustering pixel emeddings via mean shift and consistency.
230
231    First, the segmentation is computed using mean shift. Then, for each instance in this
232    segmentation the corresponding instance mask is derived from the second set of embeddings.
233    Masks that have a low IOU with the corresponding instance mask are removed.
234
235    Args:
236        embeddings1: The first set of pixel embeddings, used for mean shift clustering.
237        embeddings2: The second set of pixel embeddings, used for consistency.
238        bandwidth: The bandwidth parameter for the mean shift algorithm.
239        iou_threshold: The threshold for consistency filtering.
240        num_anchors: The number of anchors for computing the instance masks for consistency.
241        skip_zero: Whether to skip the background label.
242        n_jobs: The number of jobs for parallelizing MeanShift.
243
244    Returns:
245        The segmentation.
246    """
247    def _iou(gt, seg):
248        epsilon = 1e-5
249        inter = (gt & seg).sum()
250        union = (gt | seg).sum()
251
252        iou = (inter + epsilon) / (union + epsilon)
253        return iou
254
255    with threadpool_limits(limits=n_jobs):
256        clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
257        clusters = _cluster(embeddings1, clustering)
258
259    for label_id in np.unique(clusters):
260        if label_id == 0 and skip_zero:
261            continue
262
263        mask = clusters == label_id
264        iou_table = []
265        # FIXME: make it work for 3d
266        y, x = np.nonzero(mask)
267        for _ in range(num_anchors):
268            ind = np.random.randint(len(y))
269            # get random embedding anchor from emb-g
270            anchor_emb = embeddings2[:, y[ind], x[ind]]
271            # add necessary singleton dims
272            anchor_emb = anchor_emb[:, None, None]
273            # compute the instance mask from emb2
274            inst_mask = np.linalg.norm(embeddings2 - anchor_emb, axis=0) < bandwidth
275            iou_table.append(_iou(mask, inst_mask))
276        # choose final IoU as a median
277        final_iou = np.median(iou_table)
278
279        if final_iou < iou_threshold:
280            clusters[mask] = 0
281
282    return clusters.astype("uint64")
283
284
285#
286# affinity based segmentation
287#
288
289
290def _ensure_mask_is_zero(seg, mask):
291    inv_mask = ~mask
292    mask_id = seg[inv_mask][0]
293    if mask_id == 0:
294        return seg
295
296    seg_ids = np.unique(seg[mask])
297    if 0 in seg_ids:
298        seg[seg == 0] = mask_id
299    seg[inv_mask] = 0
300
301    return seg
302
303
304def _get_lr_offsets(offsets):
305    lr_offsets = [
306        off for off in offsets if np.sum(np.abs(off)) > 1
307    ]
308    return lr_offsets
309
310
311def _apply_mask(mask, g, weights, lr_edges, lr_weights):
312    assert np.dtype(mask.dtype) == np.dtype("bool")
313    node_ids = g.projectNodeIdsToPixels()
314    assert node_ids.shape == mask.shape == tuple(g.shape), f"{node_ids.shape}, {mask.shape}, {g.shape}"
315    masked_ids = node_ids[~mask]
316
317    # local edges:
318    # - set edges that connect masked nodes to max attractive
319    # - set edges that connect masked and non-masked nodes to max repulsive
320    local_edge_state = np.isin(g.uvIds(), masked_ids).sum(axis=1)
321    local_masked_edges = local_edge_state == 2
322    local_transition_edges = local_edge_state == 1
323    weights[local_masked_edges] = 0.0
324    weights[local_transition_edges] = 1.0
325
326    # lr edges:
327    # - remove edges that connect masked nodes
328    # - set all edges that connect masked and non-masked nodes to max repulsive
329    lr_edge_state = np.isin(lr_edges, masked_ids).sum(axis=1)
330    lr_keep_edges = lr_edge_state != 2
331
332    lr_edges, lr_weights, lr_edge_state = (lr_edges[lr_keep_edges],
333                                           lr_weights[lr_keep_edges],
334                                           lr_edge_state[lr_keep_edges])
335    lr_transition_edges = lr_edge_state == 1
336    lr_weights[lr_transition_edges] = 1.0
337
338    return weights, lr_edges, lr_weights
339
340
341# weight functions may normalize the weight values based on some statistics
342# calculated for all weights. It's important to apply this weighting on a per offset channel
343# basis, because long-range weights may be much larger than the short range weights.
344def _process_weights(g, edges, weights, weight_function, beta,
345                     offsets=None, strides=None, randomize_strides=None):
346
347    def apply_weight_function():
348        nonlocal weights
349        edge_ids = g.projectEdgeIdsToPixels()
350        invalid_edges = edge_ids == -1
351        edge_ids[invalid_edges] = 0
352        weights = weights[edge_ids]
353        weights[invalid_edges] = 0
354        for chan_id, weightc in enumerate(weights):
355            weights[chan_id] = weight_function(weightc)
356        edges, weights = compute_grid_graph_affinity_features(
357            g, weights
358        )
359        assert len(weights) == g.numberOfEdges
360        return edges, weights
361
362    def apply_weight_function_lr():
363        nonlocal weights
364        edge_ids = g.projectEdgeIdsToPixelsWithOffsets(offsets)
365        invalid_edges = edge_ids == -1
366        edge_ids[invalid_edges] = 0
367        weights = weights[edge_ids]
368        weights[invalid_edges] = 0
369        for chan_id, weightc in enumerate(weights):
370            weights[chan_id] = weight_function(weightc)
371        edges, weights = compute_grid_graph_affinity_features(
372            g, weights, offsets=offsets,
373            strides=strides, randomize_strides=randomize_strides
374        )
375        return edges, weights
376
377    apply_weight = weight_function is not None
378    if apply_weight and offsets is None:
379        edges, weights = apply_weight_function()
380    elif apply_weight and offsets is not None:
381        edges, weights = apply_weight_function_lr()
382
383    if beta is not None:
384        weights = compute_edge_costs(weights, beta=beta)
385
386    return edges, weights
387
388
389def _embeddings_to_problem(embed, distance_type, beta=None,
390                           offsets=None, strides=None, weight_function=None,
391                           mask=None):
392    im_shape = embed.shape[1:]
393    g = compute_grid_graph(im_shape)
394    _, weights = compute_grid_graph_image_features(g, embed, distance_type)
395    _, weights = _process_weights(g, None, weights, weight_function, beta)
396    if offsets is None:
397        return g, weights
398
399    lr_offsets = _get_lr_offsets(offsets)
400
401    # we only compute with strides if we are not applying a weight function, otherwise
402    # strides are applied later!
403    strides_, randomize_ = (strides, True) if weight_function is None else (None, False)
404
405    lr_edges, lr_weights = compute_grid_graph_image_features(
406        g, embed, distance_type, offsets=lr_offsets, strides=strides_, randomize_strides=randomize_
407    )
408
409    if mask is not None:
410        weights, lr_edges, lr_weights = _apply_mask(mask, g, weights, lr_edges, lr_weights)
411
412    lr_edges, lr_weights = _process_weights(g, lr_edges, lr_weights, weight_function, beta, offsets=lr_offsets,
413                                            strides=strides, randomize_strides=randomize_)
414    return g, weights, lr_edges, lr_weights
415
416
417# weight function based on the seung paper, using the push delta
418# of the discriminative loss term.
419def discriminative_loss_weight(dist, delta):
420    """@private
421    """
422    dist = (2 * delta - dist) / (2 * delta)
423    dist = 1. - np.maximum(dist, 0) ** 2
424    return dist
425
426
427def segment_embeddings_mws(
428    embeddings: np.ndarray,
429    distance_type: str,
430    offsets: List[List[int]],
431    bias: float = 0.0,
432    strides: List[int] = None,
433    weight_function: Optional[callable] = None,
434    mask: Optional[np.ndarray] = None,
435) -> np.ndarray:
436    """Compute a segmentation by computing a mutex watershed based on pixel emeddings.
437
438    Args:
439        embeddings: The pixel embeddings.
440        distance_type: The distance type for deriving affinities from embeddings.
441        offsets: The affinity offsets.
442        bias: Additional bias factor to apply to the affinities.
443            This can be used to reduce under-segmentation (positive value) or over-segmentation (negative value).
444        strides: The strides for sub-sampling repulsive mutex edges.
445        weight_function: Optional function for weighting the affinity values.
446        mask: Mask to ignore in the segmentation.
447
448    Returns:
449        The segmentation.
450    """
451    g, costs, mutex_uvs, mutex_costs = _embeddings_to_problem(
452        embeddings, distance_type, beta=None,
453        offsets=offsets, strides=strides,
454        weight_function=weight_function,
455        mask=mask
456    )
457    if bias > 0:
458        mutex_costs += bias
459    uvs = g.uvIds()
460    seg = mutex_watershed_clustering(uvs, mutex_uvs, costs, mutex_costs).reshape(embeddings.shape[1:])
461    if mask is not None:
462        seg = _ensure_mask_is_zero(seg, mask)
463    return seg
def embedding_pca( embeddings: numpy.ndarray, n_components: int = 3, as_rgb: bool = True) -> numpy.ndarray:
28def embedding_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray:
29    """Compute PCA of per-pixel embeddings.
30
31    Args:
32        embeddings: The per-pixel embeddings.
33        n_components: The number of PCA components.
34        as_rgb: Whether to reshape the output so that it can be displayed as RGB image.
35
36    Returns:
37        The PCA of the embeddings.
38    """
39    if as_rgb and n_components != 3:
40        raise ValueError("")
41
42    pca = PCA(n_components=n_components)
43    embed_dim = embeddings.shape[0]
44    shape = embeddings.shape[1:]
45
46    embed_flat = embeddings.reshape(embed_dim, -1).T
47    embed_flat = pca.fit_transform(embed_flat).T
48    embed_flat = embed_flat.reshape((n_components,) + shape)
49
50    if as_rgb:
51        embed_flat = 255 * (embed_flat - embed_flat.min()) / np.ptp(embed_flat)
52        embed_flat = embed_flat.astype("uint8")
53
54    return embed_flat

Compute PCA of per-pixel embeddings.

Arguments:
  • embeddings: The per-pixel embeddings.
  • n_components: The number of PCA components.
  • as_rgb: Whether to reshape the output so that it can be displayed as RGB image.
Returns:

The PCA of the embeddings.

def edge_probabilities_from_embeddings( embeddings: numpy.ndarray, segmentation: numpy.ndarray, rag, delta: float) -> numpy.ndarray:
63def edge_probabilities_from_embeddings(
64    embeddings: np.ndarray, segmentation: np.ndarray, rag, delta: float
65) -> np.ndarray:
66    """Derive edge probabilities from pixel embeddings.
67
68    Args:
69        embeddings: The pixel embeddings.
70        segmentation: The segmentation.
71        rag: The region adjacency graph derived from the segmentation.
72        delta: The delta factor used in the push force when training the embeddings.
73
74    Returns:
75        The edge probabilties.
76    """
77    n_nodes = rag.numberOfNodes
78    embed_dim = embeddings.shape[0]
79
80    segmentation = segmentation.astype("uint32")
81    mean_embeddings = np.zeros((n_nodes, embed_dim), dtype="float32")
82    for cid in range(embed_dim):
83        mean_embed = vigra.analysis.extractRegionFeatures(embeddings[cid], segmentation, features=["mean"])["mean"]
84        mean_embeddings[:, cid] = mean_embed
85
86    uv_ids = rag.uvIds()
87    embed_u = mean_embeddings[uv_ids[:, 0]]
88    embed_v = mean_embeddings[uv_ids[:, 1]]
89    edge_probabilities = 1. - _embeddings_to_probabilities(embed_u, embed_v, delta, embedding_axis=1)
90    return edge_probabilities

Derive edge probabilities from pixel embeddings.

Arguments:
  • embeddings: The pixel embeddings.
  • segmentation: The segmentation.
  • rag: The region adjacency graph derived from the segmentation.
  • delta: The delta factor used in the push force when training the embeddings.
Returns:

The edge probabilties.

def embeddings_to_affinities( embeddings: numpy.ndarray, offsets: List[List[int]], delta: float, invert: bool = False) -> numpy.ndarray:
 95def embeddings_to_affinities(
 96    embeddings: np.ndarray,
 97    offsets: List[List[int]],
 98    delta: float,
 99    invert: bool = False,
100) -> np.ndarray:
101    """Convert pixel embeddings to affinities.
102
103    Computes the affinity according to the formula
104    a_ij = max((2 * delta - ||x_i - x_j||) / 2 * delta, 0) ** 2,
105    where delta is the push force used in training the embeddings.
106    Introduced in "Learning Dense Voxel Embeddings for 3D Neuron Reconstruction":
107    https://arxiv.org/pdf/1909.09872.pdf
108
109    Args:
110        embeddings: The pixel embeddings.
111        offsets: The offset vectors for which to compute affinities.
112        delta: The delta factor used in the push force when training the embeddings.
113        invert: Whether to invert the affinites.
114
115    Returns:
116        The affinity values.
117    """
118    ndim = embeddings.ndim - 1
119    if not all(len(off) == ndim for off in offsets):
120        raise ValueError("Incosistent dimension of offsets and embeddings")
121
122    n_channels = len(offsets)
123    shape = embeddings.shape[1:]
124    affinities = np.zeros((n_channels,) + shape, dtype="float32")
125
126    for cid, off in enumerate(offsets):
127        # we need to  shift in the other direction in order to
128        # get the correct offset
129        # also, we need to add a zero shift in the first axis
130        shift_off = [0] + [-o for o in off]
131        # we could also shift via np.pad and slicing
132        shifted = shift(embeddings, shift_off, order=0, prefilter=False)
133        affs = _embeddings_to_probabilities(embeddings, shifted, delta, embedding_axis=0)
134        affinities[cid] = affs
135
136    if invert:
137        affinities = 1. - affinities
138
139    return affinities

Convert pixel embeddings to affinities.

Computes the affinity according to the formula a_ij = max((2 * delta - ||x_i - x_j||) / 2 * delta, 0) ** 2, where delta is the push force used in training the embeddings. Introduced in "Learning Dense Voxel Embeddings for 3D Neuron Reconstruction": https://arxiv.org/pdf/1909.09872.pdf

Arguments:
  • embeddings: The pixel embeddings.
  • offsets: The offset vectors for which to compute affinities.
  • delta: The delta factor used in the push force when training the embeddings.
  • invert: Whether to invert the affinites.
Returns:

The affinity values.

def segment_hdbscan( embeddings: numpy.ndarray, min_size: int, eps: float, remove_largest: bool, n_jobs: int = 1) -> numpy.ndarray:
177def segment_hdbscan(
178    embeddings: np.ndarray,
179    min_size: int, eps: float,
180    remove_largest: bool,
181    n_jobs: int = 1,
182) -> np.ndarray:
183    """Compute a segmentation by clustering pixel emeddings with HDBSCAN.
184
185    Args:
186        embeddings: The pixel embeddings.
187        min_size: The minimal segment size.
188        eps: Epsilon factor for HDBSCAN.
189        remove_largest: Whether to remove the largest (=background) object.
190        n_jobs: The number of jobs for parallelizing HDBSCAN.
191
192    Returns:
193        The segmentation.
194    """
195    assert hdbscan is not None, "Needs hdbscan library"
196    with threadpool_limits(limits=n_jobs):
197        clustering = hdbscan.HDBSCAN(
198            min_cluster_size=min_size, cluster_selection_epsilon=eps, core_dist_n_jobs=n_jobs
199        )
200        result = _cluster(embeddings, clustering, remove_largest=remove_largest).astype("uint64")
201    return result

Compute a segmentation by clustering pixel emeddings with HDBSCAN.

Arguments:
  • embeddings: The pixel embeddings.
  • min_size: The minimal segment size.
  • eps: Epsilon factor for HDBSCAN.
  • remove_largest: Whether to remove the largest (=background) object.
  • n_jobs: The number of jobs for parallelizing HDBSCAN.
Returns:

The segmentation.

def segment_mean_shift( embeddings: numpy.ndarray, bandwidth: float, n_jobs: int = 1) -> numpy.ndarray:
204def segment_mean_shift(embeddings: np.ndarray, bandwidth: float, n_jobs: int = 1) -> np.ndarray:
205    """Compute a segmentation by clustering pixel emeddings with mean shift.
206
207    Args:
208        embeddings: The pixel embeddings.
209        bandwidth: The bandwidth parameter for the mean shift algorithm.
210        n_jobs: The number of jobs for parallelizing MeanShift.
211
212    Returns:
213        The segmentation.
214    """
215    with threadpool_limits(limits=n_jobs):
216        clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
217        result = _cluster(embeddings, clustering).astype("uint64")
218    return result

Compute a segmentation by clustering pixel emeddings with mean shift.

Arguments:
  • embeddings: The pixel embeddings.
  • bandwidth: The bandwidth parameter for the mean shift algorithm.
  • n_jobs: The number of jobs for parallelizing MeanShift.
Returns:

The segmentation.

def segment_consistency( embeddings1: numpy.ndarray, embeddings2: numpy.ndarray, bandwidth: float, iou_threshold: float, num_anchors: int, skip_zero: bool = True, n_jobs: int = 1) -> numpy.ndarray:
221def segment_consistency(
222    embeddings1: np.ndarray,
223    embeddings2: np.ndarray,
224    bandwidth: float,
225    iou_threshold: float,
226    num_anchors: int,
227    skip_zero: bool = True,
228    n_jobs: int = 1
229) -> np.ndarray:
230    """Compute a segmentation by clustering pixel emeddings via mean shift and consistency.
231
232    First, the segmentation is computed using mean shift. Then, for each instance in this
233    segmentation the corresponding instance mask is derived from the second set of embeddings.
234    Masks that have a low IOU with the corresponding instance mask are removed.
235
236    Args:
237        embeddings1: The first set of pixel embeddings, used for mean shift clustering.
238        embeddings2: The second set of pixel embeddings, used for consistency.
239        bandwidth: The bandwidth parameter for the mean shift algorithm.
240        iou_threshold: The threshold for consistency filtering.
241        num_anchors: The number of anchors for computing the instance masks for consistency.
242        skip_zero: Whether to skip the background label.
243        n_jobs: The number of jobs for parallelizing MeanShift.
244
245    Returns:
246        The segmentation.
247    """
248    def _iou(gt, seg):
249        epsilon = 1e-5
250        inter = (gt & seg).sum()
251        union = (gt | seg).sum()
252
253        iou = (inter + epsilon) / (union + epsilon)
254        return iou
255
256    with threadpool_limits(limits=n_jobs):
257        clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
258        clusters = _cluster(embeddings1, clustering)
259
260    for label_id in np.unique(clusters):
261        if label_id == 0 and skip_zero:
262            continue
263
264        mask = clusters == label_id
265        iou_table = []
266        # FIXME: make it work for 3d
267        y, x = np.nonzero(mask)
268        for _ in range(num_anchors):
269            ind = np.random.randint(len(y))
270            # get random embedding anchor from emb-g
271            anchor_emb = embeddings2[:, y[ind], x[ind]]
272            # add necessary singleton dims
273            anchor_emb = anchor_emb[:, None, None]
274            # compute the instance mask from emb2
275            inst_mask = np.linalg.norm(embeddings2 - anchor_emb, axis=0) < bandwidth
276            iou_table.append(_iou(mask, inst_mask))
277        # choose final IoU as a median
278        final_iou = np.median(iou_table)
279
280        if final_iou < iou_threshold:
281            clusters[mask] = 0
282
283    return clusters.astype("uint64")

Compute a segmentation by clustering pixel emeddings via mean shift and consistency.

First, the segmentation is computed using mean shift. Then, for each instance in this segmentation the corresponding instance mask is derived from the second set of embeddings. Masks that have a low IOU with the corresponding instance mask are removed.

Arguments:
  • embeddings1: The first set of pixel embeddings, used for mean shift clustering.
  • embeddings2: The second set of pixel embeddings, used for consistency.
  • bandwidth: The bandwidth parameter for the mean shift algorithm.
  • iou_threshold: The threshold for consistency filtering.
  • num_anchors: The number of anchors for computing the instance masks for consistency.
  • skip_zero: Whether to skip the background label.
  • n_jobs: The number of jobs for parallelizing MeanShift.
Returns:

The segmentation.

def segment_embeddings_mws( embeddings: numpy.ndarray, distance_type: str, offsets: List[List[int]], bias: float = 0.0, strides: List[int] = None, weight_function: Optional[<built-in function callable>] = None, mask: Optional[numpy.ndarray] = None) -> numpy.ndarray:
428def segment_embeddings_mws(
429    embeddings: np.ndarray,
430    distance_type: str,
431    offsets: List[List[int]],
432    bias: float = 0.0,
433    strides: List[int] = None,
434    weight_function: Optional[callable] = None,
435    mask: Optional[np.ndarray] = None,
436) -> np.ndarray:
437    """Compute a segmentation by computing a mutex watershed based on pixel emeddings.
438
439    Args:
440        embeddings: The pixel embeddings.
441        distance_type: The distance type for deriving affinities from embeddings.
442        offsets: The affinity offsets.
443        bias: Additional bias factor to apply to the affinities.
444            This can be used to reduce under-segmentation (positive value) or over-segmentation (negative value).
445        strides: The strides for sub-sampling repulsive mutex edges.
446        weight_function: Optional function for weighting the affinity values.
447        mask: Mask to ignore in the segmentation.
448
449    Returns:
450        The segmentation.
451    """
452    g, costs, mutex_uvs, mutex_costs = _embeddings_to_problem(
453        embeddings, distance_type, beta=None,
454        offsets=offsets, strides=strides,
455        weight_function=weight_function,
456        mask=mask
457    )
458    if bias > 0:
459        mutex_costs += bias
460    uvs = g.uvIds()
461    seg = mutex_watershed_clustering(uvs, mutex_uvs, costs, mutex_costs).reshape(embeddings.shape[1:])
462    if mask is not None:
463        seg = _ensure_mask_is_zero(seg, mask)
464    return seg

Compute a segmentation by computing a mutex watershed based on pixel emeddings.

Arguments:
  • embeddings: The pixel embeddings.
  • distance_type: The distance type for deriving affinities from embeddings.
  • offsets: The affinity offsets.
  • bias: Additional bias factor to apply to the affinities. This can be used to reduce under-segmentation (positive value) or over-segmentation (negative value).
  • strides: The strides for sub-sampling repulsive mutex edges.
  • weight_function: Optional function for weighting the affinity values.
  • mask: Mask to ignore in the segmentation.
Returns:

The segmentation.