
  1import math
  3import numpy as np
  4import torch
  5import torch.nn as nn
  7    from torch_scatter import scatter_mean
  8except ImportError:
  9    scatter_mean = None
 11from . import contrastive_impl as cimpl
 12from .affinity_side_loss import AffinitySideLoss
 13from .dice import DiceLoss
 16def compute_cluster_means(embeddings: torch.Tensor, target: torch.Tensor, n_instances: int) -> torch.Tensor:
 17    """Compute mean embeddings per instance.
 19    Args:
 20        embeddings: The tensor of pixel embeddings with shape: ExSPATIAL. E is the embedding dimension.
 21        target: One-hot encoded target instances with shape: SPATIAL.
 22        n_instances: The number of instances.
 24    Returns:
 25        The cluster means.
 26    """
 27    assert scatter_mean is not None, "torch_scatter is required"
 28    embeddings = embeddings.flatten(1)
 29    target = target.flatten()
 30    assert target.min() == 0, \
 31        "The target min value has to be zero, otherwise this will lead to errors in scatter."
 32    mean_embeddings = scatter_mean(embeddings, target, dim_size=n_instances)
 33    return mean_embeddings.transpose(1, 0)
 36def select_stable_anchor(
 37    embeddings: torch.Tensor,
 38    mean_embedding: torch.Tensor,
 39    object_mask: torch.Tensor,
 40    delta_var: float,
 41    norm: str = "fro"
 42) -> torch.Tensor:
 43    """Sample anchor embeddings from the object mask.
 45    Given a binary mask of an object (`object_mask`) and a `mean_embedding` vector within the mask,
 46    the function selects a pixel from the mask at random and returns its embedding only if it's closer than
 47    `delta_var` from the `mean_embedding`.
 49    Args:
 50        embeddings: The embeddings, a ExSPATIAL vector field of an image.
 51        mean_embedding: The E-dimensional mean of embeddings lying within the `object_mask`.
 52        object_mask: Binary image of a selected object.
 53        delta_var: The pull force margin of the contrastive loss.
 54        norm: The vector norm used. By default the frobenius norm is used.
 56    Returns:
 57        Embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found.
 58    """
 59    indices = torch.nonzero(object_mask, as_tuple=True)
 60    # convert to numpy
 61    indices = [t.cpu().numpy() for t in indices]
 63    # randomize coordinates
 64    seed = np.random.randint(np.iinfo("int32").max)
 65    for t in indices:
 66        rs = np.random.RandomState(seed)
 67        rs.shuffle(t)
 69    for ind in range(len(indices[0])):
 70        if object_mask.dim() == 2:
 71            y, x = indices
 72            anchor_emb = embeddings[:, y[ind], x[ind]]
 73            anchor_emb = anchor_emb[..., None, None]
 74        else:
 75            z, y, x = indices
 76            anchor_emb = embeddings[:, z[ind], y[ind], x[ind]]
 77            anchor_emb = anchor_emb[..., None, None, None]
 78        dist_to_mean = torch.norm(mean_embedding - anchor_emb, norm)
 79        if dist_to_mean < delta_var:
 80            return anchor_emb
 81    # if stable anchor has not been found, return mean_embedding
 82    return mean_embedding
 85class GaussianKernel(nn.Module):
 """
 87    """
 88    def __init__(self, delta_var, pmaps_threshold):
 89        super().__init__()
 90        self.delta_var = delta_var
 91        # dist_var^2 = -2*sigma*ln(pmaps_threshold)
 92        self.two_sigma = delta_var * delta_var / (-math.log(pmaps_threshold))
 94    def forward(self, dist_map):
 95        return torch.exp(- dist_map * dist_map / self.two_sigma)
 98class CombinedAuxLoss(nn.Module):
 """
100    """
101    def __init__(self, losses, weights):
102        super().__init__()
103        self.losses = losses
104        self.weights = weights
106    def forward(self, embeddings, target, instance_pmaps, instance_masks):
107        result = 0.
108        for loss, weight in zip(self.losses, self.weights):
109            if isinstance(loss, AffinitySideLoss):
110                # add batch axis / batch and channel axis for embeddings, target
111                result += weight * loss(embeddings[None], target[None, None])
112            elif instance_masks is not None:
113                result += weight * loss(instance_pmaps, instance_masks).mean()
114        return result
117class ContrastiveLossBase(nn.Module):
"""
119    """
120    def __init__(self, delta_var, delta_dist,
121                 norm="fro", alpha=1., beta=1., gamma=0.001, unlabeled_push_weight=0.0,
122                 instance_term_weight=1.0, impl=None):
123        assert scatter_mean is not None, "Spoco loss requires pytorch_scatter"
124        super().__init__()
125        self.delta_var = delta_var
126        self.delta_dist = delta_dist
127        self.norm = norm
128        self.alpha = alpha
129        self.beta = beta
130        self.gamma = gamma
131        self.unlabeled_push_weight = unlabeled_push_weight
132        self.unlabeled_push = unlabeled_push_weight > 0
133        self.instance_term_weight = instance_term_weight
135    def __str__(self):
136        return super().__str__() + f"\ndelta_var: {self.delta_var}\ndelta_dist: {self.delta_dist}" \
137                                   f"\nalpha: {self.alpha}\nbeta: {self.beta}\ngamma: {self.gamma}" \
138                                   f"\nunlabeled_push_weight: {self.unlabeled_push_weight}" \
139                                   f"\ninstance_term_weight: {self.instance_term_weight}"
141    def _compute_variance_term(self, cluster_means, embeddings, target, instance_counts, ignore_zero_label):
142        """Computes the variance term, i.e. intra-cluster pull force that draws embeddings towards the mean embedding
144        C - number of clusters (instances)
145        E - embedding dimension
146        SPATIAL - volume shape, i.e. DxHxW for 3D/ HxW for 2D
148        Args:
149            cluster_means: mean embedding of each instance, tensor (CxE)
150            embeddings: embeddings vectors per instance, tensor (ExSPATIAL)
151            target: label tensor (1xSPATIAL); each label is represented as one-hot vector
152            instance_counts: number of voxels per instance
153            ignore_zero_label: if True ignores the cluster corresponding to the 0-label
154        """
155        assert target.dim() in (2, 3)
156        ignore_labels = [0] if ignore_zero_label else None
157        return cimpl._compute_variance_term_scatter(
158            cluster_means, embeddings.unsqueeze(0), target.unsqueeze(0),
159            self.norm, self.delta_var, instance_counts, ignore_labels
160        )
162    def _compute_unlabeled_push(self, cluster_means, embeddings, target):
163        assert target.dim() in (2, 3)
164        n_instances = cluster_means.shape[0]
166        # permute embedding dimension at the end
167        if target.dim() == 2:
168            embeddings = embeddings.permute(1, 2, 0)
169        else:
170            embeddings = embeddings.permute(1, 2, 3, 0)
172        # decrease number of instances `C` since we're ignoring 0-label
173        n_instances -= 1
174        # if there is only 0-label in the target return 0
175        if n_instances == 0:
176            return 0.0
178        background_mask = target == 0
179        n_background = background_mask.sum()
180        background_push = 0.0
181        # skip embedding corresponding to the background pixels
182        for cluster_mean in cluster_means[1:]:
183            # compute distances between embeddings and a given cluster_mean
184            dist_to_mean = torch.norm(embeddings - cluster_mean, self.norm, dim=-1)
185            # apply background mask and compute hinge
186            dist_hinged = torch.clamp((self.delta_dist - dist_to_mean) * background_mask, min=0) ** 2
187            background_push += torch.sum(dist_hinged) / n_background
189        # normalize by the number of instances
190        return background_push / n_instances
192    # def _compute_distance_term_scatter(cluster_means, norm, delta_dist):
193    def _compute_distance_term(self, cluster_means, ignore_zero_label):
194        """
195        Compute the distance term, i.e an inter-cluster push-force that pushes clusters away from each other, increasing
196        the distance between cluster centers
198        Args:
199            cluster_means: mean embedding of each instance, tensor (CxE)
200            ignore_zero_label: if True ignores the cluster corresponding to the 0-label
201        """
202        ignore_labels = [0] if ignore_zero_label else None
203        return cimpl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist, ignore_labels)
205    def _compute_regularizer_term(self, cluster_means):
206        """
207        Computes the regularizer term, i.e. a small pull-force that draws all clusters towards origin to keep
208        the network activations bounded
209        """
210        # compute the norm of the mean embeddings
211        norms = torch.norm(cluster_means, p=self.norm, dim=1)
212        # return the average norm per batch
213        return torch.sum(norms) / cluster_means.size(0)
215    def compute_instance_term(self, embeddings, cluster_means, target):
216        """Computes auxiliary loss based on embeddings and a given list of target
217        instances together with their mean embeddings.
219        Args:
220            embeddings (torch.tensor): pixel embeddings (ExSPATIAL)
221            cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL)
222            target (torch.tensor): ground truth instance segmentation (SPATIAL)
224        Returns:
225            float: value of the instance-based term
226        """
227        raise NotImplementedError
229    def forward(self, input_, target):
230        """
231        Args:
232             input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims)
233                expects float32 tensor
234             target (torch.tensor): ground truth instance segmentation (Nx1DxHxW)
235                expects int64 tensor
236        Returns:
237            Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term
238                + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term
239        """
240        # enable calling this loss from the spoco trainer, which passes a tuple
241        if isinstance(input_, tuple):
242            assert len(input_) == 2
243            input_ = input_[0]
245        n_batches = input_.shape[0]
246        # compute the loss per each instance in the batch separately
247        # and sum it up in the per_instance variable
248        loss = 0.0
249        for single_input, single_target in zip(input_, target):
250            # compare spatial dimensions
251            assert single_input.shape[1:] == single_target.shape[1:], f"{single_input.shape}, {single_target.shape}"
252            assert single_target.shape[0] == 1
253            single_target = single_target[0]
255            contains_bg = 0 in single_target
256            ignore_zero_label = self.unlabeled_push and contains_bg
258            # get number of instances in the batch instance
259            instance_ids, instance_counts = torch.unique(single_target, return_counts=True)
261            # get the number of instances
262            C = instance_ids.size(0)
264            # compute mean embeddings (output is of shape CxE)
265            cluster_means = compute_cluster_means(single_input, single_target, C)
267            # compute variance term, i.e. pull force
268            variance_term = self._compute_variance_term(
269                cluster_means, single_input, single_target, instance_counts, ignore_zero_label
270            )
272            # compute unlabeled push force, i.e. push force between
273            # the mean cluster embeddings and embeddings of background pixels
274            # compute only ignore_zero_label is True, i.e. a given patch contains background label
275            unlabeled_push_term = 0.0
276            if self.unlabeled_push and contains_bg:
277                unlabeled_push_term = self._compute_unlabeled_push(cluster_means, single_input, single_target)
279            # compute the instance-based auxiliary loss
280            instance_term = self.compute_instance_term(single_input, cluster_means, single_target)
282            # compute distance term, i.e. push force
283            distance_term = self._compute_distance_term(cluster_means, ignore_zero_label)
285            # compute regularization term
286            regularization_term = self._compute_regularizer_term(cluster_means)
288            # compute total loss and sum it up
289            loss = self.alpha * variance_term + \
290                self.beta * distance_term + \
291                self.gamma * regularization_term + \
292                self.instance_term_weight * instance_term + \
293                self.unlabeled_push_weight * unlabeled_push_term
295            loss += loss
297        # reduce across the batch dimension
298        return loss.div(n_batches)
301class ExtendedContrastiveLoss(ContrastiveLossBase):
302    """Contrastive loss extended with instance-based loss term and background push term.
304    Based on:
305    "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings":
308    Args:
309        delta_var: The hinge distance for the variance term in the discriminative loss.
310        delta_dist: The hinge distance for the distance term in the discriminative loss.
311        norm: The norm to use.
312        alpha: Weight for the variance term of the discrimantive loss.
313        beta: Weight for the distance term of the discriminative loss.
314        gamma: Weight for the regularization term of the discriminative loss.
315        unlabeled_push_weight: The weight term for the unlabeled loss term.
316        instance_term_weight: The weight term for the instance loss term.
317        aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'.
318        pmaps_threshold: The probabilit threshold for the background push term.
319        kwargs: Additional keyword arguments for other loss terms.
320    """
321    def __init__(
322        self,
323        delta_var: float,
324        delta_dist: float,
325        norm: str = "fro",
326        alpha: float = 1.0,
327        beta: float = 1.0,
328        gamma: float = 0.001,
329        unlabeled_push_weight: float = 1.0,
330        instance_term_weight: float = 1.0,
331        aux_loss: str = "dice",
332        pmaps_threshold: float = 0.9,
333        **kwargs,
334    ):
335        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
336                         unlabeled_push_weight=unlabeled_push_weight,
337                         instance_term_weight=instance_term_weight)
338        # Init auxiliary loss.
339        assert aux_loss in ["dice", "affinity", "dice_aff"]
340        if aux_loss == "dice":
341            self.aff_loss = None
342            self.dice_loss = DiceLoss()
343        # Additional auxiliary losses.
344        elif aux_loss == "affinity":
345            self.aff_loss = AffinitySideLoss(
346                delta=delta_dist,
347                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
348                n_samples=kwargs.get("n_samples", 9)
349            )
350            self.dice_loss = None
351        elif aux_loss == "dice_aff":
352            # combine dice and affinity side loss
353            self.dice_weight = kwargs.get("dice_weight", 1.0)
354            self.aff_weight = kwargs.get("aff_weight", 1.0)
356            self.aff_loss = AffinitySideLoss(
357                delta=delta_dist,
358                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
359                n_samples=kwargs.get("n_samples", 9)
360            )
361            self.dice_loss = DiceLoss()
363        # Init dist_to_mask kernel which maps distance to the cluster center to instance probability map.
364        self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold)
365        self.init_kwargs = {
366            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
367            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
368            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold
369        }
370        self.init_kwargs.update(kwargs)
372    # FIXME stacking per instance here makes this very memory hungry,
373    def _create_instance_pmaps_and_masks(self, embeddings, anchors, target):
374        inst_pmaps = []
375        inst_masks = []
377        if not inst_masks:
378            return None, None
380        # stack along batch dimension
381        inst_pmaps = torch.stack(inst_pmaps)
382        inst_masks = torch.stack(inst_masks)
384        return inst_pmaps, inst_masks
386    def compute_instance_term(self, embeddings, cluster_means, target):
"""
388        """
389        assert embeddings.size()[1:] == target.size()
391        if self.aff_loss is None:
392            aff_loss = None
393        else:
394            aff_loss = self.aff_loss(embeddings[None], target[None, None])
396        if self.dice_loss is None:
397            dice_loss = None
398        else:
399            dice_loss = []
401            # permute embedding dimension at the end
402            if target.dim() == 2:
403                embeddings = embeddings.permute(1, 2, 0)
404            else:
405                embeddings = embeddings.permute(1, 2, 3, 0)
407            # compute random anchors per instance
408            instances = torch.unique(target)
409            for i in instances:
410                if i == 0:
411                    continue
412                anchor_emb = cluster_means[i]
413                # FIXME this makes training extremely slow, check with Adrian if this is the latest version
414                # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var)
416                distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1)
417                instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0)
418                instance_mask = (target == i).float().unsqueeze(0)
420                dice_loss.append(self.dice_loss(instance_pmap, instance_mask))
422            dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0
424        assert not (dice_loss is None and aff_loss is None)
425        if dice_loss is None and aff_loss is not None:
426            return aff_loss
427        if dice_loss is not None and aff_loss is None:
428            return dice_loss
429        else:
430            return self.dice_weight * dice_loss + self.aff_weight * aff_loss
433class SPOCOLoss(ExtendedContrastiveLoss):
434    """The full SPOCO Loss for instance segmentation training with sparse instance labels.
436    Extends the "classic" contrastive loss with an instance-based term and a unsupervised embedding consistency term.
437    An additional background push term can be added. It is disabled by default because we assume sparse instance labels.
439    Based on:
440    "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings":
443    Args:
444        delta_var: The hinge distance for the variance term in the discriminative loss.
445        delta_dist: The hinge distance for the distance term in the discriminative loss.
446        norm: The norm to use.
447        alpha: Weight for the variance term of the discrimantive loss.
448        beta: Weight for the distance term of the discriminative loss.
449        gamma: Weight for the regularization term of the discriminative loss.
450        unlabeled_push_weight: The weight term for the unlabeled loss term.
451        instance_term_weight: The weight term for the instance loss term.
452        aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'.
453        pmaps_threshold: The probabilit threshold for the background push term.
454        max_anchors: The number of anchors to sample for the consistency term.
455        volume_threshold:
456        kwargs: Additional keyword arguments for other loss terms.
457    """
458    def __init__(
459        self,
460        delta_var: float,
461        delta_dist: float,
462        norm: str = "fro",
463        alpha: float = 1.0,
464        beta: float = 1.0,
465        gamma: float = 0.001,
466        unlabeled_push_weight: float = 0.0,
467        instance_term_weight: float = 1.0,
468        consistency_term_weight: float = 1.0,
469        aux_loss: str = "dice",
470        pmaps_threshold: float = 0.9,
471        max_anchors: int = 20,
472        volume_threshold: float = 0.05,
473        **kwargs,
474    ):
475        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
476                         unlabeled_push_weight=unlabeled_push_weight,
477                         instance_term_weight=instance_term_weight,
478                         aux_loss=aux_loss,
479                         pmaps_threshold=pmaps_threshold,
480                         **kwargs)
482        self.consistency_term_weight = consistency_term_weight
483        self.max_anchors = max_anchors
484        self.volume_threshold = volume_threshold
485        self.consistency_loss = DiceLoss()
486        self.init_kwargs = {
487            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
488            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
489            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold,
490            "max_anchors": max_anchors, "volume_threshold": volume_threshold
491        }
492        self.init_kwargs.update(kwargs)
494    def __str__(self):
495        return super().__str__() + f"\nconsistency_term_weight: {self.consistency_term_weight}"
497    def _inst_pmap(self, emb, anchor):
498        # compute distance map
499        distance_map = torch.norm(emb - anchor, self.norm, dim=-1)
500        # convert distance map to instance pmaps and return
501        return self.dist_to_mask(distance_map)
503    def emb_consistency(self, emb_q, emb_k, mask):
"""
505        """
506        inst_q = []
507        inst_k = []
508        for i in range(self.max_anchors):
509            if mask.sum() < self.volume_threshold * mask.numel():
510                break
512            # get random anchor
513            indices = torch.nonzero(mask, as_tuple=True)
514            ind = np.random.randint(len(indices[0]))
516            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
517            inst_q.append(q_pmap)
519            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
520            inst_k.append(k_pmap)
522        # stack along channel dim
523        inst_q = torch.stack(inst_q)
524        inst_k = torch.stack(inst_k)
526        loss = self.consistency_loss(inst_q, inst_k)
527        return loss
529    def _extract_pmap(self, emb, mask, indices, ind):
530        if mask.dim() == 2:
531            y, x = indices
532            anchor = emb[:, y[ind], x[ind]]
533            emb = emb.permute(1, 2, 0)
534        else:
535            z, y, x = indices
536            anchor = emb[:, z[ind], y[ind], x[ind]]
537            emb = emb.permute(1, 2, 3, 0)
539        return self._inst_pmap(emb, anchor)
541    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
542        """Compute the SPOCO loss.
544        Args:
545            input_: The predicted embeddings.
546            target: The segmentation target.
548        Returns:
549            The SPOCO loss.
550        """
551        assert len(input_) == 2
552        emb_q, emb_k = input_
554        # Compute extended contrastive loss only on the embeddings coming from q.
555        contrastive_loss = super().forward(emb_q, target)
557        # TODO enable computing the consistency on all pixels!
558        # Compute consistency term.
559        for e_q, e_k, t in zip(emb_q, emb_k, target):
560            unlabeled_mask = (t[0] == 0).int()
561            if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel():
562                continue
563            emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask)
564            contrastive_loss += self.consistency_term_weight * emb_consistency_loss
566        return contrastive_loss
569class SPOCOConsistencyLoss(nn.Module):
570    """Unsupervised consistency term computed between embeddings.
572    Args:
573        delta_var: Hinge distance of the distance loss term.
574        pmaps_threshold:
575        max_anchors: The maximum number of anchors to compute for the consistency loss.
576        norm: The vector norm used. By default the frobenius norm is used.
577    """
578    def __init__(self, delta_var: float, pmaps_threshold: float, max_anchors: int = 30, norm: str = "fro"):
579        super().__init__()
580        self.max_anchors = max_anchors
581        self.consistency_loss = DiceLoss()
582        self.norm = norm
583        self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold)
584        self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold,
585                            "max_anchors": max_anchors, "norm": norm}
587    def _inst_pmap(self, emb, anchor):
588        # compute distance map
589        distance_map = torch.norm(emb - anchor, self.norm, dim=-1)
590        # convert distance map to instance pmaps and return
591        return self.dist_to_mask(distance_map)
593    def emb_consistency(self, emb_q, emb_k):
"""
595        """
596        inst_q = []
597        inst_k = []
598        mask = torch.ones(emb_q.shape[1:])
599        for i in range(self.max_anchors):
600            # get random anchor
601            indices = torch.nonzero(mask, as_tuple=True)
602            ind = np.random.randint(len(indices[0]))
604            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
605            inst_q.append(q_pmap)
607            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
608            inst_k.append(k_pmap)
610        # stack along channel dim
611        inst_q = torch.stack(inst_q)
612        inst_k = torch.stack(inst_k)
614        loss = self.consistency_loss(inst_q, inst_k)
615        return loss
617    def _extract_pmap(self, emb, mask, indices, ind):
618        if mask.dim() == 2:
619            y, x = indices
620            anchor = emb[:, y[ind], x[ind]]
621            emb = emb.permute(1, 2, 0)
622        else:
623            z, y, x = indices
624            anchor = emb[:, z[ind], y[ind], x[ind]]
625            emb = emb.permute(1, 2, 3, 0)
627        return self._inst_pmap(emb, anchor)
629    def forward(self, emb_q: torch.Tensor, emb_k: torch.Tensor) -> torch.Tensor:
630        """Compute the consistency loss term between embeddings.
632        Args:
633            emb_q: The first embedding predictions.
634            emb_k: The second embedding predictions.
636        Returns:
637            The consistency loss.
638        """
639        contrastive_loss = 0.0
640        # compute consistency term
641        for e_q, e_k in zip(emb_q, emb_k):
642            contrastive_loss += self.emb_consistency(e_q, e_k)
643        return contrastive_loss
Compute mean embeddings per instance.

  • embeddings: The tensor of pixel embeddings with shape: ExSPATIAL. E is the embedding dimension.
  • target: One-hot encoded target instances with shape: SPATIAL.
  • n_instances: The number of instances.

The cluster means.

Sample anchor embeddings from the object mask.

Given a binary mask of an object (object_mask) and a mean_embedding vector within the mask, the function selects a pixel from the mask at random and returns its embedding only if it's closer than delta_var from the mean_embedding.

  • embeddings: The embeddings, a ExSPATIAL vector field of an image.
  • mean_embedding: The E-dimensional mean of embeddings lying within the object_mask.
  • object_mask: Binary image of a selected object.
  • delta_var: The pull force margin of the contrastive loss.
  • norm: The vector norm used. By default the frobenius norm is used.

Embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found.

"""
"""
