elf.segmentation.embeddings
1# IMPORTANT do threadctl import first (before numpy imports) 2from threadpoolctl import threadpool_limits 3from typing import List, Optional 4 5import numpy as np 6import vigra 7try: 8 import hdbscan 9except ImportError: 10 hdbscan = None 11 12from scipy.ndimage import shift 13from sklearn.cluster import MeanShift 14from sklearn.decomposition import PCA 15 16from .features import (compute_grid_graph, 17 compute_grid_graph_affinity_features, 18 compute_grid_graph_image_features) 19from .multicut import compute_edge_costs 20from .mutex_watershed import mutex_watershed_clustering 21 22# 23# utils 24# 25 26 27def embedding_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray: 28 """Compute PCA of per-pixel embeddings. 29 30 Args: 31 embeddings: The per-pixel embeddings. 32 n_components: The number of PCA components. 33 as_rgb: Whether to reshape the output so that it can be displayed as RGB image. 34 35 Returns: 36 The PCA of the embeddings. 37 """ 38 if as_rgb and n_components != 3: 39 raise ValueError("") 40 41 pca = PCA(n_components=n_components) 42 embed_dim = embeddings.shape[0] 43 shape = embeddings.shape[1:] 44 45 embed_flat = embeddings.reshape(embed_dim, -1).T 46 embed_flat = pca.fit_transform(embed_flat).T 47 embed_flat = embed_flat.reshape((n_components,) + shape) 48 49 if as_rgb: 50 embed_flat = 255 * (embed_flat - embed_flat.min()) / np.ptp(embed_flat) 51 embed_flat = embed_flat.astype("uint8") 52 53 return embed_flat 54 55 56def _embeddings_to_probabilities(embed1, embed2, delta, embedding_axis): 57 probs = (2 * delta - np.linalg.norm(embed1 - embed2, axis=embedding_axis)) / (2 * delta) 58 probs = np.maximum(probs, 0) ** 2 59 return probs 60 61 62def edge_probabilities_from_embeddings( 63 embeddings: np.ndarray, segmentation: np.ndarray, rag, delta: float 64) -> np.ndarray: 65 """Derive edge probabilities from pixel embeddings. 66 67 Args: 68 embeddings: The pixel embeddings. 69 segmentation: The segmentation. 70 rag: The region adjacency graph derived from the segmentation. 71 delta: The delta factor used in the push force when training the embeddings. 72 73 Returns: 74 The edge probabilties. 75 """ 76 n_nodes = rag.numberOfNodes 77 embed_dim = embeddings.shape[0] 78 79 segmentation = segmentation.astype("uint32") 80 mean_embeddings = np.zeros((n_nodes, embed_dim), dtype="float32") 81 for cid in range(embed_dim): 82 mean_embed = vigra.analysis.extractRegionFeatures(embeddings[cid], segmentation, features=["mean"])["mean"] 83 mean_embeddings[:, cid] = mean_embed 84 85 uv_ids = rag.uvIds() 86 embed_u = mean_embeddings[uv_ids[:, 0]] 87 embed_v = mean_embeddings[uv_ids[:, 1]] 88 edge_probabilities = 1. - _embeddings_to_probabilities(embed_u, embed_v, delta, embedding_axis=1) 89 return edge_probabilities 90 91 92# Could probably be implemented more efficiently with shift kernels instead of explicit call to shift. 93# (or implement in C++ to save memory) 94def embeddings_to_affinities( 95 embeddings: np.ndarray, 96 offsets: List[List[int]], 97 delta: float, 98 invert: bool = False, 99) -> np.ndarray: 100 """Convert pixel embeddings to affinities. 101 102 Computes the affinity according to the formula 103 a_ij = max((2 * delta - ||x_i - x_j||) / 2 * delta, 0) ** 2, 104 where delta is the push force used in training the embeddings. 105 Introduced in "Learning Dense Voxel Embeddings for 3D Neuron Reconstruction": 106 https://arxiv.org/pdf/1909.09872.pdf 107 108 Args: 109 embeddings: The pixel embeddings. 110 offsets: The offset vectors for which to compute affinities. 111 delta: The delta factor used in the push force when training the embeddings. 112 invert: Whether to invert the affinites. 113 114 Returns: 115 The affinity values. 116 """ 117 ndim = embeddings.ndim - 1 118 if not all(len(off) == ndim for off in offsets): 119 raise ValueError("Incosistent dimension of offsets and embeddings") 120 121 n_channels = len(offsets) 122 shape = embeddings.shape[1:] 123 affinities = np.zeros((n_channels,) + shape, dtype="float32") 124 125 for cid, off in enumerate(offsets): 126 # we need to shift in the other direction in order to 127 # get the correct offset 128 # also, we need to add a zero shift in the first axis 129 shift_off = [0] + [-o for o in off] 130 # we could also shift via np.pad and slicing 131 shifted = shift(embeddings, shift_off, order=0, prefilter=False) 132 affs = _embeddings_to_probabilities(embeddings, shifted, delta, embedding_axis=0) 133 affinities[cid] = affs 134 135 if invert: 136 affinities = 1. - affinities 137 138 return affinities 139 140 141# 142# density based segmentation 143# 144 145 146def _cluster(embeddings, clustering_alg, semantic_mask=None, remove_largest=False): 147 output_shape = embeddings.shape[1:] 148 # reshape (E, D, H, W) -> (E, D * H * W) and transpose -> (D * H * W, E) 149 flattened_embeddings = embeddings.reshape(embeddings.shape[0], -1).transpose() 150 151 result = np.zeros(flattened_embeddings.shape[0]) 152 153 if semantic_mask is not None: 154 flattened_mask = semantic_mask.reshape(-1) 155 assert flattened_mask.shape[0] == flattened_embeddings.shape[0] 156 else: 157 flattened_mask = np.ones(flattened_embeddings.shape[0]) 158 159 if flattened_mask.sum() == 0: 160 # return zeros for empty masks 161 return result.reshape(output_shape) 162 163 # cluster only within the foreground mask 164 clusters = clustering_alg.fit_predict(flattened_embeddings[flattened_mask == 1]) 165 # always increase the labels by 1 cause clustering results start from 0 and we may loose one object 166 result[flattened_mask == 1] = clusters + 1 167 168 if remove_largest: 169 # set largest object to 0-label 170 ids, counts = np.unique(result, return_counts=True) 171 result[ids[np.argmax(counts)] == result] = 0 172 173 return result.reshape(output_shape) 174 175 176def segment_hdbscan( 177 embeddings: np.ndarray, 178 min_size: int, eps: float, 179 remove_largest: bool, 180 n_jobs: int = 1, 181) -> np.ndarray: 182 """Compute a segmentation by clustering pixel emeddings with HDBSCAN. 183 184 Args: 185 embeddings: The pixel embeddings. 186 min_size: The minimal segment size. 187 eps: Epsilon factor for HDBSCAN. 188 remove_largest: Whether to remove the largest (=background) object. 189 n_jobs: The number of jobs for parallelizing HDBSCAN. 190 191 Returns: 192 The segmentation. 193 """ 194 assert hdbscan is not None, "Needs hdbscan library" 195 with threadpool_limits(limits=n_jobs): 196 clustering = hdbscan.HDBSCAN( 197 min_cluster_size=min_size, cluster_selection_epsilon=eps, core_dist_n_jobs=n_jobs 198 ) 199 result = _cluster(embeddings, clustering, remove_largest=remove_largest).astype("uint64") 200 return result 201 202 203def segment_mean_shift(embeddings: np.ndarray, bandwidth: float, n_jobs: int = 1) -> np.ndarray: 204 """Compute a segmentation by clustering pixel emeddings with mean shift. 205 206 Args: 207 embeddings: The pixel embeddings. 208 bandwidth: The bandwidth parameter for the mean shift algorithm. 209 n_jobs: The number of jobs for parallelizing MeanShift. 210 211 Returns: 212 The segmentation. 213 """ 214 with threadpool_limits(limits=n_jobs): 215 clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs) 216 result = _cluster(embeddings, clustering).astype("uint64") 217 return result 218 219 220def segment_consistency( 221 embeddings1: np.ndarray, 222 embeddings2: np.ndarray, 223 bandwidth: float, 224 iou_threshold: float, 225 num_anchors: int, 226 skip_zero: bool = True, 227 n_jobs: int = 1 228) -> np.ndarray: 229 """Compute a segmentation by clustering pixel emeddings via mean shift and consistency. 230 231 First, the segmentation is computed using mean shift. Then, for each instance in this 232 segmentation the corresponding instance mask is derived from the second set of embeddings. 233 Masks that have a low IOU with the corresponding instance mask are removed. 234 235 Args: 236 embeddings1: The first set of pixel embeddings, used for mean shift clustering. 237 embeddings2: The second set of pixel embeddings, used for consistency. 238 bandwidth: The bandwidth parameter for the mean shift algorithm. 239 iou_threshold: The threshold for consistency filtering. 240 num_anchors: The number of anchors for computing the instance masks for consistency. 241 skip_zero: Whether to skip the background label. 242 n_jobs: The number of jobs for parallelizing MeanShift. 243 244 Returns: 245 The segmentation. 246 """ 247 def _iou(gt, seg): 248 epsilon = 1e-5 249 inter = (gt & seg).sum() 250 union = (gt | seg).sum() 251 252 iou = (inter + epsilon) / (union + epsilon) 253 return iou 254 255 with threadpool_limits(limits=n_jobs): 256 clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs) 257 clusters = _cluster(embeddings1, clustering) 258 259 for label_id in np.unique(clusters): 260 if label_id == 0 and skip_zero: 261 continue 262 263 mask = clusters == label_id 264 iou_table = [] 265 # FIXME: make it work for 3d 266 y, x = np.nonzero(mask) 267 for _ in range(num_anchors): 268 ind = np.random.randint(len(y)) 269 # get random embedding anchor from emb-g 270 anchor_emb = embeddings2[:, y[ind], x[ind]] 271 # add necessary singleton dims 272 anchor_emb = anchor_emb[:, None, None] 273 # compute the instance mask from emb2 274 inst_mask = np.linalg.norm(embeddings2 - anchor_emb, axis=0) < bandwidth 275 iou_table.append(_iou(mask, inst_mask)) 276 # choose final IoU as a median 277 final_iou = np.median(iou_table) 278 279 if final_iou < iou_threshold: 280 clusters[mask] = 0 281 282 return clusters.astype("uint64") 283 284 285# 286# affinity based segmentation 287# 288 289 290def _ensure_mask_is_zero(seg, mask): 291 inv_mask = ~mask 292 mask_id = seg[inv_mask][0] 293 if mask_id == 0: 294 return seg 295 296 seg_ids = np.unique(seg[mask]) 297 if 0 in seg_ids: 298 seg[seg == 0] = mask_id 299 seg[inv_mask] = 0 300 301 return seg 302 303 304def _get_lr_offsets(offsets): 305 lr_offsets = [ 306 off for off in offsets if np.sum(np.abs(off)) > 1 307 ] 308 return lr_offsets 309 310 311def _apply_mask(mask, g, weights, lr_edges, lr_weights): 312 assert np.dtype(mask.dtype) == np.dtype("bool") 313 node_ids = g.projectNodeIdsToPixels() 314 assert node_ids.shape == mask.shape == tuple(g.shape), f"{node_ids.shape}, {mask.shape}, {g.shape}" 315 masked_ids = node_ids[~mask] 316 317 # local edges: 318 # - set edges that connect masked nodes to max attractive 319 # - set edges that connect masked and non-masked nodes to max repulsive 320 local_edge_state = np.isin(g.uvIds(), masked_ids).sum(axis=1) 321 local_masked_edges = local_edge_state == 2 322 local_transition_edges = local_edge_state == 1 323 weights[local_masked_edges] = 0.0 324 weights[local_transition_edges] = 1.0 325 326 # lr edges: 327 # - remove edges that connect masked nodes 328 # - set all edges that connect masked and non-masked nodes to max repulsive 329 lr_edge_state = np.isin(lr_edges, masked_ids).sum(axis=1) 330 lr_keep_edges = lr_edge_state != 2 331 332 lr_edges, lr_weights, lr_edge_state = (lr_edges[lr_keep_edges], 333 lr_weights[lr_keep_edges], 334 lr_edge_state[lr_keep_edges]) 335 lr_transition_edges = lr_edge_state == 1 336 lr_weights[lr_transition_edges] = 1.0 337 338 return weights, lr_edges, lr_weights 339 340 341# weight functions may normalize the weight values based on some statistics 342# calculated for all weights. It's important to apply this weighting on a per offset channel 343# basis, because long-range weights may be much larger than the short range weights. 344def _process_weights(g, edges, weights, weight_function, beta, 345 offsets=None, strides=None, randomize_strides=None): 346 347 def apply_weight_function(): 348 nonlocal weights 349 edge_ids = g.projectEdgeIdsToPixels() 350 invalid_edges = edge_ids == -1 351 edge_ids[invalid_edges] = 0 352 weights = weights[edge_ids] 353 weights[invalid_edges] = 0 354 for chan_id, weightc in enumerate(weights): 355 weights[chan_id] = weight_function(weightc) 356 edges, weights = compute_grid_graph_affinity_features( 357 g, weights 358 ) 359 assert len(weights) == g.numberOfEdges 360 return edges, weights 361 362 def apply_weight_function_lr(): 363 nonlocal weights 364 edge_ids = g.projectEdgeIdsToPixelsWithOffsets(offsets) 365 invalid_edges = edge_ids == -1 366 edge_ids[invalid_edges] = 0 367 weights = weights[edge_ids] 368 weights[invalid_edges] = 0 369 for chan_id, weightc in enumerate(weights): 370 weights[chan_id] = weight_function(weightc) 371 edges, weights = compute_grid_graph_affinity_features( 372 g, weights, offsets=offsets, 373 strides=strides, randomize_strides=randomize_strides 374 ) 375 return edges, weights 376 377 apply_weight = weight_function is not None 378 if apply_weight and offsets is None: 379 edges, weights = apply_weight_function() 380 elif apply_weight and offsets is not None: 381 edges, weights = apply_weight_function_lr() 382 383 if beta is not None: 384 weights = compute_edge_costs(weights, beta=beta) 385 386 return edges, weights 387 388 389def _embeddings_to_problem(embed, distance_type, beta=None, 390 offsets=None, strides=None, weight_function=None, 391 mask=None): 392 im_shape = embed.shape[1:] 393 g = compute_grid_graph(im_shape) 394 _, weights = compute_grid_graph_image_features(g, embed, distance_type) 395 _, weights = _process_weights(g, None, weights, weight_function, beta) 396 if offsets is None: 397 return g, weights 398 399 lr_offsets = _get_lr_offsets(offsets) 400 401 # we only compute with strides if we are not applying a weight function, otherwise 402 # strides are applied later! 403 strides_, randomize_ = (strides, True) if weight_function is None else (None, False) 404 405 lr_edges, lr_weights = compute_grid_graph_image_features( 406 g, embed, distance_type, offsets=lr_offsets, strides=strides_, randomize_strides=randomize_ 407 ) 408 409 if mask is not None: 410 weights, lr_edges, lr_weights = _apply_mask(mask, g, weights, lr_edges, lr_weights) 411 412 lr_edges, lr_weights = _process_weights(g, lr_edges, lr_weights, weight_function, beta, offsets=lr_offsets, 413 strides=strides, randomize_strides=randomize_) 414 return g, weights, lr_edges, lr_weights 415 416 417# weight function based on the seung paper, using the push delta 418# of the discriminative loss term. 419def discriminative_loss_weight(dist, delta): 420 """@private 421 """ 422 dist = (2 * delta - dist) / (2 * delta) 423 dist = 1. - np.maximum(dist, 0) ** 2 424 return dist 425 426 427def segment_embeddings_mws( 428 embeddings: np.ndarray, 429 distance_type: str, 430 offsets: List[List[int]], 431 bias: float = 0.0, 432 strides: List[int] = None, 433 weight_function: Optional[callable] = None, 434 mask: Optional[np.ndarray] = None, 435) -> np.ndarray: 436 """Compute a segmentation by computing a mutex watershed based on pixel emeddings. 437 438 Args: 439 embeddings: The pixel embeddings. 440 distance_type: The distance type for deriving affinities from embeddings. 441 offsets: The affinity offsets. 442 bias: Additional bias factor to apply to the affinities. 443 This can be used to reduce under-segmentation (positive value) or over-segmentation (negative value). 444 strides: The strides for sub-sampling repulsive mutex edges. 445 weight_function: Optional function for weighting the affinity values. 446 mask: Mask to ignore in the segmentation. 447 448 Returns: 449 The segmentation. 450 """ 451 g, costs, mutex_uvs, mutex_costs = _embeddings_to_problem( 452 embeddings, distance_type, beta=None, 453 offsets=offsets, strides=strides, 454 weight_function=weight_function, 455 mask=mask 456 ) 457 if bias > 0: 458 mutex_costs += bias 459 uvs = g.uvIds() 460 seg = mutex_watershed_clustering(uvs, mutex_uvs, costs, mutex_costs).reshape(embeddings.shape[1:]) 461 if mask is not None: 462 seg = _ensure_mask_is_zero(seg, mask) 463 return seg
28def embedding_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray: 29 """Compute PCA of per-pixel embeddings. 30 31 Args: 32 embeddings: The per-pixel embeddings. 33 n_components: The number of PCA components. 34 as_rgb: Whether to reshape the output so that it can be displayed as RGB image. 35 36 Returns: 37 The PCA of the embeddings. 38 """ 39 if as_rgb and n_components != 3: 40 raise ValueError("") 41 42 pca = PCA(n_components=n_components) 43 embed_dim = embeddings.shape[0] 44 shape = embeddings.shape[1:] 45 46 embed_flat = embeddings.reshape(embed_dim, -1).T 47 embed_flat = pca.fit_transform(embed_flat).T 48 embed_flat = embed_flat.reshape((n_components,) + shape) 49 50 if as_rgb: 51 embed_flat = 255 * (embed_flat - embed_flat.min()) / np.ptp(embed_flat) 52 embed_flat = embed_flat.astype("uint8") 53 54 return embed_flat
Compute PCA of per-pixel embeddings.
Arguments:
- embeddings: The per-pixel embeddings.
- n_components: The number of PCA components.
- as_rgb: Whether to reshape the output so that it can be displayed as RGB image.
Returns:
The PCA of the embeddings.
63def edge_probabilities_from_embeddings( 64 embeddings: np.ndarray, segmentation: np.ndarray, rag, delta: float 65) -> np.ndarray: 66 """Derive edge probabilities from pixel embeddings. 67 68 Args: 69 embeddings: The pixel embeddings. 70 segmentation: The segmentation. 71 rag: The region adjacency graph derived from the segmentation. 72 delta: The delta factor used in the push force when training the embeddings. 73 74 Returns: 75 The edge probabilties. 76 """ 77 n_nodes = rag.numberOfNodes 78 embed_dim = embeddings.shape[0] 79 80 segmentation = segmentation.astype("uint32") 81 mean_embeddings = np.zeros((n_nodes, embed_dim), dtype="float32") 82 for cid in range(embed_dim): 83 mean_embed = vigra.analysis.extractRegionFeatures(embeddings[cid], segmentation, features=["mean"])["mean"] 84 mean_embeddings[:, cid] = mean_embed 85 86 uv_ids = rag.uvIds() 87 embed_u = mean_embeddings[uv_ids[:, 0]] 88 embed_v = mean_embeddings[uv_ids[:, 1]] 89 edge_probabilities = 1. - _embeddings_to_probabilities(embed_u, embed_v, delta, embedding_axis=1) 90 return edge_probabilities
Derive edge probabilities from pixel embeddings.
Arguments:
- embeddings: The pixel embeddings.
- segmentation: The segmentation.
- rag: The region adjacency graph derived from the segmentation.
- delta: The delta factor used in the push force when training the embeddings.
Returns:
The edge probabilties.
95def embeddings_to_affinities( 96 embeddings: np.ndarray, 97 offsets: List[List[int]], 98 delta: float, 99 invert: bool = False, 100) -> np.ndarray: 101 """Convert pixel embeddings to affinities. 102 103 Computes the affinity according to the formula 104 a_ij = max((2 * delta - ||x_i - x_j||) / 2 * delta, 0) ** 2, 105 where delta is the push force used in training the embeddings. 106 Introduced in "Learning Dense Voxel Embeddings for 3D Neuron Reconstruction": 107 https://arxiv.org/pdf/1909.09872.pdf 108 109 Args: 110 embeddings: The pixel embeddings. 111 offsets: The offset vectors for which to compute affinities. 112 delta: The delta factor used in the push force when training the embeddings. 113 invert: Whether to invert the affinites. 114 115 Returns: 116 The affinity values. 117 """ 118 ndim = embeddings.ndim - 1 119 if not all(len(off) == ndim for off in offsets): 120 raise ValueError("Incosistent dimension of offsets and embeddings") 121 122 n_channels = len(offsets) 123 shape = embeddings.shape[1:] 124 affinities = np.zeros((n_channels,) + shape, dtype="float32") 125 126 for cid, off in enumerate(offsets): 127 # we need to shift in the other direction in order to 128 # get the correct offset 129 # also, we need to add a zero shift in the first axis 130 shift_off = [0] + [-o for o in off] 131 # we could also shift via np.pad and slicing 132 shifted = shift(embeddings, shift_off, order=0, prefilter=False) 133 affs = _embeddings_to_probabilities(embeddings, shifted, delta, embedding_axis=0) 134 affinities[cid] = affs 135 136 if invert: 137 affinities = 1. - affinities 138 139 return affinities
Convert pixel embeddings to affinities.
Computes the affinity according to the formula a_ij = max((2 * delta - ||x_i - x_j||) / 2 * delta, 0) ** 2, where delta is the push force used in training the embeddings. Introduced in "Learning Dense Voxel Embeddings for 3D Neuron Reconstruction": https://arxiv.org/pdf/1909.09872.pdf
Arguments:
- embeddings: The pixel embeddings.
- offsets: The offset vectors for which to compute affinities.
- delta: The delta factor used in the push force when training the embeddings.
- invert: Whether to invert the affinites.
Returns:
The affinity values.
177def segment_hdbscan( 178 embeddings: np.ndarray, 179 min_size: int, eps: float, 180 remove_largest: bool, 181 n_jobs: int = 1, 182) -> np.ndarray: 183 """Compute a segmentation by clustering pixel emeddings with HDBSCAN. 184 185 Args: 186 embeddings: The pixel embeddings. 187 min_size: The minimal segment size. 188 eps: Epsilon factor for HDBSCAN. 189 remove_largest: Whether to remove the largest (=background) object. 190 n_jobs: The number of jobs for parallelizing HDBSCAN. 191 192 Returns: 193 The segmentation. 194 """ 195 assert hdbscan is not None, "Needs hdbscan library" 196 with threadpool_limits(limits=n_jobs): 197 clustering = hdbscan.HDBSCAN( 198 min_cluster_size=min_size, cluster_selection_epsilon=eps, core_dist_n_jobs=n_jobs 199 ) 200 result = _cluster(embeddings, clustering, remove_largest=remove_largest).astype("uint64") 201 return result
Compute a segmentation by clustering pixel emeddings with HDBSCAN.
Arguments:
- embeddings: The pixel embeddings.
- min_size: The minimal segment size.
- eps: Epsilon factor for HDBSCAN.
- remove_largest: Whether to remove the largest (=background) object.
- n_jobs: The number of jobs for parallelizing HDBSCAN.
Returns:
The segmentation.
204def segment_mean_shift(embeddings: np.ndarray, bandwidth: float, n_jobs: int = 1) -> np.ndarray: 205 """Compute a segmentation by clustering pixel emeddings with mean shift. 206 207 Args: 208 embeddings: The pixel embeddings. 209 bandwidth: The bandwidth parameter for the mean shift algorithm. 210 n_jobs: The number of jobs for parallelizing MeanShift. 211 212 Returns: 213 The segmentation. 214 """ 215 with threadpool_limits(limits=n_jobs): 216 clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs) 217 result = _cluster(embeddings, clustering).astype("uint64") 218 return result
Compute a segmentation by clustering pixel emeddings with mean shift.
Arguments:
- embeddings: The pixel embeddings.
- bandwidth: The bandwidth parameter for the mean shift algorithm.
- n_jobs: The number of jobs for parallelizing MeanShift.
Returns:
The segmentation.
221def segment_consistency( 222 embeddings1: np.ndarray, 223 embeddings2: np.ndarray, 224 bandwidth: float, 225 iou_threshold: float, 226 num_anchors: int, 227 skip_zero: bool = True, 228 n_jobs: int = 1 229) -> np.ndarray: 230 """Compute a segmentation by clustering pixel emeddings via mean shift and consistency. 231 232 First, the segmentation is computed using mean shift. Then, for each instance in this 233 segmentation the corresponding instance mask is derived from the second set of embeddings. 234 Masks that have a low IOU with the corresponding instance mask are removed. 235 236 Args: 237 embeddings1: The first set of pixel embeddings, used for mean shift clustering. 238 embeddings2: The second set of pixel embeddings, used for consistency. 239 bandwidth: The bandwidth parameter for the mean shift algorithm. 240 iou_threshold: The threshold for consistency filtering. 241 num_anchors: The number of anchors for computing the instance masks for consistency. 242 skip_zero: Whether to skip the background label. 243 n_jobs: The number of jobs for parallelizing MeanShift. 244 245 Returns: 246 The segmentation. 247 """ 248 def _iou(gt, seg): 249 epsilon = 1e-5 250 inter = (gt & seg).sum() 251 union = (gt | seg).sum() 252 253 iou = (inter + epsilon) / (union + epsilon) 254 return iou 255 256 with threadpool_limits(limits=n_jobs): 257 clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs) 258 clusters = _cluster(embeddings1, clustering) 259 260 for label_id in np.unique(clusters): 261 if label_id == 0 and skip_zero: 262 continue 263 264 mask = clusters == label_id 265 iou_table = [] 266 # FIXME: make it work for 3d 267 y, x = np.nonzero(mask) 268 for _ in range(num_anchors): 269 ind = np.random.randint(len(y)) 270 # get random embedding anchor from emb-g 271 anchor_emb = embeddings2[:, y[ind], x[ind]] 272 # add necessary singleton dims 273 anchor_emb = anchor_emb[:, None, None] 274 # compute the instance mask from emb2 275 inst_mask = np.linalg.norm(embeddings2 - anchor_emb, axis=0) < bandwidth 276 iou_table.append(_iou(mask, inst_mask)) 277 # choose final IoU as a median 278 final_iou = np.median(iou_table) 279 280 if final_iou < iou_threshold: 281 clusters[mask] = 0 282 283 return clusters.astype("uint64")
Compute a segmentation by clustering pixel emeddings via mean shift and consistency.
First, the segmentation is computed using mean shift. Then, for each instance in this segmentation the corresponding instance mask is derived from the second set of embeddings. Masks that have a low IOU with the corresponding instance mask are removed.
Arguments:
- embeddings1: The first set of pixel embeddings, used for mean shift clustering.
- embeddings2: The second set of pixel embeddings, used for consistency.
- bandwidth: The bandwidth parameter for the mean shift algorithm.
- iou_threshold: The threshold for consistency filtering.
- num_anchors: The number of anchors for computing the instance masks for consistency.
- skip_zero: Whether to skip the background label.
- n_jobs: The number of jobs for parallelizing MeanShift.
Returns:
The segmentation.
428def segment_embeddings_mws( 429 embeddings: np.ndarray, 430 distance_type: str, 431 offsets: List[List[int]], 432 bias: float = 0.0, 433 strides: List[int] = None, 434 weight_function: Optional[callable] = None, 435 mask: Optional[np.ndarray] = None, 436) -> np.ndarray: 437 """Compute a segmentation by computing a mutex watershed based on pixel emeddings. 438 439 Args: 440 embeddings: The pixel embeddings. 441 distance_type: The distance type for deriving affinities from embeddings. 442 offsets: The affinity offsets. 443 bias: Additional bias factor to apply to the affinities. 444 This can be used to reduce under-segmentation (positive value) or over-segmentation (negative value). 445 strides: The strides for sub-sampling repulsive mutex edges. 446 weight_function: Optional function for weighting the affinity values. 447 mask: Mask to ignore in the segmentation. 448 449 Returns: 450 The segmentation. 451 """ 452 g, costs, mutex_uvs, mutex_costs = _embeddings_to_problem( 453 embeddings, distance_type, beta=None, 454 offsets=offsets, strides=strides, 455 weight_function=weight_function, 456 mask=mask 457 ) 458 if bias > 0: 459 mutex_costs += bias 460 uvs = g.uvIds() 461 seg = mutex_watershed_clustering(uvs, mutex_uvs, costs, mutex_costs).reshape(embeddings.shape[1:]) 462 if mask is not None: 463 seg = _ensure_mask_is_zero(seg, mask) 464 return seg
Compute a segmentation by computing a mutex watershed based on pixel emeddings.
Arguments:
- embeddings: The pixel embeddings.
- distance_type: The distance type for deriving affinities from embeddings.
- offsets: The affinity offsets.
- bias: Additional bias factor to apply to the affinities. This can be used to reduce under-segmentation (positive value) or over-segmentation (negative value).
- strides: The strides for sub-sampling repulsive mutex edges.
- weight_function: Optional function for weighting the affinity values.
- mask: Mask to ignore in the segmentation.
Returns:
The segmentation.