elf.segmentation.embeddings

  1# IMPORTANT do threadctl import first (before numpy imports)
  2from threadpoolctl import threadpool_limits
  3from typing import List, Optional
  4
  5import numpy as np
  6try:
  7    import hdbscan
  8except ImportError:
  9    hdbscan = None
 10
 11from scipy.ndimage import shift
 12from sklearn.cluster import MeanShift
 13from sklearn.decomposition import PCA
 14
 15from .features import (_region_features,
 16                       compute_grid_graph,
 17                       compute_grid_graph_affinity_features,
 18                       compute_grid_graph_image_features)
 19from .multicut import compute_edge_costs
 20from .mutex_watershed import mutex_watershed_clustering
 21
 22#
 23# utils
 24#
 25
 26
 27def embedding_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray:
 28    """Compute PCA of per-pixel embeddings.
 29
 30    Args:
 31        embeddings: The per-pixel embeddings.
 32        n_components: The number of PCA components.
 33        as_rgb: Whether to reshape the output so that it can be displayed as RGB image.
 34
 35    Returns:
 36        The PCA of the embeddings.
 37    """
 38    if as_rgb and n_components != 3:
 39        raise ValueError("")
 40
 41    pca = PCA(n_components=n_components)
 42    embed_dim = embeddings.shape[0]
 43    shape = embeddings.shape[1:]
 44
 45    embed_flat = embeddings.reshape(embed_dim, -1).T
 46    embed_flat = pca.fit_transform(embed_flat).T
 47    embed_flat = embed_flat.reshape((n_components,) + shape)
 48
 49    if as_rgb:
 50        embed_flat = 255 * (embed_flat - embed_flat.min()) / np.ptp(embed_flat)
 51        embed_flat = embed_flat.astype("uint8")
 52
 53    return embed_flat
 54
 55
 56def _embeddings_to_probabilities(embed1, embed2, delta, embedding_axis):
 57    probs = (2 * delta - np.linalg.norm(embed1 - embed2, axis=embedding_axis)) / (2 * delta)
 58    probs = np.maximum(probs, 0) ** 2
 59    return probs
 60
 61
 62def edge_probabilities_from_embeddings(
 63    embeddings: np.ndarray, segmentation: np.ndarray, rag, delta: float
 64) -> np.ndarray:
 65    """Derive edge probabilities from pixel embeddings.
 66
 67    Args:
 68        embeddings: The pixel embeddings.
 69        segmentation: The segmentation.
 70        rag: The region adjacency graph derived from the segmentation.
 71        delta: The delta factor used in the push force when training the embeddings.
 72
 73    Returns:
 74        The edge probabilties.
 75    """
 76    n_nodes = rag.number_of_nodes
 77    embed_dim = embeddings.shape[0]
 78
 79    segmentation = segmentation.astype("uint32")
 80    mean_embeddings = np.zeros((n_nodes, embed_dim), dtype="float32")
 81    for cid in range(embed_dim):
 82        mean_embed = _region_features(embeddings[cid], segmentation, ["mean"])["mean"]
 83        mean_embeddings[:, cid] = mean_embed
 84
 85    uv_ids = rag.uv_ids()
 86    embed_u = mean_embeddings[uv_ids[:, 0]]
 87    embed_v = mean_embeddings[uv_ids[:, 1]]
 88    edge_probabilities = 1. - _embeddings_to_probabilities(embed_u, embed_v, delta, embedding_axis=1)
 89    return edge_probabilities
 90
 91
 92# Could probably be implemented more efficiently with shift kernels instead of explicit call to shift.
 93# (or implement in C++ to save memory)
 94def embeddings_to_affinities(
 95    embeddings: np.ndarray,
 96    offsets: List[List[int]],
 97    delta: float,
 98    invert: bool = False,
 99) -> np.ndarray:
100    """Convert pixel embeddings to affinities.
101
102    Computes the affinity according to the formula
103    a_ij = max((2 * delta - ||x_i - x_j||) / 2 * delta, 0) ** 2,
104    where delta is the push force used in training the embeddings.
105    Introduced in "Learning Dense Voxel Embeddings for 3D Neuron Reconstruction":
106    https://arxiv.org/pdf/1909.09872.pdf
107
108    Args:
109        embeddings: The pixel embeddings.
110        offsets: The offset vectors for which to compute affinities.
111        delta: The delta factor used in the push force when training the embeddings.
112        invert: Whether to invert the affinites.
113
114    Returns:
115        The affinity values.
116    """
117    ndim = embeddings.ndim - 1
118    if not all(len(off) == ndim for off in offsets):
119        raise ValueError("Incosistent dimension of offsets and embeddings")
120
121    n_channels = len(offsets)
122    shape = embeddings.shape[1:]
123    affinities = np.zeros((n_channels,) + shape, dtype="float32")
124
125    for cid, off in enumerate(offsets):
126        # we need to  shift in the other direction in order to
127        # get the correct offset
128        # also, we need to add a zero shift in the first axis
129        shift_off = [0] + [-o for o in off]
130        # we could also shift via np.pad and slicing
131        shifted = shift(embeddings, shift_off, order=0, prefilter=False)
132        affs = _embeddings_to_probabilities(embeddings, shifted, delta, embedding_axis=0)
133        affinities[cid] = affs
134
135    if invert:
136        affinities = 1. - affinities
137
138    return affinities
139
140
141#
142# density based segmentation
143#
144
145
146def _cluster(embeddings, clustering_alg, semantic_mask=None, remove_largest=False):
147    output_shape = embeddings.shape[1:]
148    # reshape (E, D, H, W) -> (E, D * H * W) and transpose -> (D * H * W, E)
149    flattened_embeddings = embeddings.reshape(embeddings.shape[0], -1).transpose()
150
151    result = np.zeros(flattened_embeddings.shape[0])
152
153    if semantic_mask is not None:
154        flattened_mask = semantic_mask.reshape(-1)
155        assert flattened_mask.shape[0] == flattened_embeddings.shape[0]
156    else:
157        flattened_mask = np.ones(flattened_embeddings.shape[0])
158
159    if flattened_mask.sum() == 0:
160        # return zeros for empty masks
161        return result.reshape(output_shape)
162
163    # cluster only within the foreground mask
164    clusters = clustering_alg.fit_predict(flattened_embeddings[flattened_mask == 1])
165    # always increase the labels by 1 cause clustering results start from 0 and we may loose one object
166    result[flattened_mask == 1] = clusters + 1
167
168    if remove_largest:
169        # set largest object to 0-label
170        ids, counts = np.unique(result, return_counts=True)
171        result[ids[np.argmax(counts)] == result] = 0
172
173    return result.reshape(output_shape)
174
175
176def segment_hdbscan(
177    embeddings: np.ndarray,
178    min_size: int, eps: float,
179    remove_largest: bool,
180    n_jobs: int = 1,
181) -> np.ndarray:
182    """Compute a segmentation by clustering pixel emeddings with HDBSCAN.
183
184    Args:
185        embeddings: The pixel embeddings.
186        min_size: The minimal segment size.
187        eps: Epsilon factor for HDBSCAN.
188        remove_largest: Whether to remove the largest (=background) object.
189        n_jobs: The number of jobs for parallelizing HDBSCAN.
190
191    Returns:
192        The segmentation.
193    """
194    assert hdbscan is not None, "Needs hdbscan library"
195    with threadpool_limits(limits=n_jobs):
196        clustering = hdbscan.HDBSCAN(
197            min_cluster_size=min_size, cluster_selection_epsilon=eps, core_dist_n_jobs=n_jobs
198        )
199        result = _cluster(embeddings, clustering, remove_largest=remove_largest).astype("uint64")
200    return result
201
202
203def segment_mean_shift(embeddings: np.ndarray, bandwidth: float, n_jobs: int = 1) -> np.ndarray:
204    """Compute a segmentation by clustering pixel emeddings with mean shift.
205
206    Args:
207        embeddings: The pixel embeddings.
208        bandwidth: The bandwidth parameter for the mean shift algorithm.
209        n_jobs: The number of jobs for parallelizing MeanShift.
210
211    Returns:
212        The segmentation.
213    """
214    with threadpool_limits(limits=n_jobs):
215        clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
216        result = _cluster(embeddings, clustering).astype("uint64")
217    return result
218
219
220def segment_consistency(
221    embeddings1: np.ndarray,
222    embeddings2: np.ndarray,
223    bandwidth: float,
224    iou_threshold: float,
225    num_anchors: int,
226    skip_zero: bool = True,
227    n_jobs: int = 1
228) -> np.ndarray:
229    """Compute a segmentation by clustering pixel emeddings via mean shift and consistency.
230
231    First, the segmentation is computed using mean shift. Then, for each instance in this
232    segmentation the corresponding instance mask is derived from the second set of embeddings.
233    Masks that have a low IOU with the corresponding instance mask are removed.
234
235    Args:
236        embeddings1: The first set of pixel embeddings, used for mean shift clustering.
237        embeddings2: The second set of pixel embeddings, used for consistency.
238        bandwidth: The bandwidth parameter for the mean shift algorithm.
239        iou_threshold: The threshold for consistency filtering.
240        num_anchors: The number of anchors for computing the instance masks for consistency.
241        skip_zero: Whether to skip the background label.
242        n_jobs: The number of jobs for parallelizing MeanShift.
243
244    Returns:
245        The segmentation.
246    """
247    def _iou(gt, seg):
248        epsilon = 1e-5
249        inter = (gt & seg).sum()
250        union = (gt | seg).sum()
251
252        iou = (inter + epsilon) / (union + epsilon)
253        return iou
254
255    with threadpool_limits(limits=n_jobs):
256        clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
257        clusters = _cluster(embeddings1, clustering)
258
259    for label_id in np.unique(clusters):
260        if label_id == 0 and skip_zero:
261            continue
262
263        mask = clusters == label_id
264        iou_table = []
265        # FIXME: make it work for 3d
266        y, x = np.nonzero(mask)
267        for _ in range(num_anchors):
268            ind = np.random.randint(len(y))
269            # get random embedding anchor from emb-g
270            anchor_emb = embeddings2[:, y[ind], x[ind]]
271            # add necessary singleton dims
272            anchor_emb = anchor_emb[:, None, None]
273            # compute the instance mask from emb2
274            inst_mask = np.linalg.norm(embeddings2 - anchor_emb, axis=0) < bandwidth
275            iou_table.append(_iou(mask, inst_mask))
276        # choose final IoU as a median
277        final_iou = np.median(iou_table)
278
279        if final_iou < iou_threshold:
280            clusters[mask] = 0
281
282    return clusters.astype("uint64")
283
284
285#
286# affinity based segmentation
287#
288
289
290def _ensure_mask_is_zero(seg, mask):
291    inv_mask = ~mask
292    mask_id = seg[inv_mask][0]
293    if mask_id == 0:
294        return seg
295
296    seg_ids = np.unique(seg[mask])
297    if 0 in seg_ids:
298        seg[seg == 0] = mask_id
299    seg[inv_mask] = 0
300
301    return seg
302
303
304def _get_lr_offsets(offsets):
305    lr_offsets = [
306        off for off in offsets if np.sum(np.abs(off)) > 1
307    ]
308    return lr_offsets
309
310
311def _apply_mask(mask, g, weights, lr_edges, lr_weights):
312    assert np.dtype(mask.dtype) == np.dtype("bool")
313    # bic.graph.GridGraph2D/3D uses row-major node ids (np.arange(prod(shape)).reshape(shape)),
314    # so we rebuild that without a per-graph projection method.
315    shape = tuple(g.shape)
316    node_ids = np.arange(np.prod(shape), dtype="uint64").reshape(shape)
317    assert node_ids.shape == mask.shape == shape, f"{node_ids.shape}, {mask.shape}, {shape}"
318    masked_ids = node_ids[~mask]
319
320    # local edges:
321    # - set edges that connect masked nodes to max attractive
322    # - set edges that connect masked and non-masked nodes to max repulsive
323    local_edge_state = np.isin(g.uv_ids(), masked_ids).sum(axis=1)
324    local_masked_edges = local_edge_state == 2
325    local_transition_edges = local_edge_state == 1
326    weights[local_masked_edges] = 0.0
327    weights[local_transition_edges] = 1.0
328
329    # lr edges:
330    # - remove edges that connect masked nodes
331    # - set all edges that connect masked and non-masked nodes to max repulsive
332    lr_edge_state = np.isin(lr_edges, masked_ids).sum(axis=1)
333    lr_keep_edges = lr_edge_state != 2
334
335    lr_edges, lr_weights, lr_edge_state = (lr_edges[lr_keep_edges],
336                                           lr_weights[lr_keep_edges],
337                                           lr_edge_state[lr_keep_edges])
338    lr_transition_edges = lr_edge_state == 1
339    lr_weights[lr_transition_edges] = 1.0
340
341    return weights, lr_edges, lr_weights
342
343
344# weight functions may normalize the weight values based on some statistics
345# calculated for all weights. It's important to apply this weighting on a per offset channel
346# basis, because long-range weights may be much larger than the short range weights.
347def _process_weights(g, edges, weights, weight_function, beta,
348                     offsets=None, strides=None, randomize_strides=None):
349
350    def apply_weight_function():
351        nonlocal weights
352        edge_ids = g.project_edge_ids_to_pixels()
353        invalid_edges = edge_ids == -1
354        edge_ids[invalid_edges] = 0
355        weights = weights[edge_ids]
356        weights[invalid_edges] = 0
357        for chan_id, weightc in enumerate(weights):
358            weights[chan_id] = weight_function(weightc)
359        edges, weights = compute_grid_graph_affinity_features(g, weights)
360        assert len(weights) == g.number_of_edges
361        return edges, weights
362
363    def apply_weight_function_lr():
364        nonlocal weights
365        edge_ids, _ = g.project_edge_ids_to_pixels_with_offsets(np.asarray(offsets))
366        invalid_edges = edge_ids == -1
367        edge_ids[invalid_edges] = 0
368        weights = weights[edge_ids]
369        weights[invalid_edges] = 0
370        for chan_id, weightc in enumerate(weights):
371            weights[chan_id] = weight_function(weightc)
372        edges, weights = compute_grid_graph_affinity_features(
373            g, weights, offsets=offsets,
374            strides=strides, randomize_strides=randomize_strides,
375        )
376        return edges, weights
377
378    apply_weight = weight_function is not None
379    if apply_weight and offsets is None:
380        edges, weights = apply_weight_function()
381    elif apply_weight and offsets is not None:
382        edges, weights = apply_weight_function_lr()
383
384    if beta is not None:
385        weights = compute_edge_costs(weights, beta=beta)
386
387    return edges, weights
388
389
390def _embeddings_to_problem(embed, distance_type, beta=None,
391                           offsets=None, strides=None, weight_function=None,
392                           mask=None):
393    im_shape = embed.shape[1:]
394    g = compute_grid_graph(im_shape)
395    _, weights = compute_grid_graph_image_features(g, embed, distance_type)
396    _, weights = _process_weights(g, None, weights, weight_function, beta)
397    if offsets is None:
398        return g, weights
399
400    lr_offsets = _get_lr_offsets(offsets)
401
402    # we only compute with strides if we are not applying a weight function, otherwise
403    # strides are applied later!
404    strides_, randomize_ = (strides, True) if weight_function is None else (None, False)
405
406    lr_edges, lr_weights = compute_grid_graph_image_features(
407        g, embed, distance_type, offsets=lr_offsets, strides=strides_, randomize_strides=randomize_
408    )
409
410    if mask is not None:
411        weights, lr_edges, lr_weights = _apply_mask(mask, g, weights, lr_edges, lr_weights)
412
413    lr_edges, lr_weights = _process_weights(g, lr_edges, lr_weights, weight_function, beta, offsets=lr_offsets,
414                                            strides=strides, randomize_strides=randomize_)
415    return g, weights, lr_edges, lr_weights
416
417
418# weight function based on the seung paper, using the push delta
419# of the discriminative loss term.
420def discriminative_loss_weight(dist, delta):
421    """@private
422    """
423    dist = (2 * delta - dist) / (2 * delta)
424    dist = 1. - np.maximum(dist, 0) ** 2
425    return dist
426
427
428def segment_embeddings_mws(
429    embeddings: np.ndarray,
430    distance_type: str,
431    offsets: List[List[int]],
432    bias: float = 0.0,
433    strides: List[int] = None,
434    weight_function: Optional[callable] = None,
435    mask: Optional[np.ndarray] = None,
436) -> np.ndarray:
437    """Compute a segmentation by computing a mutex watershed based on pixel emeddings.
438
439    Args:
440        embeddings: The pixel embeddings.
441        distance_type: The distance type for deriving affinities from embeddings.
442        offsets: The affinity offsets.
443        bias: Additional bias factor to apply to the affinities.
444            This can be used to reduce under-segmentation (positive value) or over-segmentation (negative value).
445        strides: The strides for sub-sampling repulsive mutex edges.
446        weight_function: Optional function for weighting the affinity values.
447        mask: Mask to ignore in the segmentation.
448
449    Returns:
450        The segmentation.
451    """
452    g, costs, mutex_uvs, mutex_costs = _embeddings_to_problem(
453        embeddings, distance_type, beta=None,
454        offsets=offsets, strides=strides,
455        weight_function=weight_function,
456        mask=mask
457    )
458    if bias > 0:
459        mutex_costs += bias
460    uvs = g.uv_ids()
461    seg = mutex_watershed_clustering(uvs, mutex_uvs, costs, mutex_costs).reshape(embeddings.shape[1:])
462    if mask is not None:
463        seg = _ensure_mask_is_zero(seg, mask)
464    return seg
def embedding_pca( embeddings: numpy.ndarray, n_components: int = 3, as_rgb: bool = True) -> numpy.ndarray:
28def embedding_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray:
29    """Compute PCA of per-pixel embeddings.
30
31    Args:
32        embeddings: The per-pixel embeddings.
33        n_components: The number of PCA components.
34        as_rgb: Whether to reshape the output so that it can be displayed as RGB image.
35
36    Returns:
37        The PCA of the embeddings.
38    """
39    if as_rgb and n_components != 3:
40        raise ValueError("")
41
42    pca = PCA(n_components=n_components)
43    embed_dim = embeddings.shape[0]
44    shape = embeddings.shape[1:]
45
46    embed_flat = embeddings.reshape(embed_dim, -1).T
47    embed_flat = pca.fit_transform(embed_flat).T
48    embed_flat = embed_flat.reshape((n_components,) + shape)
49
50    if as_rgb:
51        embed_flat = 255 * (embed_flat - embed_flat.min()) / np.ptp(embed_flat)
52        embed_flat = embed_flat.astype("uint8")
53
54    return embed_flat

Compute PCA of per-pixel embeddings.

Arguments:
  • embeddings: The per-pixel embeddings.
  • n_components: The number of PCA components.
  • as_rgb: Whether to reshape the output so that it can be displayed as RGB image.
Returns:

The PCA of the embeddings.

def edge_probabilities_from_embeddings( embeddings: numpy.ndarray, segmentation: numpy.ndarray, rag, delta: float) -> numpy.ndarray:
63def edge_probabilities_from_embeddings(
64    embeddings: np.ndarray, segmentation: np.ndarray, rag, delta: float
65) -> np.ndarray:
66    """Derive edge probabilities from pixel embeddings.
67
68    Args:
69        embeddings: The pixel embeddings.
70        segmentation: The segmentation.
71        rag: The region adjacency graph derived from the segmentation.
72        delta: The delta factor used in the push force when training the embeddings.
73
74    Returns:
75        The edge probabilties.
76    """
77    n_nodes = rag.number_of_nodes
78    embed_dim = embeddings.shape[0]
79
80    segmentation = segmentation.astype("uint32")
81    mean_embeddings = np.zeros((n_nodes, embed_dim), dtype="float32")
82    for cid in range(embed_dim):
83        mean_embed = _region_features(embeddings[cid], segmentation, ["mean"])["mean"]
84        mean_embeddings[:, cid] = mean_embed
85
86    uv_ids = rag.uv_ids()
87    embed_u = mean_embeddings[uv_ids[:, 0]]
88    embed_v = mean_embeddings[uv_ids[:, 1]]
89    edge_probabilities = 1. - _embeddings_to_probabilities(embed_u, embed_v, delta, embedding_axis=1)
90    return edge_probabilities

Derive edge probabilities from pixel embeddings.

Arguments:
  • embeddings: The pixel embeddings.
  • segmentation: The segmentation.
  • rag: The region adjacency graph derived from the segmentation.
  • delta: The delta factor used in the push force when training the embeddings.
Returns:

The edge probabilties.

def embeddings_to_affinities( embeddings: numpy.ndarray, offsets: List[List[int]], delta: float, invert: bool = False) -> numpy.ndarray:
 95def embeddings_to_affinities(
 96    embeddings: np.ndarray,
 97    offsets: List[List[int]],
 98    delta: float,
 99    invert: bool = False,
100) -> np.ndarray:
101    """Convert pixel embeddings to affinities.
102
103    Computes the affinity according to the formula
104    a_ij = max((2 * delta - ||x_i - x_j||) / 2 * delta, 0) ** 2,
105    where delta is the push force used in training the embeddings.
106    Introduced in "Learning Dense Voxel Embeddings for 3D Neuron Reconstruction":
107    https://arxiv.org/pdf/1909.09872.pdf
108
109    Args:
110        embeddings: The pixel embeddings.
111        offsets: The offset vectors for which to compute affinities.
112        delta: The delta factor used in the push force when training the embeddings.
113        invert: Whether to invert the affinites.
114
115    Returns:
116        The affinity values.
117    """
118    ndim = embeddings.ndim - 1
119    if not all(len(off) == ndim for off in offsets):
120        raise ValueError("Incosistent dimension of offsets and embeddings")
121
122    n_channels = len(offsets)
123    shape = embeddings.shape[1:]
124    affinities = np.zeros((n_channels,) + shape, dtype="float32")
125
126    for cid, off in enumerate(offsets):
127        # we need to  shift in the other direction in order to
128        # get the correct offset
129        # also, we need to add a zero shift in the first axis
130        shift_off = [0] + [-o for o in off]
131        # we could also shift via np.pad and slicing
132        shifted = shift(embeddings, shift_off, order=0, prefilter=False)
133        affs = _embeddings_to_probabilities(embeddings, shifted, delta, embedding_axis=0)
134        affinities[cid] = affs
135
136    if invert:
137        affinities = 1. - affinities
138
139    return affinities

Convert pixel embeddings to affinities.

Computes the affinity according to the formula a_ij = max((2 * delta - ||x_i - x_j||) / 2 * delta, 0) ** 2, where delta is the push force used in training the embeddings. Introduced in "Learning Dense Voxel Embeddings for 3D Neuron Reconstruction": https://arxiv.org/pdf/1909.09872.pdf

Arguments:
  • embeddings: The pixel embeddings.
  • offsets: The offset vectors for which to compute affinities.
  • delta: The delta factor used in the push force when training the embeddings.
  • invert: Whether to invert the affinites.
Returns:

The affinity values.

def segment_hdbscan( embeddings: numpy.ndarray, min_size: int, eps: float, remove_largest: bool, n_jobs: int = 1) -> numpy.ndarray:
177def segment_hdbscan(
178    embeddings: np.ndarray,
179    min_size: int, eps: float,
180    remove_largest: bool,
181    n_jobs: int = 1,
182) -> np.ndarray:
183    """Compute a segmentation by clustering pixel emeddings with HDBSCAN.
184
185    Args:
186        embeddings: The pixel embeddings.
187        min_size: The minimal segment size.
188        eps: Epsilon factor for HDBSCAN.
189        remove_largest: Whether to remove the largest (=background) object.
190        n_jobs: The number of jobs for parallelizing HDBSCAN.
191
192    Returns:
193        The segmentation.
194    """
195    assert hdbscan is not None, "Needs hdbscan library"
196    with threadpool_limits(limits=n_jobs):
197        clustering = hdbscan.HDBSCAN(
198            min_cluster_size=min_size, cluster_selection_epsilon=eps, core_dist_n_jobs=n_jobs
199        )
200        result = _cluster(embeddings, clustering, remove_largest=remove_largest).astype("uint64")
201    return result

Compute a segmentation by clustering pixel emeddings with HDBSCAN.

Arguments:
  • embeddings: The pixel embeddings.
  • min_size: The minimal segment size.
  • eps: Epsilon factor for HDBSCAN.
  • remove_largest: Whether to remove the largest (=background) object.
  • n_jobs: The number of jobs for parallelizing HDBSCAN.
Returns:

The segmentation.

def segment_mean_shift( embeddings: numpy.ndarray, bandwidth: float, n_jobs: int = 1) -> numpy.ndarray:
204def segment_mean_shift(embeddings: np.ndarray, bandwidth: float, n_jobs: int = 1) -> np.ndarray:
205    """Compute a segmentation by clustering pixel emeddings with mean shift.
206
207    Args:
208        embeddings: The pixel embeddings.
209        bandwidth: The bandwidth parameter for the mean shift algorithm.
210        n_jobs: The number of jobs for parallelizing MeanShift.
211
212    Returns:
213        The segmentation.
214    """
215    with threadpool_limits(limits=n_jobs):
216        clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
217        result = _cluster(embeddings, clustering).astype("uint64")
218    return result

Compute a segmentation by clustering pixel emeddings with mean shift.

Arguments:
  • embeddings: The pixel embeddings.
  • bandwidth: The bandwidth parameter for the mean shift algorithm.
  • n_jobs: The number of jobs for parallelizing MeanShift.
Returns:

The segmentation.

def segment_consistency( embeddings1: numpy.ndarray, embeddings2: numpy.ndarray, bandwidth: float, iou_threshold: float, num_anchors: int, skip_zero: bool = True, n_jobs: int = 1) -> numpy.ndarray:
221def segment_consistency(
222    embeddings1: np.ndarray,
223    embeddings2: np.ndarray,
224    bandwidth: float,
225    iou_threshold: float,
226    num_anchors: int,
227    skip_zero: bool = True,
228    n_jobs: int = 1
229) -> np.ndarray:
230    """Compute a segmentation by clustering pixel emeddings via mean shift and consistency.
231
232    First, the segmentation is computed using mean shift. Then, for each instance in this
233    segmentation the corresponding instance mask is derived from the second set of embeddings.
234    Masks that have a low IOU with the corresponding instance mask are removed.
235
236    Args:
237        embeddings1: The first set of pixel embeddings, used for mean shift clustering.
238        embeddings2: The second set of pixel embeddings, used for consistency.
239        bandwidth: The bandwidth parameter for the mean shift algorithm.
240        iou_threshold: The threshold for consistency filtering.
241        num_anchors: The number of anchors for computing the instance masks for consistency.
242        skip_zero: Whether to skip the background label.
243        n_jobs: The number of jobs for parallelizing MeanShift.
244
245    Returns:
246        The segmentation.
247    """
248    def _iou(gt, seg):
249        epsilon = 1e-5
250        inter = (gt & seg).sum()
251        union = (gt | seg).sum()
252
253        iou = (inter + epsilon) / (union + epsilon)
254        return iou
255
256    with threadpool_limits(limits=n_jobs):
257        clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
258        clusters = _cluster(embeddings1, clustering)
259
260    for label_id in np.unique(clusters):
261        if label_id == 0 and skip_zero:
262            continue
263
264        mask = clusters == label_id
265        iou_table = []
266        # FIXME: make it work for 3d
267        y, x = np.nonzero(mask)
268        for _ in range(num_anchors):
269            ind = np.random.randint(len(y))
270            # get random embedding anchor from emb-g
271            anchor_emb = embeddings2[:, y[ind], x[ind]]
272            # add necessary singleton dims
273            anchor_emb = anchor_emb[:, None, None]
274            # compute the instance mask from emb2
275            inst_mask = np.linalg.norm(embeddings2 - anchor_emb, axis=0) < bandwidth
276            iou_table.append(_iou(mask, inst_mask))
277        # choose final IoU as a median
278        final_iou = np.median(iou_table)
279
280        if final_iou < iou_threshold:
281            clusters[mask] = 0
282
283    return clusters.astype("uint64")

Compute a segmentation by clustering pixel emeddings via mean shift and consistency.

First, the segmentation is computed using mean shift. Then, for each instance in this segmentation the corresponding instance mask is derived from the second set of embeddings. Masks that have a low IOU with the corresponding instance mask are removed.

Arguments:
  • embeddings1: The first set of pixel embeddings, used for mean shift clustering.
  • embeddings2: The second set of pixel embeddings, used for consistency.
  • bandwidth: The bandwidth parameter for the mean shift algorithm.
  • iou_threshold: The threshold for consistency filtering.
  • num_anchors: The number of anchors for computing the instance masks for consistency.
  • skip_zero: Whether to skip the background label.
  • n_jobs: The number of jobs for parallelizing MeanShift.
Returns:

The segmentation.

def segment_embeddings_mws( embeddings: numpy.ndarray, distance_type: str, offsets: List[List[int]], bias: float = 0.0, strides: List[int] = None, weight_function: callable | None = None, mask: numpy.ndarray | None = None) -> numpy.ndarray:
429def segment_embeddings_mws(
430    embeddings: np.ndarray,
431    distance_type: str,
432    offsets: List[List[int]],
433    bias: float = 0.0,
434    strides: List[int] = None,
435    weight_function: Optional[callable] = None,
436    mask: Optional[np.ndarray] = None,
437) -> np.ndarray:
438    """Compute a segmentation by computing a mutex watershed based on pixel emeddings.
439
440    Args:
441        embeddings: The pixel embeddings.
442        distance_type: The distance type for deriving affinities from embeddings.
443        offsets: The affinity offsets.
444        bias: Additional bias factor to apply to the affinities.
445            This can be used to reduce under-segmentation (positive value) or over-segmentation (negative value).
446        strides: The strides for sub-sampling repulsive mutex edges.
447        weight_function: Optional function for weighting the affinity values.
448        mask: Mask to ignore in the segmentation.
449
450    Returns:
451        The segmentation.
452    """
453    g, costs, mutex_uvs, mutex_costs = _embeddings_to_problem(
454        embeddings, distance_type, beta=None,
455        offsets=offsets, strides=strides,
456        weight_function=weight_function,
457        mask=mask
458    )
459    if bias > 0:
460        mutex_costs += bias
461    uvs = g.uv_ids()
462    seg = mutex_watershed_clustering(uvs, mutex_uvs, costs, mutex_costs).reshape(embeddings.shape[1:])
463    if mask is not None:
464        seg = _ensure_mask_is_zero(seg, mask)
465    return seg

Compute a segmentation by computing a mutex watershed based on pixel emeddings.

Arguments:
  • embeddings: The pixel embeddings.
  • distance_type: The distance type for deriving affinities from embeddings.
  • offsets: The affinity offsets.
  • bias: Additional bias factor to apply to the affinities. This can be used to reduce under-segmentation (positive value) or over-segmentation (negative value).
  • strides: The strides for sub-sampling repulsive mutex edges.
  • weight_function: Optional function for weighting the affinity values.
  • mask: Mask to ignore in the segmentation.
Returns:

The segmentation.