torch_em.loss.spoco_loss

  1import math
  2
  3import numpy as np
  4import torch
  5import torch.nn as nn
  6try:
  7    from torch_scatter import scatter_mean
  8except ImportError:
  9    scatter_mean = None
 10
 11from . import contrastive_impl as cimpl
 12from .affinity_side_loss import AffinitySideLoss
 13from .dice import DiceLoss
 14
 15
 16def compute_cluster_means(embeddings, target, n_instances):
 17    """
 18    Computes mean embeddings per instance.
 19    E - embedding dimension
 20
 21    Args:
 22        embeddings: tensor of pixel embeddings, shape: ExSPATIAL
 23        target: one-hot encoded target instances, shape: SPATIAL
 24        n_instances: number of instances
 25    """
 26    assert scatter_mean is not None, "torch_scatter is required"
 27    embeddings = embeddings.flatten(1)
 28    target = target.flatten()
 29    assert target.min() == 0,\
 30        "The target min value has to be zero, otherwise this will lead to errors in scatter"
 31    mean_embeddings = scatter_mean(embeddings, target, dim_size=n_instances)
 32    return mean_embeddings.transpose(1, 0)
 33
 34
 35def select_stable_anchor(embeddings, mean_embedding, object_mask, delta_var, norm="fro"):
 36    """
 37    Anchor sampling procedure. Given a binary mask of an object (`object_mask`) and a `mean_embedding` vector within
 38    the mask, the function selects a pixel from the mask at random and returns its embedding only if it"s closer than
 39    `delta_var` from the `mean_embedding`.
 40
 41    Args:
 42        embeddings (torch.Tensor): ExSpatial vector field of an image
 43        mean_embedding (torch.Tensor): E-dimensional mean of embeddings lying within the `object_mask`
 44        object_mask (torch.Tensor): binary image of a selected object
 45        delta_var (float): contrastive loss, pull force margin
 46        norm (str): vector norm used, default: Frobenius norm
 47
 48    Returns:
 49        embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found
 50    """
 51    indices = torch.nonzero(object_mask, as_tuple=True)
 52    # convert to numpy
 53    indices = [t.cpu().numpy() for t in indices]
 54
 55    # randomize coordinates
 56    seed = np.random.randint(np.iinfo("int32").max)
 57    for t in indices:
 58        rs = np.random.RandomState(seed)
 59        rs.shuffle(t)
 60
 61    for ind in range(len(indices[0])):
 62        if object_mask.dim() == 2:
 63            y, x = indices
 64            anchor_emb = embeddings[:, y[ind], x[ind]]
 65            anchor_emb = anchor_emb[..., None, None]
 66        else:
 67            z, y, x = indices
 68            anchor_emb = embeddings[:, z[ind], y[ind], x[ind]]
 69            anchor_emb = anchor_emb[..., None, None, None]
 70        dist_to_mean = torch.norm(mean_embedding - anchor_emb, norm)
 71        if dist_to_mean < delta_var:
 72            return anchor_emb
 73    # if stable anchor has not been found, return mean_embedding
 74    return mean_embedding
 75
 76
 77class GaussianKernel(nn.Module):
 78    def __init__(self, delta_var, pmaps_threshold):
 79        super().__init__()
 80        self.delta_var = delta_var
 81        # dist_var^2 = -2*sigma*ln(pmaps_threshold)
 82        self.two_sigma = delta_var * delta_var / (-math.log(pmaps_threshold))
 83
 84    def forward(self, dist_map):
 85        return torch.exp(- dist_map * dist_map / self.two_sigma)
 86
 87
 88class CombinedAuxLoss(nn.Module):
 89    def __init__(self, losses, weights):
 90        super().__init__()
 91        self.losses = losses
 92        self.weights = weights
 93
 94    def forward(self, embeddings, target, instance_pmaps, instance_masks):
 95        result = 0.
 96        for loss, weight in zip(self.losses, self.weights):
 97            if isinstance(loss, AffinitySideLoss):
 98                # add batch axis / batch and channel axis for embeddings, target
 99                result += weight * loss(embeddings[None], target[None, None])
100            elif instance_masks is not None:
101                result += weight * loss(instance_pmaps, instance_masks).mean()
102        return result
103
104
105class ContrastiveLossBase(nn.Module):
106    """Base class for the spoco losses.
107    """
108
109    def __init__(self, delta_var, delta_dist,
110                 norm="fro", alpha=1., beta=1., gamma=0.001, unlabeled_push_weight=0.0,
111                 instance_term_weight=1.0, impl=None):
112        assert scatter_mean is not None, "Spoco loss requires pytorch_scatter"
113        super().__init__()
114        self.delta_var = delta_var
115        self.delta_dist = delta_dist
116        self.norm = norm
117        self.alpha = alpha
118        self.beta = beta
119        self.gamma = gamma
120        self.unlabeled_push_weight = unlabeled_push_weight
121        self.unlabeled_push = unlabeled_push_weight > 0
122        self.instance_term_weight = instance_term_weight
123
124    def __str__(self):
125        return super().__str__() + f"\ndelta_var: {self.delta_var}\ndelta_dist: {self.delta_dist}" \
126                                   f"\nalpha: {self.alpha}\nbeta: {self.beta}\ngamma: {self.gamma}" \
127                                   f"\nunlabeled_push_weight: {self.unlabeled_push_weight}" \
128                                   f"\ninstance_term_weight: {self.instance_term_weight}"
129
130    def _compute_variance_term(self, cluster_means, embeddings, target, instance_counts, ignore_zero_label):
131        """Computes the variance term, i.e. intra-cluster pull force that draws embeddings towards the mean embedding
132
133        C - number of clusters (instances)
134        E - embedding dimension
135        SPATIAL - volume shape, i.e. DxHxW for 3D/ HxW for 2D
136
137        Args:
138            cluster_means: mean embedding of each instance, tensor (CxE)
139            embeddings: embeddings vectors per instance, tensor (ExSPATIAL)
140            target: label tensor (1xSPATIAL); each label is represented as one-hot vector
141            instance_counts: number of voxels per instance
142            ignore_zero_label: if True ignores the cluster corresponding to the 0-label
143        """
144        assert target.dim() in (2, 3)
145        ignore_labels = [0] if ignore_zero_label else None
146        return cimpl._compute_variance_term_scatter(
147            cluster_means, embeddings.unsqueeze(0), target.unsqueeze(0),
148            self.norm, self.delta_var, instance_counts, ignore_labels
149        )
150
151    def _compute_unlabeled_push(self, cluster_means, embeddings, target):
152        assert target.dim() in (2, 3)
153        n_instances = cluster_means.shape[0]
154
155        # permute embedding dimension at the end
156        if target.dim() == 2:
157            embeddings = embeddings.permute(1, 2, 0)
158        else:
159            embeddings = embeddings.permute(1, 2, 3, 0)
160
161        # decrease number of instances `C` since we're ignoring 0-label
162        n_instances -= 1
163        # if there is only 0-label in the target return 0
164        if n_instances == 0:
165            return 0.0
166
167        background_mask = target == 0
168        n_background = background_mask.sum()
169        background_push = 0.0
170        # skip embedding corresponding to the background pixels
171        for cluster_mean in cluster_means[1:]:
172            # compute distances between embeddings and a given cluster_mean
173            dist_to_mean = torch.norm(embeddings - cluster_mean, self.norm, dim=-1)
174            # apply background mask and compute hinge
175            dist_hinged = torch.clamp((self.delta_dist - dist_to_mean) * background_mask, min=0) ** 2
176            background_push += torch.sum(dist_hinged) / n_background
177
178        # normalize by the number of instances
179        return background_push / n_instances
180
181    # def _compute_distance_term_scatter(cluster_means, norm, delta_dist):
182    def _compute_distance_term(self, cluster_means, ignore_zero_label):
183        """
184        Compute the distance term, i.e an inter-cluster push-force that pushes clusters away from each other, increasing
185        the distance between cluster centers
186
187        Args:
188            cluster_means: mean embedding of each instance, tensor (CxE)
189            ignore_zero_label: if True ignores the cluster corresponding to the 0-label
190        """
191        ignore_labels = [0] if ignore_zero_label else None
192        return cimpl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist, ignore_labels)
193
194    def _compute_regularizer_term(self, cluster_means):
195        """
196        Computes the regularizer term, i.e. a small pull-force that draws all clusters towards origin to keep
197        the network activations bounded
198        """
199        # compute the norm of the mean embeddings
200        norms = torch.norm(cluster_means, p=self.norm, dim=1)
201        # return the average norm per batch
202        return torch.sum(norms) / cluster_means.size(0)
203
204    def compute_instance_term(self, embeddings, cluster_means, target):
205        """Computes auxiliary loss based on embeddings and a given list of target
206        instances together with their mean embeddings.
207
208        Args:
209            embeddings (torch.tensor): pixel embeddings (ExSPATIAL)
210            cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL)
211            target (torch.tensor): ground truth instance segmentation (SPATIAL)
212
213        Returns:
214            float: value of the instance-based term
215        """
216        raise NotImplementedError
217
218    def forward(self, input_, target):
219        """
220        Args:
221             input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims)
222                expects float32 tensor
223             target (torch.tensor): ground truth instance segmentation (Nx1DxHxW)
224                expects int64 tensor
225        Returns:
226            Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term
227                + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term
228        """
229        # enable calling this loss from the spoco trainer, which passes a tuple
230        if isinstance(input_, tuple):
231            assert len(input_) == 2
232            input_ = input_[0]
233
234        n_batches = input_.shape[0]
235        # compute the loss per each instance in the batch separately
236        # and sum it up in the per_instance variable
237        loss = 0.0
238        for single_input, single_target in zip(input_, target):
239            # compare spatial dimensions
240            assert single_input.shape[1:] == single_target.shape[1:], f"{single_input.shape}, {single_target.shape}"
241            assert single_target.shape[0] == 1
242            single_target = single_target[0]
243
244            contains_bg = 0 in single_target
245            ignore_zero_label = self.unlabeled_push and contains_bg
246
247            # get number of instances in the batch instance
248            instance_ids, instance_counts = torch.unique(single_target, return_counts=True)
249
250            # get the number of instances
251            C = instance_ids.size(0)
252
253            # compute mean embeddings (output is of shape CxE)
254            cluster_means = compute_cluster_means(single_input, single_target, C)
255
256            # compute variance term, i.e. pull force
257            variance_term = self._compute_variance_term(
258                cluster_means, single_input, single_target, instance_counts, ignore_zero_label
259            )
260
261            # compute unlabeled push force, i.e. push force between
262            # the mean cluster embeddings and embeddings of background pixels
263            # compute only ignore_zero_label is True, i.e. a given patch contains background label
264            unlabeled_push_term = 0.0
265            if self.unlabeled_push and contains_bg:
266                unlabeled_push_term = self._compute_unlabeled_push(cluster_means, single_input, single_target)
267
268            # compute the instance-based auxiliary loss
269            instance_term = self.compute_instance_term(single_input, cluster_means, single_target)
270
271            # compute distance term, i.e. push force
272            distance_term = self._compute_distance_term(cluster_means, ignore_zero_label)
273
274            # compute regularization term
275            regularization_term = self._compute_regularizer_term(cluster_means)
276
277            # compute total loss and sum it up
278            loss = self.alpha * variance_term + \
279                self.beta * distance_term + \
280                self.gamma * regularization_term + \
281                self.instance_term_weight * instance_term + \
282                self.unlabeled_push_weight * unlabeled_push_term
283
284            loss += loss
285
286        # reduce across the batch dimension
287        return loss.div(n_batches)
288
289
290class ExtendedContrastiveLoss(ContrastiveLossBase):
291    """Contrastive loss extended with instance-based loss term and background push term.
292
293    Based on:
294    "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572
295    """
296
297    def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001,
298                 unlabeled_push_weight=1.0, instance_term_weight=1.0, aux_loss="dice", pmaps_threshold=0.9, **kwargs):
299
300        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
301                         unlabeled_push_weight=unlabeled_push_weight,
302                         instance_term_weight=instance_term_weight)
303
304        # init auxiliary loss
305        assert aux_loss in ["dice", "affinity", "dice_aff"]
306        if aux_loss == "dice":
307            self.aff_loss = None
308            self.dice_loss = DiceLoss()
309        # additional auxiliary losses
310        elif aux_loss == "affinity":
311            self.aff_loss = AffinitySideLoss(
312                delta=delta_dist,
313                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
314                n_samples=kwargs.get("n_samples", 9)
315            )
316            self.dice_loss = None
317        elif aux_loss == "dice_aff":
318            # combine dice and affinity side loss
319            self.dice_weight = kwargs.get("dice_weight", 1.0)
320            self.aff_weight = kwargs.get("aff_weight", 1.0)
321
322            self.aff_loss = AffinitySideLoss(
323                delta=delta_dist,
324                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
325                n_samples=kwargs.get("n_samples", 9)
326            )
327            self.dice_loss = DiceLoss()
328
329        # init dist_to_mask kernel which maps distance to the cluster center to instance probability map
330        self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold)
331        self.init_kwargs = {
332            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
333            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
334            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold
335        }
336        self.init_kwargs.update(kwargs)
337
338    # FIXME stacking per instance here makes this very memory hungry,
339    def _create_instance_pmaps_and_masks(self, embeddings, anchors, target):
340        inst_pmaps = []
341        inst_masks = []
342
343        if not inst_masks:
344            return None, None
345
346        # stack along batch dimension
347        inst_pmaps = torch.stack(inst_pmaps)
348        inst_masks = torch.stack(inst_masks)
349
350        return inst_pmaps, inst_masks
351
352    def compute_instance_term(self, embeddings, cluster_means, target):
353        assert embeddings.size()[1:] == target.size()
354
355        if self.aff_loss is None:
356            aff_loss = None
357        else:
358            aff_loss = self.aff_loss(embeddings[None], target[None, None])
359
360        if self.dice_loss is None:
361            dice_loss = None
362        else:
363            dice_loss = []
364
365            # permute embedding dimension at the end
366            if target.dim() == 2:
367                embeddings = embeddings.permute(1, 2, 0)
368            else:
369                embeddings = embeddings.permute(1, 2, 3, 0)
370
371            # compute random anchors per instance
372            instances = torch.unique(target)
373            for i in instances:
374                if i == 0:
375                    continue
376                anchor_emb = cluster_means[i]
377                # FIXME this makes training extremely slow, check with Adrian if this is the latest version
378                # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var)
379
380                distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1)
381                instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0)
382                instance_mask = (target == i).float().unsqueeze(0)
383
384                dice_loss.append(self.dice_loss(instance_pmap, instance_mask))
385
386            dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0
387
388        assert not (dice_loss is None and aff_loss is None)
389        if dice_loss is None and aff_loss is not None:
390            return aff_loss
391        if dice_loss is not None and aff_loss is None:
392            return dice_loss
393        else:
394            return self.dice_weight * dice_loss + self.aff_weight * aff_loss
395
396
397class SPOCOLoss(ExtendedContrastiveLoss):
398    """The full SPOCO Loss for instance segmentation training with sparse instance labels.
399
400    Extends the "classic" contrastive loss with an instance-based term and a embedding consistency term.
401    (The unlabeled push term is turned off by default, since we assume sparse instance labels).
402
403    Based on:
404    "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572
405    """
406
407    def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001,
408                 unlabeled_push_weight=0.0, instance_term_weight=1.0, consistency_term_weight=1.0,
409                 aux_loss="dice", pmaps_threshold=0.9, max_anchors=20, volume_threshold=0.05, **kwargs):
410
411        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
412                         unlabeled_push_weight=unlabeled_push_weight,
413                         instance_term_weight=instance_term_weight,
414                         aux_loss=aux_loss,
415                         pmaps_threshold=pmaps_threshold,
416                         **kwargs)
417
418        self.consistency_term_weight = consistency_term_weight
419        self.max_anchors = max_anchors
420        self.volume_threshold = volume_threshold
421        self.consistency_loss = DiceLoss()
422        self.init_kwargs = {
423            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
424            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
425            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold,
426            "max_anchors": max_anchors, "volume_threshold": volume_threshold
427        }
428        self.init_kwargs.update(kwargs)
429
430    def __str__(self):
431        return super().__str__() + f"\nconsistency_term_weight: {self.consistency_term_weight}"
432
433    def _inst_pmap(self, emb, anchor):
434        # compute distance map
435        distance_map = torch.norm(emb - anchor, self.norm, dim=-1)
436        # convert distance map to instance pmaps and return
437        return self.dist_to_mask(distance_map)
438
439    def emb_consistency(self, emb_q, emb_k, mask):
440        inst_q = []
441        inst_k = []
442        for i in range(self.max_anchors):
443            if mask.sum() < self.volume_threshold * mask.numel():
444                break
445
446            # get random anchor
447            indices = torch.nonzero(mask, as_tuple=True)
448            ind = np.random.randint(len(indices[0]))
449
450            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
451            inst_q.append(q_pmap)
452
453            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
454            inst_k.append(k_pmap)
455
456        # stack along channel dim
457        inst_q = torch.stack(inst_q)
458        inst_k = torch.stack(inst_k)
459
460        loss = self.consistency_loss(inst_q, inst_k)
461        return loss
462
463    def _extract_pmap(self, emb, mask, indices, ind):
464        if mask.dim() == 2:
465            y, x = indices
466            anchor = emb[:, y[ind], x[ind]]
467            emb = emb.permute(1, 2, 0)
468        else:
469            z, y, x = indices
470            anchor = emb[:, z[ind], y[ind], x[ind]]
471            emb = emb.permute(1, 2, 3, 0)
472
473        return self._inst_pmap(emb, anchor)
474
475    def forward(self, input, target):
476        assert len(input) == 2
477        emb_q, emb_k = input
478
479        # compute extended contrastive loss only on the embeddings coming from q
480        contrastive_loss = super().forward(emb_q, target)
481
482        # TODO enable computing the consistency on all pixels!
483        # compute consistency term
484        for e_q, e_k, t in zip(emb_q, emb_k, target):
485            unlabeled_mask = (t[0] == 0).int()
486            if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel():
487                continue
488            emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask)
489            contrastive_loss += self.consistency_term_weight * emb_consistency_loss
490
491        return contrastive_loss
492
493
494# FIXME clarify what this is!
495class SPOCOConsistencyLoss(nn.Module):
496    def __init__(self, delta_var, pmaps_threshold, max_anchors=30, norm="fro"):
497        super().__init__()
498        self.max_anchors = max_anchors
499        self.consistency_loss = DiceLoss()
500        self.norm = norm
501        self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold)
502        self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold,
503                            "max_anchors": max_anchors, "norm": norm}
504
505    def _inst_pmap(self, emb, anchor):
506        # compute distance map
507        distance_map = torch.norm(emb - anchor, self.norm, dim=-1)
508        # convert distance map to instance pmaps and return
509        return self.dist_to_mask(distance_map)
510
511    def emb_consistency(self, emb_q, emb_k):
512        inst_q = []
513        inst_k = []
514        mask = torch.ones(emb_q.shape[1:])
515        for i in range(self.max_anchors):
516            # get random anchor
517            indices = torch.nonzero(mask, as_tuple=True)
518            ind = np.random.randint(len(indices[0]))
519
520            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
521            inst_q.append(q_pmap)
522
523            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
524            inst_k.append(k_pmap)
525
526        # stack along channel dim
527        inst_q = torch.stack(inst_q)
528        inst_k = torch.stack(inst_k)
529
530        loss = self.consistency_loss(inst_q, inst_k)
531        return loss
532
533    def _extract_pmap(self, emb, mask, indices, ind):
534        if mask.dim() == 2:
535            y, x = indices
536            anchor = emb[:, y[ind], x[ind]]
537            emb = emb.permute(1, 2, 0)
538        else:
539            z, y, x = indices
540            anchor = emb[:, z[ind], y[ind], x[ind]]
541            emb = emb.permute(1, 2, 3, 0)
542
543        return self._inst_pmap(emb, anchor)
544
545    def forward(self, emb_q, emb_k):
546        contrastive_loss = 0.0
547        # compute consistency term
548        for e_q, e_k in zip(emb_q, emb_k):
549            contrastive_loss += self.emb_consistency(e_q, e_k)
550        return contrastive_loss
def compute_cluster_means(embeddings, target, n_instances):
17def compute_cluster_means(embeddings, target, n_instances):
18    """
19    Computes mean embeddings per instance.
20    E - embedding dimension
21
22    Args:
23        embeddings: tensor of pixel embeddings, shape: ExSPATIAL
24        target: one-hot encoded target instances, shape: SPATIAL
25        n_instances: number of instances
26    """
27    assert scatter_mean is not None, "torch_scatter is required"
28    embeddings = embeddings.flatten(1)
29    target = target.flatten()
30    assert target.min() == 0,\
31        "The target min value has to be zero, otherwise this will lead to errors in scatter"
32    mean_embeddings = scatter_mean(embeddings, target, dim_size=n_instances)
33    return mean_embeddings.transpose(1, 0)

Computes mean embeddings per instance. E - embedding dimension

Arguments:
  • embeddings: tensor of pixel embeddings, shape: ExSPATIAL
  • target: one-hot encoded target instances, shape: SPATIAL
  • n_instances: number of instances
def select_stable_anchor(embeddings, mean_embedding, object_mask, delta_var, norm='fro'):
36def select_stable_anchor(embeddings, mean_embedding, object_mask, delta_var, norm="fro"):
37    """
38    Anchor sampling procedure. Given a binary mask of an object (`object_mask`) and a `mean_embedding` vector within
39    the mask, the function selects a pixel from the mask at random and returns its embedding only if it"s closer than
40    `delta_var` from the `mean_embedding`.
41
42    Args:
43        embeddings (torch.Tensor): ExSpatial vector field of an image
44        mean_embedding (torch.Tensor): E-dimensional mean of embeddings lying within the `object_mask`
45        object_mask (torch.Tensor): binary image of a selected object
46        delta_var (float): contrastive loss, pull force margin
47        norm (str): vector norm used, default: Frobenius norm
48
49    Returns:
50        embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found
51    """
52    indices = torch.nonzero(object_mask, as_tuple=True)
53    # convert to numpy
54    indices = [t.cpu().numpy() for t in indices]
55
56    # randomize coordinates
57    seed = np.random.randint(np.iinfo("int32").max)
58    for t in indices:
59        rs = np.random.RandomState(seed)
60        rs.shuffle(t)
61
62    for ind in range(len(indices[0])):
63        if object_mask.dim() == 2:
64            y, x = indices
65            anchor_emb = embeddings[:, y[ind], x[ind]]
66            anchor_emb = anchor_emb[..., None, None]
67        else:
68            z, y, x = indices
69            anchor_emb = embeddings[:, z[ind], y[ind], x[ind]]
70            anchor_emb = anchor_emb[..., None, None, None]
71        dist_to_mean = torch.norm(mean_embedding - anchor_emb, norm)
72        if dist_to_mean < delta_var:
73            return anchor_emb
74    # if stable anchor has not been found, return mean_embedding
75    return mean_embedding

Anchor sampling procedure. Given a binary mask of an object (object_mask) and a mean_embedding vector within the mask, the function selects a pixel from the mask at random and returns its embedding only if it"s closer than delta_var from the mean_embedding.

Arguments:
  • embeddings (torch.Tensor): ExSpatial vector field of an image
  • mean_embedding (torch.Tensor): E-dimensional mean of embeddings lying within the object_mask
  • object_mask (torch.Tensor): binary image of a selected object
  • delta_var (float): contrastive loss, pull force margin
  • norm (str): vector norm used, default: Frobenius norm
Returns:

embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found

class GaussianKernel(torch.nn.modules.module.Module):
78class GaussianKernel(nn.Module):
79    def __init__(self, delta_var, pmaps_threshold):
80        super().__init__()
81        self.delta_var = delta_var
82        # dist_var^2 = -2*sigma*ln(pmaps_threshold)
83        self.two_sigma = delta_var * delta_var / (-math.log(pmaps_threshold))
84
85    def forward(self, dist_map):
86        return torch.exp(- dist_map * dist_map / self.two_sigma)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

GaussianKernel(delta_var, pmaps_threshold)
79    def __init__(self, delta_var, pmaps_threshold):
80        super().__init__()
81        self.delta_var = delta_var
82        # dist_var^2 = -2*sigma*ln(pmaps_threshold)
83        self.two_sigma = delta_var * delta_var / (-math.log(pmaps_threshold))

Initializes internal Module state, shared by both nn.Module and ScriptModule.

delta_var
two_sigma
def forward(self, dist_map):
85    def forward(self, dist_map):
86        return torch.exp(- dist_map * dist_map / self.two_sigma)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class CombinedAuxLoss(torch.nn.modules.module.Module):
 89class CombinedAuxLoss(nn.Module):
 90    def __init__(self, losses, weights):
 91        super().__init__()
 92        self.losses = losses
 93        self.weights = weights
 94
 95    def forward(self, embeddings, target, instance_pmaps, instance_masks):
 96        result = 0.
 97        for loss, weight in zip(self.losses, self.weights):
 98            if isinstance(loss, AffinitySideLoss):
 99                # add batch axis / batch and channel axis for embeddings, target
100                result += weight * loss(embeddings[None], target[None, None])
101            elif instance_masks is not None:
102                result += weight * loss(instance_pmaps, instance_masks).mean()
103        return result

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

CombinedAuxLoss(losses, weights)
90    def __init__(self, losses, weights):
91        super().__init__()
92        self.losses = losses
93        self.weights = weights

Initializes internal Module state, shared by both nn.Module and ScriptModule.

losses
weights
def forward(self, embeddings, target, instance_pmaps, instance_masks):
 95    def forward(self, embeddings, target, instance_pmaps, instance_masks):
 96        result = 0.
 97        for loss, weight in zip(self.losses, self.weights):
 98            if isinstance(loss, AffinitySideLoss):
 99                # add batch axis / batch and channel axis for embeddings, target
100                result += weight * loss(embeddings[None], target[None, None])
101            elif instance_masks is not None:
102                result += weight * loss(instance_pmaps, instance_masks).mean()
103        return result

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ContrastiveLossBase(torch.nn.modules.module.Module):
106class ContrastiveLossBase(nn.Module):
107    """Base class for the spoco losses.
108    """
109
110    def __init__(self, delta_var, delta_dist,
111                 norm="fro", alpha=1., beta=1., gamma=0.001, unlabeled_push_weight=0.0,
112                 instance_term_weight=1.0, impl=None):
113        assert scatter_mean is not None, "Spoco loss requires pytorch_scatter"
114        super().__init__()
115        self.delta_var = delta_var
116        self.delta_dist = delta_dist
117        self.norm = norm
118        self.alpha = alpha
119        self.beta = beta
120        self.gamma = gamma
121        self.unlabeled_push_weight = unlabeled_push_weight
122        self.unlabeled_push = unlabeled_push_weight > 0
123        self.instance_term_weight = instance_term_weight
124
125    def __str__(self):
126        return super().__str__() + f"\ndelta_var: {self.delta_var}\ndelta_dist: {self.delta_dist}" \
127                                   f"\nalpha: {self.alpha}\nbeta: {self.beta}\ngamma: {self.gamma}" \
128                                   f"\nunlabeled_push_weight: {self.unlabeled_push_weight}" \
129                                   f"\ninstance_term_weight: {self.instance_term_weight}"
130
131    def _compute_variance_term(self, cluster_means, embeddings, target, instance_counts, ignore_zero_label):
132        """Computes the variance term, i.e. intra-cluster pull force that draws embeddings towards the mean embedding
133
134        C - number of clusters (instances)
135        E - embedding dimension
136        SPATIAL - volume shape, i.e. DxHxW for 3D/ HxW for 2D
137
138        Args:
139            cluster_means: mean embedding of each instance, tensor (CxE)
140            embeddings: embeddings vectors per instance, tensor (ExSPATIAL)
141            target: label tensor (1xSPATIAL); each label is represented as one-hot vector
142            instance_counts: number of voxels per instance
143            ignore_zero_label: if True ignores the cluster corresponding to the 0-label
144        """
145        assert target.dim() in (2, 3)
146        ignore_labels = [0] if ignore_zero_label else None
147        return cimpl._compute_variance_term_scatter(
148            cluster_means, embeddings.unsqueeze(0), target.unsqueeze(0),
149            self.norm, self.delta_var, instance_counts, ignore_labels
150        )
151
152    def _compute_unlabeled_push(self, cluster_means, embeddings, target):
153        assert target.dim() in (2, 3)
154        n_instances = cluster_means.shape[0]
155
156        # permute embedding dimension at the end
157        if target.dim() == 2:
158            embeddings = embeddings.permute(1, 2, 0)
159        else:
160            embeddings = embeddings.permute(1, 2, 3, 0)
161
162        # decrease number of instances `C` since we're ignoring 0-label
163        n_instances -= 1
164        # if there is only 0-label in the target return 0
165        if n_instances == 0:
166            return 0.0
167
168        background_mask = target == 0
169        n_background = background_mask.sum()
170        background_push = 0.0
171        # skip embedding corresponding to the background pixels
172        for cluster_mean in cluster_means[1:]:
173            # compute distances between embeddings and a given cluster_mean
174            dist_to_mean = torch.norm(embeddings - cluster_mean, self.norm, dim=-1)
175            # apply background mask and compute hinge
176            dist_hinged = torch.clamp((self.delta_dist - dist_to_mean) * background_mask, min=0) ** 2
177            background_push += torch.sum(dist_hinged) / n_background
178
179        # normalize by the number of instances
180        return background_push / n_instances
181
182    # def _compute_distance_term_scatter(cluster_means, norm, delta_dist):
183    def _compute_distance_term(self, cluster_means, ignore_zero_label):
184        """
185        Compute the distance term, i.e an inter-cluster push-force that pushes clusters away from each other, increasing
186        the distance between cluster centers
187
188        Args:
189            cluster_means: mean embedding of each instance, tensor (CxE)
190            ignore_zero_label: if True ignores the cluster corresponding to the 0-label
191        """
192        ignore_labels = [0] if ignore_zero_label else None
193        return cimpl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist, ignore_labels)
194
195    def _compute_regularizer_term(self, cluster_means):
196        """
197        Computes the regularizer term, i.e. a small pull-force that draws all clusters towards origin to keep
198        the network activations bounded
199        """
200        # compute the norm of the mean embeddings
201        norms = torch.norm(cluster_means, p=self.norm, dim=1)
202        # return the average norm per batch
203        return torch.sum(norms) / cluster_means.size(0)
204
205    def compute_instance_term(self, embeddings, cluster_means, target):
206        """Computes auxiliary loss based on embeddings and a given list of target
207        instances together with their mean embeddings.
208
209        Args:
210            embeddings (torch.tensor): pixel embeddings (ExSPATIAL)
211            cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL)
212            target (torch.tensor): ground truth instance segmentation (SPATIAL)
213
214        Returns:
215            float: value of the instance-based term
216        """
217        raise NotImplementedError
218
219    def forward(self, input_, target):
220        """
221        Args:
222             input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims)
223                expects float32 tensor
224             target (torch.tensor): ground truth instance segmentation (Nx1DxHxW)
225                expects int64 tensor
226        Returns:
227            Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term
228                + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term
229        """
230        # enable calling this loss from the spoco trainer, which passes a tuple
231        if isinstance(input_, tuple):
232            assert len(input_) == 2
233            input_ = input_[0]
234
235        n_batches = input_.shape[0]
236        # compute the loss per each instance in the batch separately
237        # and sum it up in the per_instance variable
238        loss = 0.0
239        for single_input, single_target in zip(input_, target):
240            # compare spatial dimensions
241            assert single_input.shape[1:] == single_target.shape[1:], f"{single_input.shape}, {single_target.shape}"
242            assert single_target.shape[0] == 1
243            single_target = single_target[0]
244
245            contains_bg = 0 in single_target
246            ignore_zero_label = self.unlabeled_push and contains_bg
247
248            # get number of instances in the batch instance
249            instance_ids, instance_counts = torch.unique(single_target, return_counts=True)
250
251            # get the number of instances
252            C = instance_ids.size(0)
253
254            # compute mean embeddings (output is of shape CxE)
255            cluster_means = compute_cluster_means(single_input, single_target, C)
256
257            # compute variance term, i.e. pull force
258            variance_term = self._compute_variance_term(
259                cluster_means, single_input, single_target, instance_counts, ignore_zero_label
260            )
261
262            # compute unlabeled push force, i.e. push force between
263            # the mean cluster embeddings and embeddings of background pixels
264            # compute only ignore_zero_label is True, i.e. a given patch contains background label
265            unlabeled_push_term = 0.0
266            if self.unlabeled_push and contains_bg:
267                unlabeled_push_term = self._compute_unlabeled_push(cluster_means, single_input, single_target)
268
269            # compute the instance-based auxiliary loss
270            instance_term = self.compute_instance_term(single_input, cluster_means, single_target)
271
272            # compute distance term, i.e. push force
273            distance_term = self._compute_distance_term(cluster_means, ignore_zero_label)
274
275            # compute regularization term
276            regularization_term = self._compute_regularizer_term(cluster_means)
277
278            # compute total loss and sum it up
279            loss = self.alpha * variance_term + \
280                self.beta * distance_term + \
281                self.gamma * regularization_term + \
282                self.instance_term_weight * instance_term + \
283                self.unlabeled_push_weight * unlabeled_push_term
284
285            loss += loss
286
287        # reduce across the batch dimension
288        return loss.div(n_batches)

Base class for the spoco losses.

ContrastiveLossBase( delta_var, delta_dist, norm='fro', alpha=1.0, beta=1.0, gamma=0.001, unlabeled_push_weight=0.0, instance_term_weight=1.0, impl=None)
110    def __init__(self, delta_var, delta_dist,
111                 norm="fro", alpha=1., beta=1., gamma=0.001, unlabeled_push_weight=0.0,
112                 instance_term_weight=1.0, impl=None):
113        assert scatter_mean is not None, "Spoco loss requires pytorch_scatter"
114        super().__init__()
115        self.delta_var = delta_var
116        self.delta_dist = delta_dist
117        self.norm = norm
118        self.alpha = alpha
119        self.beta = beta
120        self.gamma = gamma
121        self.unlabeled_push_weight = unlabeled_push_weight
122        self.unlabeled_push = unlabeled_push_weight > 0
123        self.instance_term_weight = instance_term_weight

Initializes internal Module state, shared by both nn.Module and ScriptModule.

delta_var
delta_dist
norm
alpha
beta
gamma
unlabeled_push_weight
unlabeled_push
instance_term_weight
def compute_instance_term(self, embeddings, cluster_means, target):
205    def compute_instance_term(self, embeddings, cluster_means, target):
206        """Computes auxiliary loss based on embeddings and a given list of target
207        instances together with their mean embeddings.
208
209        Args:
210            embeddings (torch.tensor): pixel embeddings (ExSPATIAL)
211            cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL)
212            target (torch.tensor): ground truth instance segmentation (SPATIAL)
213
214        Returns:
215            float: value of the instance-based term
216        """
217        raise NotImplementedError

Computes auxiliary loss based on embeddings and a given list of target instances together with their mean embeddings.

Arguments:
  • embeddings (torch.tensor): pixel embeddings (ExSPATIAL)
  • cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL)
  • target (torch.tensor): ground truth instance segmentation (SPATIAL)
Returns:

float: value of the instance-based term

def forward(self, input_, target):
219    def forward(self, input_, target):
220        """
221        Args:
222             input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims)
223                expects float32 tensor
224             target (torch.tensor): ground truth instance segmentation (Nx1DxHxW)
225                expects int64 tensor
226        Returns:
227            Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term
228                + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term
229        """
230        # enable calling this loss from the spoco trainer, which passes a tuple
231        if isinstance(input_, tuple):
232            assert len(input_) == 2
233            input_ = input_[0]
234
235        n_batches = input_.shape[0]
236        # compute the loss per each instance in the batch separately
237        # and sum it up in the per_instance variable
238        loss = 0.0
239        for single_input, single_target in zip(input_, target):
240            # compare spatial dimensions
241            assert single_input.shape[1:] == single_target.shape[1:], f"{single_input.shape}, {single_target.shape}"
242            assert single_target.shape[0] == 1
243            single_target = single_target[0]
244
245            contains_bg = 0 in single_target
246            ignore_zero_label = self.unlabeled_push and contains_bg
247
248            # get number of instances in the batch instance
249            instance_ids, instance_counts = torch.unique(single_target, return_counts=True)
250
251            # get the number of instances
252            C = instance_ids.size(0)
253
254            # compute mean embeddings (output is of shape CxE)
255            cluster_means = compute_cluster_means(single_input, single_target, C)
256
257            # compute variance term, i.e. pull force
258            variance_term = self._compute_variance_term(
259                cluster_means, single_input, single_target, instance_counts, ignore_zero_label
260            )
261
262            # compute unlabeled push force, i.e. push force between
263            # the mean cluster embeddings and embeddings of background pixels
264            # compute only ignore_zero_label is True, i.e. a given patch contains background label
265            unlabeled_push_term = 0.0
266            if self.unlabeled_push and contains_bg:
267                unlabeled_push_term = self._compute_unlabeled_push(cluster_means, single_input, single_target)
268
269            # compute the instance-based auxiliary loss
270            instance_term = self.compute_instance_term(single_input, cluster_means, single_target)
271
272            # compute distance term, i.e. push force
273            distance_term = self._compute_distance_term(cluster_means, ignore_zero_label)
274
275            # compute regularization term
276            regularization_term = self._compute_regularizer_term(cluster_means)
277
278            # compute total loss and sum it up
279            loss = self.alpha * variance_term + \
280                self.beta * distance_term + \
281                self.gamma * regularization_term + \
282                self.instance_term_weight * instance_term + \
283                self.unlabeled_push_weight * unlabeled_push_term
284
285            loss += loss
286
287        # reduce across the batch dimension
288        return loss.div(n_batches)
Arguments:
  • input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims) expects float32 tensor
  • target (torch.tensor): ground truth instance segmentation (Nx1DxHxW) expects int64 tensor
Returns:

Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ExtendedContrastiveLoss(ContrastiveLossBase):
291class ExtendedContrastiveLoss(ContrastiveLossBase):
292    """Contrastive loss extended with instance-based loss term and background push term.
293
294    Based on:
295    "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572
296    """
297
298    def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001,
299                 unlabeled_push_weight=1.0, instance_term_weight=1.0, aux_loss="dice", pmaps_threshold=0.9, **kwargs):
300
301        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
302                         unlabeled_push_weight=unlabeled_push_weight,
303                         instance_term_weight=instance_term_weight)
304
305        # init auxiliary loss
306        assert aux_loss in ["dice", "affinity", "dice_aff"]
307        if aux_loss == "dice":
308            self.aff_loss = None
309            self.dice_loss = DiceLoss()
310        # additional auxiliary losses
311        elif aux_loss == "affinity":
312            self.aff_loss = AffinitySideLoss(
313                delta=delta_dist,
314                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
315                n_samples=kwargs.get("n_samples", 9)
316            )
317            self.dice_loss = None
318        elif aux_loss == "dice_aff":
319            # combine dice and affinity side loss
320            self.dice_weight = kwargs.get("dice_weight", 1.0)
321            self.aff_weight = kwargs.get("aff_weight", 1.0)
322
323            self.aff_loss = AffinitySideLoss(
324                delta=delta_dist,
325                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
326                n_samples=kwargs.get("n_samples", 9)
327            )
328            self.dice_loss = DiceLoss()
329
330        # init dist_to_mask kernel which maps distance to the cluster center to instance probability map
331        self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold)
332        self.init_kwargs = {
333            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
334            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
335            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold
336        }
337        self.init_kwargs.update(kwargs)
338
339    # FIXME stacking per instance here makes this very memory hungry,
340    def _create_instance_pmaps_and_masks(self, embeddings, anchors, target):
341        inst_pmaps = []
342        inst_masks = []
343
344        if not inst_masks:
345            return None, None
346
347        # stack along batch dimension
348        inst_pmaps = torch.stack(inst_pmaps)
349        inst_masks = torch.stack(inst_masks)
350
351        return inst_pmaps, inst_masks
352
353    def compute_instance_term(self, embeddings, cluster_means, target):
354        assert embeddings.size()[1:] == target.size()
355
356        if self.aff_loss is None:
357            aff_loss = None
358        else:
359            aff_loss = self.aff_loss(embeddings[None], target[None, None])
360
361        if self.dice_loss is None:
362            dice_loss = None
363        else:
364            dice_loss = []
365
366            # permute embedding dimension at the end
367            if target.dim() == 2:
368                embeddings = embeddings.permute(1, 2, 0)
369            else:
370                embeddings = embeddings.permute(1, 2, 3, 0)
371
372            # compute random anchors per instance
373            instances = torch.unique(target)
374            for i in instances:
375                if i == 0:
376                    continue
377                anchor_emb = cluster_means[i]
378                # FIXME this makes training extremely slow, check with Adrian if this is the latest version
379                # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var)
380
381                distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1)
382                instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0)
383                instance_mask = (target == i).float().unsqueeze(0)
384
385                dice_loss.append(self.dice_loss(instance_pmap, instance_mask))
386
387            dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0
388
389        assert not (dice_loss is None and aff_loss is None)
390        if dice_loss is None and aff_loss is not None:
391            return aff_loss
392        if dice_loss is not None and aff_loss is None:
393            return dice_loss
394        else:
395            return self.dice_weight * dice_loss + self.aff_weight * aff_loss

Contrastive loss extended with instance-based loss term and background push term.

Based on: "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572

ExtendedContrastiveLoss( delta_var, delta_dist, norm='fro', alpha=1.0, beta=1.0, gamma=0.001, unlabeled_push_weight=1.0, instance_term_weight=1.0, aux_loss='dice', pmaps_threshold=0.9, **kwargs)
298    def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001,
299                 unlabeled_push_weight=1.0, instance_term_weight=1.0, aux_loss="dice", pmaps_threshold=0.9, **kwargs):
300
301        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
302                         unlabeled_push_weight=unlabeled_push_weight,
303                         instance_term_weight=instance_term_weight)
304
305        # init auxiliary loss
306        assert aux_loss in ["dice", "affinity", "dice_aff"]
307        if aux_loss == "dice":
308            self.aff_loss = None
309            self.dice_loss = DiceLoss()
310        # additional auxiliary losses
311        elif aux_loss == "affinity":
312            self.aff_loss = AffinitySideLoss(
313                delta=delta_dist,
314                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
315                n_samples=kwargs.get("n_samples", 9)
316            )
317            self.dice_loss = None
318        elif aux_loss == "dice_aff":
319            # combine dice and affinity side loss
320            self.dice_weight = kwargs.get("dice_weight", 1.0)
321            self.aff_weight = kwargs.get("aff_weight", 1.0)
322
323            self.aff_loss = AffinitySideLoss(
324                delta=delta_dist,
325                offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]),
326                n_samples=kwargs.get("n_samples", 9)
327            )
328            self.dice_loss = DiceLoss()
329
330        # init dist_to_mask kernel which maps distance to the cluster center to instance probability map
331        self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold)
332        self.init_kwargs = {
333            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
334            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
335            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold
336        }
337        self.init_kwargs.update(kwargs)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

dist_to_mask
init_kwargs
def compute_instance_term(self, embeddings, cluster_means, target):
353    def compute_instance_term(self, embeddings, cluster_means, target):
354        assert embeddings.size()[1:] == target.size()
355
356        if self.aff_loss is None:
357            aff_loss = None
358        else:
359            aff_loss = self.aff_loss(embeddings[None], target[None, None])
360
361        if self.dice_loss is None:
362            dice_loss = None
363        else:
364            dice_loss = []
365
366            # permute embedding dimension at the end
367            if target.dim() == 2:
368                embeddings = embeddings.permute(1, 2, 0)
369            else:
370                embeddings = embeddings.permute(1, 2, 3, 0)
371
372            # compute random anchors per instance
373            instances = torch.unique(target)
374            for i in instances:
375                if i == 0:
376                    continue
377                anchor_emb = cluster_means[i]
378                # FIXME this makes training extremely slow, check with Adrian if this is the latest version
379                # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var)
380
381                distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1)
382                instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0)
383                instance_mask = (target == i).float().unsqueeze(0)
384
385                dice_loss.append(self.dice_loss(instance_pmap, instance_mask))
386
387            dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0
388
389        assert not (dice_loss is None and aff_loss is None)
390        if dice_loss is None and aff_loss is not None:
391            return aff_loss
392        if dice_loss is not None and aff_loss is None:
393            return dice_loss
394        else:
395            return self.dice_weight * dice_loss + self.aff_weight * aff_loss

Computes auxiliary loss based on embeddings and a given list of target instances together with their mean embeddings.

Arguments:
  • embeddings (torch.tensor): pixel embeddings (ExSPATIAL)
  • cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL)
  • target (torch.tensor): ground truth instance segmentation (SPATIAL)
Returns:

float: value of the instance-based term

Inherited Members
ContrastiveLossBase
delta_var
delta_dist
norm
alpha
beta
gamma
unlabeled_push_weight
unlabeled_push
instance_term_weight
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SPOCOLoss(ExtendedContrastiveLoss):
398class SPOCOLoss(ExtendedContrastiveLoss):
399    """The full SPOCO Loss for instance segmentation training with sparse instance labels.
400
401    Extends the "classic" contrastive loss with an instance-based term and a embedding consistency term.
402    (The unlabeled push term is turned off by default, since we assume sparse instance labels).
403
404    Based on:
405    "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572
406    """
407
408    def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001,
409                 unlabeled_push_weight=0.0, instance_term_weight=1.0, consistency_term_weight=1.0,
410                 aux_loss="dice", pmaps_threshold=0.9, max_anchors=20, volume_threshold=0.05, **kwargs):
411
412        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
413                         unlabeled_push_weight=unlabeled_push_weight,
414                         instance_term_weight=instance_term_weight,
415                         aux_loss=aux_loss,
416                         pmaps_threshold=pmaps_threshold,
417                         **kwargs)
418
419        self.consistency_term_weight = consistency_term_weight
420        self.max_anchors = max_anchors
421        self.volume_threshold = volume_threshold
422        self.consistency_loss = DiceLoss()
423        self.init_kwargs = {
424            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
425            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
426            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold,
427            "max_anchors": max_anchors, "volume_threshold": volume_threshold
428        }
429        self.init_kwargs.update(kwargs)
430
431    def __str__(self):
432        return super().__str__() + f"\nconsistency_term_weight: {self.consistency_term_weight}"
433
434    def _inst_pmap(self, emb, anchor):
435        # compute distance map
436        distance_map = torch.norm(emb - anchor, self.norm, dim=-1)
437        # convert distance map to instance pmaps and return
438        return self.dist_to_mask(distance_map)
439
440    def emb_consistency(self, emb_q, emb_k, mask):
441        inst_q = []
442        inst_k = []
443        for i in range(self.max_anchors):
444            if mask.sum() < self.volume_threshold * mask.numel():
445                break
446
447            # get random anchor
448            indices = torch.nonzero(mask, as_tuple=True)
449            ind = np.random.randint(len(indices[0]))
450
451            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
452            inst_q.append(q_pmap)
453
454            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
455            inst_k.append(k_pmap)
456
457        # stack along channel dim
458        inst_q = torch.stack(inst_q)
459        inst_k = torch.stack(inst_k)
460
461        loss = self.consistency_loss(inst_q, inst_k)
462        return loss
463
464    def _extract_pmap(self, emb, mask, indices, ind):
465        if mask.dim() == 2:
466            y, x = indices
467            anchor = emb[:, y[ind], x[ind]]
468            emb = emb.permute(1, 2, 0)
469        else:
470            z, y, x = indices
471            anchor = emb[:, z[ind], y[ind], x[ind]]
472            emb = emb.permute(1, 2, 3, 0)
473
474        return self._inst_pmap(emb, anchor)
475
476    def forward(self, input, target):
477        assert len(input) == 2
478        emb_q, emb_k = input
479
480        # compute extended contrastive loss only on the embeddings coming from q
481        contrastive_loss = super().forward(emb_q, target)
482
483        # TODO enable computing the consistency on all pixels!
484        # compute consistency term
485        for e_q, e_k, t in zip(emb_q, emb_k, target):
486            unlabeled_mask = (t[0] == 0).int()
487            if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel():
488                continue
489            emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask)
490            contrastive_loss += self.consistency_term_weight * emb_consistency_loss
491
492        return contrastive_loss

The full SPOCO Loss for instance segmentation training with sparse instance labels.

Extends the "classic" contrastive loss with an instance-based term and a embedding consistency term. (The unlabeled push term is turned off by default, since we assume sparse instance labels).

Based on: "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572

SPOCOLoss( delta_var, delta_dist, norm='fro', alpha=1.0, beta=1.0, gamma=0.001, unlabeled_push_weight=0.0, instance_term_weight=1.0, consistency_term_weight=1.0, aux_loss='dice', pmaps_threshold=0.9, max_anchors=20, volume_threshold=0.05, **kwargs)
408    def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001,
409                 unlabeled_push_weight=0.0, instance_term_weight=1.0, consistency_term_weight=1.0,
410                 aux_loss="dice", pmaps_threshold=0.9, max_anchors=20, volume_threshold=0.05, **kwargs):
411
412        super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma,
413                         unlabeled_push_weight=unlabeled_push_weight,
414                         instance_term_weight=instance_term_weight,
415                         aux_loss=aux_loss,
416                         pmaps_threshold=pmaps_threshold,
417                         **kwargs)
418
419        self.consistency_term_weight = consistency_term_weight
420        self.max_anchors = max_anchors
421        self.volume_threshold = volume_threshold
422        self.consistency_loss = DiceLoss()
423        self.init_kwargs = {
424            "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta,
425            "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight,
426            "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold,
427            "max_anchors": max_anchors, "volume_threshold": volume_threshold
428        }
429        self.init_kwargs.update(kwargs)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

consistency_term_weight
max_anchors
volume_threshold
consistency_loss
init_kwargs
def emb_consistency(self, emb_q, emb_k, mask):
440    def emb_consistency(self, emb_q, emb_k, mask):
441        inst_q = []
442        inst_k = []
443        for i in range(self.max_anchors):
444            if mask.sum() < self.volume_threshold * mask.numel():
445                break
446
447            # get random anchor
448            indices = torch.nonzero(mask, as_tuple=True)
449            ind = np.random.randint(len(indices[0]))
450
451            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
452            inst_q.append(q_pmap)
453
454            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
455            inst_k.append(k_pmap)
456
457        # stack along channel dim
458        inst_q = torch.stack(inst_q)
459        inst_k = torch.stack(inst_k)
460
461        loss = self.consistency_loss(inst_q, inst_k)
462        return loss
def forward(self, input, target):
476    def forward(self, input, target):
477        assert len(input) == 2
478        emb_q, emb_k = input
479
480        # compute extended contrastive loss only on the embeddings coming from q
481        contrastive_loss = super().forward(emb_q, target)
482
483        # TODO enable computing the consistency on all pixels!
484        # compute consistency term
485        for e_q, e_k, t in zip(emb_q, emb_k, target):
486            unlabeled_mask = (t[0] == 0).int()
487            if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel():
488                continue
489            emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask)
490            contrastive_loss += self.consistency_term_weight * emb_consistency_loss
491
492        return contrastive_loss
Arguments:
  • input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims) expects float32 tensor
  • target (torch.tensor): ground truth instance segmentation (Nx1DxHxW) expects int64 tensor
Returns:

Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term

Inherited Members
ExtendedContrastiveLoss
dist_to_mask
compute_instance_term
ContrastiveLossBase
delta_var
delta_dist
norm
alpha
beta
gamma
unlabeled_push_weight
unlabeled_push
instance_term_weight
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SPOCOConsistencyLoss(torch.nn.modules.module.Module):
496class SPOCOConsistencyLoss(nn.Module):
497    def __init__(self, delta_var, pmaps_threshold, max_anchors=30, norm="fro"):
498        super().__init__()
499        self.max_anchors = max_anchors
500        self.consistency_loss = DiceLoss()
501        self.norm = norm
502        self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold)
503        self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold,
504                            "max_anchors": max_anchors, "norm": norm}
505
506    def _inst_pmap(self, emb, anchor):
507        # compute distance map
508        distance_map = torch.norm(emb - anchor, self.norm, dim=-1)
509        # convert distance map to instance pmaps and return
510        return self.dist_to_mask(distance_map)
511
512    def emb_consistency(self, emb_q, emb_k):
513        inst_q = []
514        inst_k = []
515        mask = torch.ones(emb_q.shape[1:])
516        for i in range(self.max_anchors):
517            # get random anchor
518            indices = torch.nonzero(mask, as_tuple=True)
519            ind = np.random.randint(len(indices[0]))
520
521            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
522            inst_q.append(q_pmap)
523
524            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
525            inst_k.append(k_pmap)
526
527        # stack along channel dim
528        inst_q = torch.stack(inst_q)
529        inst_k = torch.stack(inst_k)
530
531        loss = self.consistency_loss(inst_q, inst_k)
532        return loss
533
534    def _extract_pmap(self, emb, mask, indices, ind):
535        if mask.dim() == 2:
536            y, x = indices
537            anchor = emb[:, y[ind], x[ind]]
538            emb = emb.permute(1, 2, 0)
539        else:
540            z, y, x = indices
541            anchor = emb[:, z[ind], y[ind], x[ind]]
542            emb = emb.permute(1, 2, 3, 0)
543
544        return self._inst_pmap(emb, anchor)
545
546    def forward(self, emb_q, emb_k):
547        contrastive_loss = 0.0
548        # compute consistency term
549        for e_q, e_k in zip(emb_q, emb_k):
550            contrastive_loss += self.emb_consistency(e_q, e_k)
551        return contrastive_loss

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SPOCOConsistencyLoss(delta_var, pmaps_threshold, max_anchors=30, norm='fro')
497    def __init__(self, delta_var, pmaps_threshold, max_anchors=30, norm="fro"):
498        super().__init__()
499        self.max_anchors = max_anchors
500        self.consistency_loss = DiceLoss()
501        self.norm = norm
502        self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold)
503        self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold,
504                            "max_anchors": max_anchors, "norm": norm}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

max_anchors
consistency_loss
norm
dist_to_mask
init_kwargs
def emb_consistency(self, emb_q, emb_k):
512    def emb_consistency(self, emb_q, emb_k):
513        inst_q = []
514        inst_k = []
515        mask = torch.ones(emb_q.shape[1:])
516        for i in range(self.max_anchors):
517            # get random anchor
518            indices = torch.nonzero(mask, as_tuple=True)
519            ind = np.random.randint(len(indices[0]))
520
521            q_pmap = self._extract_pmap(emb_q, mask, indices, ind)
522            inst_q.append(q_pmap)
523
524            k_pmap = self._extract_pmap(emb_k, mask, indices, ind)
525            inst_k.append(k_pmap)
526
527        # stack along channel dim
528        inst_q = torch.stack(inst_q)
529        inst_k = torch.stack(inst_k)
530
531        loss = self.consistency_loss(inst_q, inst_k)
532        return loss
def forward(self, emb_q, emb_k):
546    def forward(self, emb_q, emb_k):
547        contrastive_loss = 0.0
548        # compute consistency term
549        for e_q, e_k in zip(emb_q, emb_k):
550            contrastive_loss += self.emb_consistency(e_q, e_k)
551        return contrastive_loss

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile