torch_em.loss.spoco_loss

  1import math
  2
  3import numpy as np
  4import torch
  5import torch.nn as nn
  6try:
  7    from torch_scatter import scatter_mean
  8except ImportError:
  9    scatter_mean = None
 10
 11from . import contrastive_impl as cimpl
 12from .affinity_side_loss import AffinitySideLoss
 13from .dice import DiceLoss
 14
 15
 16def compute_cluster_means(embeddings: torch.Tensor, target: torch.Tensor, n_instances: int) -> torch.Tensor:
 17    """Compute mean embeddings per instance.
 18
 19    Args:
 20        embeddings: The tensor of pixel embeddings with shape: ExSPATIAL. E is the embedding dimension.
 21        target: One-hot encoded target instances with shape: SPATIAL.
 22        n_instances: The number of instances.
 23
 24    Returns:
 25        The cluster means.
 26    """
 27    assert scatter_mean is not None, "torch_scatter is required"
 28    embeddings = embeddings.flatten(1)
 29    target = target.flatten()
 30    assert target.min() == 0, \
 31        "The target min value has to be zero, otherwise this will lead to errors in scatter."
 32    mean_embeddings = scatter_mean(embeddings, target, dim_size=n_instances)
 33    return mean_embeddings.transpose(1, 0)
 34
 35
 36def select_stable_anchor(
 37    embeddings: torch.Tensor,
 38    mean_embedding: torch.Tensor,
 39    object_mask: torch.Tensor,
 40    delta_var: float,
 41    norm: str = "fro"
 42) -> torch.Tensor:
 43    """Sample anchor embeddings from the object mask.
 44
 45    Given a binary mask of an object (`object_mask`) and a `mean_embedding` vector within the mask,
 46    the function selects a pixel from the mask at random and returns its embedding only if it's closer than
 47    `delta_var` from the `mean_embedding`.
 48
 49    Args:
 50        embeddings: The embeddings, a ExSPATIAL vector field of an image.
 51        mean_embedding: The E-dimensional mean of embeddings lying within the `object_mask`.
 52        object_mask: Binary image of a selected object.
 53        delta_var: The pull force margin of the contrastive loss.
 54        norm: The vector norm used. By default the frobenius norm is used.
 55
 56    Returns:
 57        Embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found.
 58    """
 59    indices = torch.nonzero(object_mask, as_tuple=True)
 60    # convert to numpy
 61    indices = [t.cpu().numpy() for t in indices]
 62
 63    # randomize coordinates
 64    seed = np.random.randint(np.iinfo("int32").max)
 65    for t in indices:
 66        rs = np.random.RandomState(seed)
 67        rs.shuffle(t)
 68
 69    for ind in range(len(indices[0])):
 70        if object_mask.dim() == 2:
 71            y, x = indices
 72            anchor_emb = embeddings[:, y[ind], x[ind]]
 73            anchor_emb = anchor_emb[..., None, None]
 74        else:
 75            z, y, x = indices
 76            anchor_emb = embeddings[:, z[ind], y[ind], x[ind]]
 77            anchor_emb = anchor_emb[..., None, None, None]
 78        dist_to_mean = torch.norm(mean_embedding - anchor_emb, norm)
 79        if dist_to_mean < delta_var:
 80            return anchor_emb
 81    # if stable anchor has not been found, return mean_embedding
 82    return mean_embedding
 83
 84
 85class GaussianKernel(nn.Module):
 86    """@private
 87    """
 88    def __init__(self, delta_var, pmaps_threshold):
 89        super().__init__()
 90        self.delta_var = delta_var
 91        # dist_var^2 = -2*sigma*ln(pmaps_threshold)
 92        self.two_sigma = delta_var * delta_var / (-math.log(pmaps_threshold))
 93
 94    def forward(self, dist_map):
 95        return torch.exp(- dist_map * dist_map / self.two_sigma)
 96
 97
 98class CombinedAuxLoss(nn.Module):
 99    """@private
100    """
101    def __init__(self, losses, weights):
102        super().__init__()
103        self.losses = losses
104        self.weights = weights
105
106    def forward(self, embeddings, target, instance_pmaps, instance_masks):
107        result = 0.
108        for loss, weight in zip(self.losses, self.weights):
109            if isinstance(loss, AffinitySideLoss):
110                # add batch axis / batch and channel axis for embeddings, target
111                result += weight * loss(embeddings[None], target[None, None])
112            elif instance_masks is not None:
113                result += weight * loss(instance_pmaps, instance_masks).mean()
114        return result
115
116
117class ContrastiveLossBase(nn.Module):
118    """@private
119    """
120    def __init__(self, delta_var, delta_dist,
121                 norm="fro", alpha=1., beta=1., gamma=0.001, unlabeled_push_weight=0.0,
122                 instance_term_weight=1.0, impl=None):
123        assert scatter_mean is not None, "Spoco loss requires pytorch_scatter"
124        super().__init__()
125        self.delta_var = delta_var
126        self.delta_dist = delta_dist
127        self.norm = norm
128        self.alpha = alpha
129        self.beta = beta
130        self.gamma = gamma
131        self.unlabeled_push_weight = unlabeled_push_weight
132        self.unlabeled_push = unlabeled_push_weight > 0
133        self.instance_term_weight = instance_term_weight
134
135    def __str__(self):
136        return super().__str__() + f"\ndelta_var: {self.delta_var}\ndelta_dist: {self.delta_dist}" \
137                                   f"\nalpha: {self.alpha}\nbeta: {self.beta}\ngamma: {self.gamma}" \
138                                   f"\nunlabeled_push_weight: {self.unlabeled_push_weight}" \
139                                   f"\ninstance_term_weight: {self.instance_term_weight}"
140
141    def _compute_variance_term(self, cluster_means, embeddings, target, instance_counts, ignore_zero_label):
142        """Computes the variance term, i.e. intra-cluster pull force that draws embeddings towards the mean embedding
143
144        C - number of clusters (instances)
145        E - embedding dimension
146        SPATIAL - volume shape, i.e. DxHxW for 3D/ HxW for 2D
147
148        Args:
149            cluster_means: mean embedding of each instance, tensor (CxE)
150            embeddings: embeddings vectors per instance, tensor (ExSPATIAL)
151            target: label tensor (1xSPATIAL); each label is represented as one-hot vector
152            instance_counts: number of voxels per instance
153            ignore_zero_label: if True ignores the cluster corresponding to the 0-label
154        """
155        assert target.dim() in (2, 3)
156        ignore_labels = [0] if ignore_zero_label else None
157        return cimpl._compute_variance_term_scatter(
158            cluster_means, embeddings.unsqueeze(0), target.unsqueeze(0),
159            self.norm, self.delta_var, instance_counts, ignore_labels
160        )
161
162    def _compute_unlabeled_push(self, cluster_means, embeddings, target):
163        assert target.dim() in (2, 3)
164        n_instances = cluster_means.shape[0]
165
166        # permute embedding dimension at the end
167        if target.dim() == 2:
168            embeddings = embeddings.permute(1, 2, 0)
169        else:
170            embeddings = embeddings.permute(1, 2, 3, 0)
171
172        # decrease number of instances `C` since we're ignoring 0-label
173        n_instances -= 1
174        # if there is only 0-label in the target return 0
175        if n_instances == 0:
176            return 0.0
177
178        background_mask = target == 0
179        n_background = background_mask.sum()
180        background_push = 0.0
181        # skip embedding corresponding to the background pixels
182        for cluster_mean in cluster_means[1:]:
183            # compute distances between embeddings and a given cluster_mean
184            dist_to_mean = torch.norm(embeddings - cluster_mean, self.norm, dim=-1)
185            # apply background mask and compute hinge
186            dist_hinged = torch.clamp((self.delta_dist - dist_to_mean) * background_mask, min=0) ** 2
187            background_push += torch.sum(dist_hinged) / n_background
188
189        # normalize by the number of instances
190        return background_push / n_instances
191
192    # def _compute_distance_term_scatter(cluster_means, norm, delta_dist):
193    def _compute_distance_term(self, cluster_means, ignore_zero_label):
194        """
195        Compute the distance term, i.e an inter-cluster push-force that pushes clusters away from each other, increasing
196        the distance between cluster centers
197
198        Args:
199            cluster_means: mean embedding of each instance, tensor (CxE)
200            ignore_zero_label: if True ignores the cluster corresponding to the 0-label
201        """
202        ignore_labels = [0] if ignore_zero_label else None
203        return cimpl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist, ignore_labels)
204
205    def _compute_regularizer_term(self, cluster_means):
206        """
207        Computes the regularizer term, i.e. a small pull-force that draws all clusters towards origin to keep
208        the network activations bounded
209        """
210        # compute the norm of the mean embeddings
211        norms = torch.norm(cluster_means, p=self.norm, dim=1)
212        # return the average norm per batch
213        return torch.sum(norms) / cluster_means.size(0)
214
215    def compute_instance_term(self, embeddings, cluster_means, target):
216        """Computes auxiliary loss based on embeddings and a given list of target
217        instances together with their mean embeddings.
218
219        Args:
220            embeddings (torch.tensor): pixel embeddings (ExSPATIAL)
221            cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL)
222            target (torch.tensor): ground truth instance segmentation (SPATIAL)
223
224        Returns:
225            float: value of the instance-based term
226        """
227        raise NotImplementedError
228
229    def forward(self, input_, target):
230        """
231        Args:
232             input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims)
233                expects float32 tensor
234             target (torch.tensor): ground truth instance segmentation (Nx1DxHxW)
235                expects int64 tensor
236        Returns:
237            Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term
238                + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term
239        """
240        # enable calling this loss from the spoco trainer, which passes a tuple
241        if isinstance(input_, tuple):
242            assert len(input_) == 2
243            input_ = input_[0]
244
245        n_batches = input_.shape[0]
246        # compute the loss per each instance in the batch separately
247        # and sum it up in the per_instance variable
248        loss = 0.0
249        for single_input, single_target in zip(input_, target):
250            # compare spatial dimensions
251            assert single_input.shape[1:] == single_target.shape[1:], f"{single_input.shape}, {single_target.shape}"
252            assert single_target.shape[0] == 1
253            single_target = single_target[0]
254
255            contains_bg = 0 in single_target
256            ignore_zero_label = self.unlabeled_push and contains_bg
257
258            # get number of instances in the batch instance
259            instance_ids, instance_counts = torch.unique(single_target, return_counts=True)
260
261            # get the number of instances
262            C = instance_ids.size(0)
263
264            # compute mean embeddings (output is of shape CxE)
265            cluster_means = compute_cluster_means(single_input, single_target, C)
266
267            # compute variance term, i.e. pull force
268            variance_term = self._compute_variance_term(
269                cluster_means, single_input, single_target, instance_counts, ignore_zero_label
270            )
271
272            # compute unlabeled push force, i.e. push force between
273            # the mean cluster embeddings and embeddings of background pixels
274            # compute only ignore_zero_label is True, i.e. a given patch contains background label
275            unlabeled_push_term = 0.0
276            if self.unlabeled_push and contains_bg:
277                unlabeled_push_term = self._compute_unlabeled_push(cluster_means, single_input, single_target)
278
279            # compute the instance-based auxiliary loss
280            instance_term = self.compute_instance_term(single_input, cluster_means, single_target)
281
282            # compute distance term, i.e. push force
283            distance_term = self._compute_distance_term(cluster_means, ignore_zero_label)
284
285            # compute regularization term
286            regularization_term = self._compute_regularizer_term(cluster_means)
287
288            # compute total loss and sum it up
289            loss = self.alpha * variance_term + \
290                self.beta * distance_term + \
291                self.gamma * regularization_term + \
292                self.instance_term_weight * instance_term + \
293                self.unlabeled_push_weight * unlabeled_push_term
294
295            loss += loss
296
297        # reduce across the batch dimension
298        return loss.div(n_batches)
299
300
301class ExtendedContrastiveLoss(ContrastiveLossBase):
302    """Contrastive loss extended with instance-based loss term and background push term.
303
304    Based on:
305    "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings":
306    https://arxiv.org/abs/2103.14572
307
308    Args:
309        delta_var: The hinge distance for the variance term in the discriminative loss.
310        delta_dist: The hinge distance for the distance term in the discriminative loss.
311        norm: The norm to use.
312        alpha: Weight for the variance term of the discrimantive loss.
313        beta: Weight for the distance term of the discriminative loss.
314        gamma: Weight for the regularization term of the discriminative loss.
315        unlabeled_push_weight: The weight term for the unlabeled loss term.
316        instance_term_weight: The weight term for the instance loss term.
317        aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'.
318        pmaps_threshold: The probabilit threshold for the background push term.
319        kwargs: Additional keyword arguments for other loss terms.
320    """
321    def __init__(
322        self,
323        delta_var: float,
324        delta_dist: float,
325        norm: str = "fro",
326        alpha: float = 1.0,
327        beta: float = 1.0,
328        gamma: float = 0.001,
329        unlabeled_push_weight: float = 1.0,
330        instance_term_weight: float = 1.0,
331        aux_loss: str = "dice",
332        pmaps_threshold: float = 0.9,
333        **kwargs,
334    ):
335        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
336                         unlabeled_push_weight=unlabeled_push_weight,
337                         instance_term_weight=instance_term_weight)
338        # Init auxiliary loss.
339        assert aux_loss in ["dice", "affinity", "dice_aff"]
340        if aux_loss == "dice":
341            self.aff_loss = None
342            self.dice_loss = DiceLoss()
343        # Additional auxiliary losses.
344        elif aux_loss == "affinity":
345            self.aff_loss = AffinitySideLoss(
346                delta=delta_dist,
347                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
348                n_samples=kwargs.get("n_samples", 9)
349            )
350            self.dice_loss = None
351        elif aux_loss == "dice_aff":
352            # combine dice and affinity side loss
353            self.dice_weight = kwargs.get("dice_weight", 1.0)
354            self.aff_weight = kwargs.get("aff_weight", 1.0)
355
356            self.aff_loss = AffinitySideLoss(
357                delta=delta_dist,
358                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
359                n_samples=kwargs.get("n_samples", 9)
360            )
361            self.dice_loss = DiceLoss()
362
363        # Init dist_to_mask kernel which maps distance to the cluster center to instance probability map.
364        self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold)
365        self.init_kwargs = {
366            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
367            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
368            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold
369        }
370        self.init_kwargs.update(kwargs)
371
372    # FIXME stacking per instance here makes this very memory hungry,
373    def _create_instance_pmaps_and_masks(self, embeddings, anchors, target):
374        inst_pmaps = []
375        inst_masks = []
376
377        if not inst_masks:
378            return None, None
379
380        # stack along batch dimension
381        inst_pmaps = torch.stack(inst_pmaps)
382        inst_masks = torch.stack(inst_masks)
383
384        return inst_pmaps, inst_masks
385
386    def compute_instance_term(self, embeddings, cluster_means, target):
387        """@private
388        """
389        assert embeddings.size()[1:] == target.size()
390
391        if self.aff_loss is None:
392            aff_loss = None
393        else:
394            aff_loss = self.aff_loss(embeddings[None], target[None, None])
395
396        if self.dice_loss is None:
397            dice_loss = None
398        else:
399            dice_loss = []
400
401            # permute embedding dimension at the end
402            if target.dim() == 2:
403                embeddings = embeddings.permute(1, 2, 0)
404            else:
405                embeddings = embeddings.permute(1, 2, 3, 0)
406
407            # compute random anchors per instance
408            instances = torch.unique(target)
409            for i in instances:
410                if i == 0:
411                    continue
412                anchor_emb = cluster_means[i]
413                # FIXME this makes training extremely slow, check with Adrian if this is the latest version
414                # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var)
415
416                distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1)
417                instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0)
418                instance_mask = (target == i).float().unsqueeze(0)
419
420                dice_loss.append(self.dice_loss(instance_pmap, instance_mask))
421
422            dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0
423
424        assert not (dice_loss is None and aff_loss is None)
425        if dice_loss is None and aff_loss is not None:
426            return aff_loss
427        if dice_loss is not None and aff_loss is None:
428            return dice_loss
429        else:
430            return self.dice_weight * dice_loss + self.aff_weight * aff_loss
431
432
433class SPOCOLoss(ExtendedContrastiveLoss):
434    """The full SPOCO Loss for instance segmentation training with sparse instance labels.
435
436    Extends the "classic" contrastive loss with an instance-based term and a unsupervised embedding consistency term.
437    An additional background push term can be added. It is disabled by default because we assume sparse instance labels.
438
439    Based on:
440    "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings":
441    https://arxiv.org/abs/2103.14572
442
443    Args:
444        delta_var: The hinge distance for the variance term in the discriminative loss.
445        delta_dist: The hinge distance for the distance term in the discriminative loss.
446        norm: The norm to use.
447        alpha: Weight for the variance term of the discrimantive loss.
448        beta: Weight for the distance term of the discriminative loss.
449        gamma: Weight for the regularization term of the discriminative loss.
450        unlabeled_push_weight: The weight term for the unlabeled loss term.
451        instance_term_weight: The weight term for the instance loss term.
452        aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'.
453        pmaps_threshold: The probabilit threshold for the background push term.
454        max_anchors: The number of anchors to sample for the consistency term.
455        volume_threshold:
456        kwargs: Additional keyword arguments for other loss terms.
457    """
458    def __init__(
459        self,
460        delta_var: float,
461        delta_dist: float,
462        norm: str = "fro",
463        alpha: float = 1.0,
464        beta: float = 1.0,
465        gamma: float = 0.001,
466        unlabeled_push_weight: float = 0.0,
467        instance_term_weight: float = 1.0,
468        consistency_term_weight: float = 1.0,
469        aux_loss: str = "dice",
470        pmaps_threshold: float = 0.9,
471        max_anchors: int = 20,
472        volume_threshold: float = 0.05,
473        **kwargs,
474    ):
475        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
476                         unlabeled_push_weight=unlabeled_push_weight,
477                         instance_term_weight=instance_term_weight,
478                         aux_loss=aux_loss,
479                         pmaps_threshold=pmaps_threshold,
480                         **kwargs)
481
482        self.consistency_term_weight = consistency_term_weight
483        self.max_anchors = max_anchors
484        self.volume_threshold = volume_threshold
485        self.consistency_loss = DiceLoss()
486        self.init_kwargs = {
487            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
488            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
489            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold,
490            "max_anchors": max_anchors, "volume_threshold": volume_threshold
491        }
492        self.init_kwargs.update(kwargs)
493
494    def __str__(self):
495        return super().__str__() + f"\nconsistency_term_weight: {self.consistency_term_weight}"
496
497    def _inst_pmap(self, emb, anchor):
498        # compute distance map
499        distance_map = torch.norm(emb - anchor, self.norm, dim=-1)
500        # convert distance map to instance pmaps and return
501        return self.dist_to_mask(distance_map)
502
503    def emb_consistency(self, emb_q, emb_k, mask):
504        """@private
505        """
506        inst_q = []
507        inst_k = []
508        for i in range(self.max_anchors):
509            if mask.sum() < self.volume_threshold * mask.numel():
510                break
511
512            # get random anchor
513            indices = torch.nonzero(mask, as_tuple=True)
514            ind = np.random.randint(len(indices[0]))
515
516            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
517            inst_q.append(q_pmap)
518
519            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
520            inst_k.append(k_pmap)
521
522        # stack along channel dim
523        inst_q = torch.stack(inst_q)
524        inst_k = torch.stack(inst_k)
525
526        loss = self.consistency_loss(inst_q, inst_k)
527        return loss
528
529    def _extract_pmap(self, emb, mask, indices, ind):
530        if mask.dim() == 2:
531            y, x = indices
532            anchor = emb[:, y[ind], x[ind]]
533            emb = emb.permute(1, 2, 0)
534        else:
535            z, y, x = indices
536            anchor = emb[:, z[ind], y[ind], x[ind]]
537            emb = emb.permute(1, 2, 3, 0)
538
539        return self._inst_pmap(emb, anchor)
540
541    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
542        """Compute the SPOCO loss.
543
544        Args:
545            input_: The predicted embeddings.
546            target: The segmentation target.
547
548        Returns:
549            The SPOCO loss.
550        """
551        assert len(input_) == 2
552        emb_q, emb_k = input_
553
554        # Compute extended contrastive loss only on the embeddings coming from q.
555        contrastive_loss = super().forward(emb_q, target)
556
557        # TODO enable computing the consistency on all pixels!
558        # Compute consistency term.
559        for e_q, e_k, t in zip(emb_q, emb_k, target):
560            unlabeled_mask = (t[0] == 0).int()
561            if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel():
562                continue
563            emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask)
564            contrastive_loss += self.consistency_term_weight * emb_consistency_loss
565
566        return contrastive_loss
567
568
569class SPOCOConsistencyLoss(nn.Module):
570    """Unsupervised consistency term computed between embeddings.
571
572    Args:
573        delta_var: Hinge distance of the distance loss term.
574        pmaps_threshold:
575        max_anchors: The maximum number of anchors to compute for the consistency loss.
576        norm: The vector norm used. By default the frobenius norm is used.
577    """
578    def __init__(self, delta_var: float, pmaps_threshold: float, max_anchors: int = 30, norm: str = "fro"):
579        super().__init__()
580        self.max_anchors = max_anchors
581        self.consistency_loss = DiceLoss()
582        self.norm = norm
583        self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold)
584        self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold,
585                            "max_anchors": max_anchors, "norm": norm}
586
587    def _inst_pmap(self, emb, anchor):
588        # compute distance map
589        distance_map = torch.norm(emb - anchor, self.norm, dim=-1)
590        # convert distance map to instance pmaps and return
591        return self.dist_to_mask(distance_map)
592
593    def emb_consistency(self, emb_q, emb_k):
594        """@private
595        """
596        inst_q = []
597        inst_k = []
598        mask = torch.ones(emb_q.shape[1:])
599        for i in range(self.max_anchors):
600            # get random anchor
601            indices = torch.nonzero(mask, as_tuple=True)
602            ind = np.random.randint(len(indices[0]))
603
604            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
605            inst_q.append(q_pmap)
606
607            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
608            inst_k.append(k_pmap)
609
610        # stack along channel dim
611        inst_q = torch.stack(inst_q)
612        inst_k = torch.stack(inst_k)
613
614        loss = self.consistency_loss(inst_q, inst_k)
615        return loss
616
617    def _extract_pmap(self, emb, mask, indices, ind):
618        if mask.dim() == 2:
619            y, x = indices
620            anchor = emb[:, y[ind], x[ind]]
621            emb = emb.permute(1, 2, 0)
622        else:
623            z, y, x = indices
624            anchor = emb[:, z[ind], y[ind], x[ind]]
625            emb = emb.permute(1, 2, 3, 0)
626
627        return self._inst_pmap(emb, anchor)
628
629    def forward(self, emb_q: torch.Tensor, emb_k: torch.Tensor) -> torch.Tensor:
630        """Compute the consistency loss term between embeddings.
631
632        Args:
633            emb_q: The first embedding predictions.
634            emb_k: The second embedding predictions.
635
636        Returns:
637            The consistency loss.
638        """
639        contrastive_loss = 0.0
640        # compute consistency term
641        for e_q, e_k in zip(emb_q, emb_k):
642            contrastive_loss += self.emb_consistency(e_q, e_k)
643        return contrastive_loss
def compute_cluster_means( embeddings: torch.Tensor, target: torch.Tensor, n_instances: int) -> torch.Tensor:
17def compute_cluster_means(embeddings: torch.Tensor, target: torch.Tensor, n_instances: int) -> torch.Tensor:
18    """Compute mean embeddings per instance.
19
20    Args:
21        embeddings: The tensor of pixel embeddings with shape: ExSPATIAL. E is the embedding dimension.
22        target: One-hot encoded target instances with shape: SPATIAL.
23        n_instances: The number of instances.
24
25    Returns:
26        The cluster means.
27    """
28    assert scatter_mean is not None, "torch_scatter is required"
29    embeddings = embeddings.flatten(1)
30    target = target.flatten()
31    assert target.min() == 0, \
32        "The target min value has to be zero, otherwise this will lead to errors in scatter."
33    mean_embeddings = scatter_mean(embeddings, target, dim_size=n_instances)
34    return mean_embeddings.transpose(1, 0)

Compute mean embeddings per instance.

Arguments:
  • embeddings: The tensor of pixel embeddings with shape: ExSPATIAL. E is the embedding dimension.
  • target: One-hot encoded target instances with shape: SPATIAL.
  • n_instances: The number of instances.
Returns:

The cluster means.

def select_stable_anchor( embeddings: torch.Tensor, mean_embedding: torch.Tensor, object_mask: torch.Tensor, delta_var: float, norm: str = 'fro') -> torch.Tensor:
37def select_stable_anchor(
38    embeddings: torch.Tensor,
39    mean_embedding: torch.Tensor,
40    object_mask: torch.Tensor,
41    delta_var: float,
42    norm: str = "fro"
43) -> torch.Tensor:
44    """Sample anchor embeddings from the object mask.
45
46    Given a binary mask of an object (`object_mask`) and a `mean_embedding` vector within the mask,
47    the function selects a pixel from the mask at random and returns its embedding only if it's closer than
48    `delta_var` from the `mean_embedding`.
49
50    Args:
51        embeddings: The embeddings, a ExSPATIAL vector field of an image.
52        mean_embedding: The E-dimensional mean of embeddings lying within the `object_mask`.
53        object_mask: Binary image of a selected object.
54        delta_var: The pull force margin of the contrastive loss.
55        norm: The vector norm used. By default the frobenius norm is used.
56
57    Returns:
58        Embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found.
59    """
60    indices = torch.nonzero(object_mask, as_tuple=True)
61    # convert to numpy
62    indices = [t.cpu().numpy() for t in indices]
63
64    # randomize coordinates
65    seed = np.random.randint(np.iinfo("int32").max)
66    for t in indices:
67        rs = np.random.RandomState(seed)
68        rs.shuffle(t)
69
70    for ind in range(len(indices[0])):
71        if object_mask.dim() == 2:
72            y, x = indices
73            anchor_emb = embeddings[:, y[ind], x[ind]]
74            anchor_emb = anchor_emb[..., None, None]
75        else:
76            z, y, x = indices
77            anchor_emb = embeddings[:, z[ind], y[ind], x[ind]]
78            anchor_emb = anchor_emb[..., None, None, None]
79        dist_to_mean = torch.norm(mean_embedding - anchor_emb, norm)
80        if dist_to_mean < delta_var:
81            return anchor_emb
82    # if stable anchor has not been found, return mean_embedding
83    return mean_embedding

Sample anchor embeddings from the object mask.

Given a binary mask of an object (object_mask) and a mean_embedding vector within the mask, the function selects a pixel from the mask at random and returns its embedding only if it's closer than delta_var from the mean_embedding.

Arguments:
  • embeddings: The embeddings, a ExSPATIAL vector field of an image.
  • mean_embedding: The E-dimensional mean of embeddings lying within the object_mask.
  • object_mask: Binary image of a selected object.
  • delta_var: The pull force margin of the contrastive loss.
  • norm: The vector norm used. By default the frobenius norm is used.
Returns:

Embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found.

class ExtendedContrastiveLoss(ContrastiveLossBase):
302class ExtendedContrastiveLoss(ContrastiveLossBase):
303    """Contrastive loss extended with instance-based loss term and background push term.
304
305    Based on:
306    "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings":
307    https://arxiv.org/abs/2103.14572
308
309    Args:
310        delta_var: The hinge distance for the variance term in the discriminative loss.
311        delta_dist: The hinge distance for the distance term in the discriminative loss.
312        norm: The norm to use.
313        alpha: Weight for the variance term of the discrimantive loss.
314        beta: Weight for the distance term of the discriminative loss.
315        gamma: Weight for the regularization term of the discriminative loss.
316        unlabeled_push_weight: The weight term for the unlabeled loss term.
317        instance_term_weight: The weight term for the instance loss term.
318        aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'.
319        pmaps_threshold: The probabilit threshold for the background push term.
320        kwargs: Additional keyword arguments for other loss terms.
321    """
322    def __init__(
323        self,
324        delta_var: float,
325        delta_dist: float,
326        norm: str = "fro",
327        alpha: float = 1.0,
328        beta: float = 1.0,
329        gamma: float = 0.001,
330        unlabeled_push_weight: float = 1.0,
331        instance_term_weight: float = 1.0,
332        aux_loss: str = "dice",
333        pmaps_threshold: float = 0.9,
334        **kwargs,
335    ):
336        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
337                         unlabeled_push_weight=unlabeled_push_weight,
338                         instance_term_weight=instance_term_weight)
339        # Init auxiliary loss.
340        assert aux_loss in ["dice", "affinity", "dice_aff"]
341        if aux_loss == "dice":
342            self.aff_loss = None
343            self.dice_loss = DiceLoss()
344        # Additional auxiliary losses.
345        elif aux_loss == "affinity":
346            self.aff_loss = AffinitySideLoss(
347                delta=delta_dist,
348                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
349                n_samples=kwargs.get("n_samples", 9)
350            )
351            self.dice_loss = None
352        elif aux_loss == "dice_aff":
353            # combine dice and affinity side loss
354            self.dice_weight = kwargs.get("dice_weight", 1.0)
355            self.aff_weight = kwargs.get("aff_weight", 1.0)
356
357            self.aff_loss = AffinitySideLoss(
358                delta=delta_dist,
359                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
360                n_samples=kwargs.get("n_samples", 9)
361            )
362            self.dice_loss = DiceLoss()
363
364        # Init dist_to_mask kernel which maps distance to the cluster center to instance probability map.
365        self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold)
366        self.init_kwargs = {
367            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
368            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
369            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold
370        }
371        self.init_kwargs.update(kwargs)
372
373    # FIXME stacking per instance here makes this very memory hungry,
374    def _create_instance_pmaps_and_masks(self, embeddings, anchors, target):
375        inst_pmaps = []
376        inst_masks = []
377
378        if not inst_masks:
379            return None, None
380
381        # stack along batch dimension
382        inst_pmaps = torch.stack(inst_pmaps)
383        inst_masks = torch.stack(inst_masks)
384
385        return inst_pmaps, inst_masks
386
387    def compute_instance_term(self, embeddings, cluster_means, target):
388        """@private
389        """
390        assert embeddings.size()[1:] == target.size()
391
392        if self.aff_loss is None:
393            aff_loss = None
394        else:
395            aff_loss = self.aff_loss(embeddings[None], target[None, None])
396
397        if self.dice_loss is None:
398            dice_loss = None
399        else:
400            dice_loss = []
401
402            # permute embedding dimension at the end
403            if target.dim() == 2:
404                embeddings = embeddings.permute(1, 2, 0)
405            else:
406                embeddings = embeddings.permute(1, 2, 3, 0)
407
408            # compute random anchors per instance
409            instances = torch.unique(target)
410            for i in instances:
411                if i == 0:
412                    continue
413                anchor_emb = cluster_means[i]
414                # FIXME this makes training extremely slow, check with Adrian if this is the latest version
415                # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var)
416
417                distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1)
418                instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0)
419                instance_mask = (target == i).float().unsqueeze(0)
420
421                dice_loss.append(self.dice_loss(instance_pmap, instance_mask))
422
423            dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0
424
425        assert not (dice_loss is None and aff_loss is None)
426        if dice_loss is None and aff_loss is not None:
427            return aff_loss
428        if dice_loss is not None and aff_loss is None:
429            return dice_loss
430        else:
431            return self.dice_weight * dice_loss + self.aff_weight * aff_loss

Contrastive loss extended with instance-based loss term and background push term.

Based on: "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572

Arguments:
  • delta_var: The hinge distance for the variance term in the discriminative loss.
  • delta_dist: The hinge distance for the distance term in the discriminative loss.
  • norm: The norm to use.
  • alpha: Weight for the variance term of the discrimantive loss.
  • beta: Weight for the distance term of the discriminative loss.
  • gamma: Weight for the regularization term of the discriminative loss.
  • unlabeled_push_weight: The weight term for the unlabeled loss term.
  • instance_term_weight: The weight term for the instance loss term.
  • aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'.
  • pmaps_threshold: The probabilit threshold for the background push term.
  • kwargs: Additional keyword arguments for other loss terms.
ExtendedContrastiveLoss( delta_var: float, delta_dist: float, norm: str = 'fro', alpha: float = 1.0, beta: float = 1.0, gamma: float = 0.001, unlabeled_push_weight: float = 1.0, instance_term_weight: float = 1.0, aux_loss: str = 'dice', pmaps_threshold: float = 0.9, **kwargs)
322    def __init__(
323        self,
324        delta_var: float,
325        delta_dist: float,
326        norm: str = "fro",
327        alpha: float = 1.0,
328        beta: float = 1.0,
329        gamma: float = 0.001,
330        unlabeled_push_weight: float = 1.0,
331        instance_term_weight: float = 1.0,
332        aux_loss: str = "dice",
333        pmaps_threshold: float = 0.9,
334        **kwargs,
335    ):
336        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
337                         unlabeled_push_weight=unlabeled_push_weight,
338                         instance_term_weight=instance_term_weight)
339        # Init auxiliary loss.
340        assert aux_loss in ["dice", "affinity", "dice_aff"]
341        if aux_loss == "dice":
342            self.aff_loss = None
343            self.dice_loss = DiceLoss()
344        # Additional auxiliary losses.
345        elif aux_loss == "affinity":
346            self.aff_loss = AffinitySideLoss(
347                delta=delta_dist,
348                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
349                n_samples=kwargs.get("n_samples", 9)
350            )
351            self.dice_loss = None
352        elif aux_loss == "dice_aff":
353            # combine dice and affinity side loss
354            self.dice_weight = kwargs.get("dice_weight", 1.0)
355            self.aff_weight = kwargs.get("aff_weight", 1.0)
356
357            self.aff_loss = AffinitySideLoss(
358                delta=delta_dist,
359                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
360                n_samples=kwargs.get("n_samples", 9)
361            )
362            self.dice_loss = DiceLoss()
363
364        # Init dist_to_mask kernel which maps distance to the cluster center to instance probability map.
365        self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold)
366        self.init_kwargs = {
367            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
368            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
369            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold
370        }
371        self.init_kwargs.update(kwargs)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dist_to_mask
init_kwargs
class SPOCOLoss(ExtendedContrastiveLoss):
434class SPOCOLoss(ExtendedContrastiveLoss):
435    """The full SPOCO Loss for instance segmentation training with sparse instance labels.
436
437    Extends the "classic" contrastive loss with an instance-based term and a unsupervised embedding consistency term.
438    An additional background push term can be added. It is disabled by default because we assume sparse instance labels.
439
440    Based on:
441    "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings":
442    https://arxiv.org/abs/2103.14572
443
444    Args:
445        delta_var: The hinge distance for the variance term in the discriminative loss.
446        delta_dist: The hinge distance for the distance term in the discriminative loss.
447        norm: The norm to use.
448        alpha: Weight for the variance term of the discrimantive loss.
449        beta: Weight for the distance term of the discriminative loss.
450        gamma: Weight for the regularization term of the discriminative loss.
451        unlabeled_push_weight: The weight term for the unlabeled loss term.
452        instance_term_weight: The weight term for the instance loss term.
453        aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'.
454        pmaps_threshold: The probabilit threshold for the background push term.
455        max_anchors: The number of anchors to sample for the consistency term.
456        volume_threshold:
457        kwargs: Additional keyword arguments for other loss terms.
458    """
459    def __init__(
460        self,
461        delta_var: float,
462        delta_dist: float,
463        norm: str = "fro",
464        alpha: float = 1.0,
465        beta: float = 1.0,
466        gamma: float = 0.001,
467        unlabeled_push_weight: float = 0.0,
468        instance_term_weight: float = 1.0,
469        consistency_term_weight: float = 1.0,
470        aux_loss: str = "dice",
471        pmaps_threshold: float = 0.9,
472        max_anchors: int = 20,
473        volume_threshold: float = 0.05,
474        **kwargs,
475    ):
476        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
477                         unlabeled_push_weight=unlabeled_push_weight,
478                         instance_term_weight=instance_term_weight,
479                         aux_loss=aux_loss,
480                         pmaps_threshold=pmaps_threshold,
481                         **kwargs)
482
483        self.consistency_term_weight = consistency_term_weight
484        self.max_anchors = max_anchors
485        self.volume_threshold = volume_threshold
486        self.consistency_loss = DiceLoss()
487        self.init_kwargs = {
488            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
489            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
490            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold,
491            "max_anchors": max_anchors, "volume_threshold": volume_threshold
492        }
493        self.init_kwargs.update(kwargs)
494
495    def __str__(self):
496        return super().__str__() + f"\nconsistency_term_weight: {self.consistency_term_weight}"
497
498    def _inst_pmap(self, emb, anchor):
499        # compute distance map
500        distance_map = torch.norm(emb - anchor, self.norm, dim=-1)
501        # convert distance map to instance pmaps and return
502        return self.dist_to_mask(distance_map)
503
504    def emb_consistency(self, emb_q, emb_k, mask):
505        """@private
506        """
507        inst_q = []
508        inst_k = []
509        for i in range(self.max_anchors):
510            if mask.sum() < self.volume_threshold * mask.numel():
511                break
512
513            # get random anchor
514            indices = torch.nonzero(mask, as_tuple=True)
515            ind = np.random.randint(len(indices[0]))
516
517            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
518            inst_q.append(q_pmap)
519
520            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
521            inst_k.append(k_pmap)
522
523        # stack along channel dim
524        inst_q = torch.stack(inst_q)
525        inst_k = torch.stack(inst_k)
526
527        loss = self.consistency_loss(inst_q, inst_k)
528        return loss
529
530    def _extract_pmap(self, emb, mask, indices, ind):
531        if mask.dim() == 2:
532            y, x = indices
533            anchor = emb[:, y[ind], x[ind]]
534            emb = emb.permute(1, 2, 0)
535        else:
536            z, y, x = indices
537            anchor = emb[:, z[ind], y[ind], x[ind]]
538            emb = emb.permute(1, 2, 3, 0)
539
540        return self._inst_pmap(emb, anchor)
541
542    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
543        """Compute the SPOCO loss.
544
545        Args:
546            input_: The predicted embeddings.
547            target: The segmentation target.
548
549        Returns:
550            The SPOCO loss.
551        """
552        assert len(input_) == 2
553        emb_q, emb_k = input_
554
555        # Compute extended contrastive loss only on the embeddings coming from q.
556        contrastive_loss = super().forward(emb_q, target)
557
558        # TODO enable computing the consistency on all pixels!
559        # Compute consistency term.
560        for e_q, e_k, t in zip(emb_q, emb_k, target):
561            unlabeled_mask = (t[0] == 0).int()
562            if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel():
563                continue
564            emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask)
565            contrastive_loss += self.consistency_term_weight * emb_consistency_loss
566
567        return contrastive_loss

The full SPOCO Loss for instance segmentation training with sparse instance labels.

Extends the "classic" contrastive loss with an instance-based term and a unsupervised embedding consistency term. An additional background push term can be added. It is disabled by default because we assume sparse instance labels.

Based on: "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572

Arguments:
  • delta_var: The hinge distance for the variance term in the discriminative loss.
  • delta_dist: The hinge distance for the distance term in the discriminative loss.
  • norm: The norm to use.
  • alpha: Weight for the variance term of the discrimantive loss.
  • beta: Weight for the distance term of the discriminative loss.
  • gamma: Weight for the regularization term of the discriminative loss.
  • unlabeled_push_weight: The weight term for the unlabeled loss term.
  • instance_term_weight: The weight term for the instance loss term.
  • aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'.
  • pmaps_threshold: The probabilit threshold for the background push term.
  • max_anchors: The number of anchors to sample for the consistency term.
  • volume_threshold:
  • kwargs: Additional keyword arguments for other loss terms.
SPOCOLoss( delta_var: float, delta_dist: float, norm: str = 'fro', alpha: float = 1.0, beta: float = 1.0, gamma: float = 0.001, unlabeled_push_weight: float = 0.0, instance_term_weight: float = 1.0, consistency_term_weight: float = 1.0, aux_loss: str = 'dice', pmaps_threshold: float = 0.9, max_anchors: int = 20, volume_threshold: float = 0.05, **kwargs)
459    def __init__(
460        self,
461        delta_var: float,
462        delta_dist: float,
463        norm: str = "fro",
464        alpha: float = 1.0,
465        beta: float = 1.0,
466        gamma: float = 0.001,
467        unlabeled_push_weight: float = 0.0,
468        instance_term_weight: float = 1.0,
469        consistency_term_weight: float = 1.0,
470        aux_loss: str = "dice",
471        pmaps_threshold: float = 0.9,
472        max_anchors: int = 20,
473        volume_threshold: float = 0.05,
474        **kwargs,
475    ):
476        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
477                         unlabeled_push_weight=unlabeled_push_weight,
478                         instance_term_weight=instance_term_weight,
479                         aux_loss=aux_loss,
480                         pmaps_threshold=pmaps_threshold,
481                         **kwargs)
482
483        self.consistency_term_weight = consistency_term_weight
484        self.max_anchors = max_anchors
485        self.volume_threshold = volume_threshold
486        self.consistency_loss = DiceLoss()
487        self.init_kwargs = {
488            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
489            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
490            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold,
491            "max_anchors": max_anchors, "volume_threshold": volume_threshold
492        }
493        self.init_kwargs.update(kwargs)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

consistency_term_weight
max_anchors
volume_threshold
consistency_loss
init_kwargs
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
542    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
543        """Compute the SPOCO loss.
544
545        Args:
546            input_: The predicted embeddings.
547            target: The segmentation target.
548
549        Returns:
550            The SPOCO loss.
551        """
552        assert len(input_) == 2
553        emb_q, emb_k = input_
554
555        # Compute extended contrastive loss only on the embeddings coming from q.
556        contrastive_loss = super().forward(emb_q, target)
557
558        # TODO enable computing the consistency on all pixels!
559        # Compute consistency term.
560        for e_q, e_k, t in zip(emb_q, emb_k, target):
561            unlabeled_mask = (t[0] == 0).int()
562            if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel():
563                continue
564            emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask)
565            contrastive_loss += self.consistency_term_weight * emb_consistency_loss
566
567        return contrastive_loss

Compute the SPOCO loss.

Arguments:
  • input_: The predicted embeddings.
  • target: The segmentation target.
Returns:

The SPOCO loss.

class SPOCOConsistencyLoss(torch.nn.modules.module.Module):
570class SPOCOConsistencyLoss(nn.Module):
571    """Unsupervised consistency term computed between embeddings.
572
573    Args:
574        delta_var: Hinge distance of the distance loss term.
575        pmaps_threshold:
576        max_anchors: The maximum number of anchors to compute for the consistency loss.
577        norm: The vector norm used. By default the frobenius norm is used.
578    """
579    def __init__(self, delta_var: float, pmaps_threshold: float, max_anchors: int = 30, norm: str = "fro"):
580        super().__init__()
581        self.max_anchors = max_anchors
582        self.consistency_loss = DiceLoss()
583        self.norm = norm
584        self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold)
585        self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold,
586                            "max_anchors": max_anchors, "norm": norm}
587
588    def _inst_pmap(self, emb, anchor):
589        # compute distance map
590        distance_map = torch.norm(emb - anchor, self.norm, dim=-1)
591        # convert distance map to instance pmaps and return
592        return self.dist_to_mask(distance_map)
593
594    def emb_consistency(self, emb_q, emb_k):
595        """@private
596        """
597        inst_q = []
598        inst_k = []
599        mask = torch.ones(emb_q.shape[1:])
600        for i in range(self.max_anchors):
601            # get random anchor
602            indices = torch.nonzero(mask, as_tuple=True)
603            ind = np.random.randint(len(indices[0]))
604
605            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
606            inst_q.append(q_pmap)
607
608            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
609            inst_k.append(k_pmap)
610
611        # stack along channel dim
612        inst_q = torch.stack(inst_q)
613        inst_k = torch.stack(inst_k)
614
615        loss = self.consistency_loss(inst_q, inst_k)
616        return loss
617
618    def _extract_pmap(self, emb, mask, indices, ind):
619        if mask.dim() == 2:
620            y, x = indices
621            anchor = emb[:, y[ind], x[ind]]
622            emb = emb.permute(1, 2, 0)
623        else:
624            z, y, x = indices
625            anchor = emb[:, z[ind], y[ind], x[ind]]
626            emb = emb.permute(1, 2, 3, 0)
627
628        return self._inst_pmap(emb, anchor)
629
630    def forward(self, emb_q: torch.Tensor, emb_k: torch.Tensor) -> torch.Tensor:
631        """Compute the consistency loss term between embeddings.
632
633        Args:
634            emb_q: The first embedding predictions.
635            emb_k: The second embedding predictions.
636
637        Returns:
638            The consistency loss.
639        """
640        contrastive_loss = 0.0
641        # compute consistency term
642        for e_q, e_k in zip(emb_q, emb_k):
643            contrastive_loss += self.emb_consistency(e_q, e_k)
644        return contrastive_loss

Unsupervised consistency term computed between embeddings.

Arguments:
  • delta_var: Hinge distance of the distance loss term.
  • pmaps_threshold:
  • max_anchors: The maximum number of anchors to compute for the consistency loss.
  • norm: The vector norm used. By default the frobenius norm is used.
SPOCOConsistencyLoss( delta_var: float, pmaps_threshold: float, max_anchors: int = 30, norm: str = 'fro')
579    def __init__(self, delta_var: float, pmaps_threshold: float, max_anchors: int = 30, norm: str = "fro"):
580        super().__init__()
581        self.max_anchors = max_anchors
582        self.consistency_loss = DiceLoss()
583        self.norm = norm
584        self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold)
585        self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold,
586                            "max_anchors": max_anchors, "norm": norm}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

max_anchors
consistency_loss
norm
dist_to_mask
init_kwargs
def forward(self, emb_q: torch.Tensor, emb_k: torch.Tensor) -> torch.Tensor:
630    def forward(self, emb_q: torch.Tensor, emb_k: torch.Tensor) -> torch.Tensor:
631        """Compute the consistency loss term between embeddings.
632
633        Args:
634            emb_q: The first embedding predictions.
635            emb_k: The second embedding predictions.
636
637        Returns:
638            The consistency loss.
639        """
640        contrastive_loss = 0.0
641        # compute consistency term
642        for e_q, e_k in zip(emb_q, emb_k):
643            contrastive_loss += self.emb_consistency(e_q, e_k)
644        return contrastive_loss

Compute the consistency loss term between embeddings.

Arguments:
  • emb_q: The first embedding predictions.
  • emb_k: The second embedding predictions.
Returns:

The consistency loss.