torch_em.loss.spoco_loss
1import math 2 3import numpy as np 4import torch 5import torch.nn as nn 6try: 7 from torch_scatter import scatter_mean 8except ImportError: 9 scatter_mean = None 10 11from . import contrastive_impl as cimpl 12from .affinity_side_loss import AffinitySideLoss 13from .dice import DiceLoss 14 15 16def compute_cluster_means(embeddings: torch.Tensor, target: torch.Tensor, n_instances: int) -> torch.Tensor: 17 """Compute mean embeddings per instance. 18 19 Args: 20 embeddings: The tensor of pixel embeddings with shape: ExSPATIAL. E is the embedding dimension. 21 target: One-hot encoded target instances with shape: SPATIAL. 22 n_instances: The number of instances. 23 24 Returns: 25 The cluster means. 26 """ 27 assert scatter_mean is not None, "torch_scatter is required" 28 embeddings = embeddings.flatten(1) 29 target = target.flatten() 30 assert target.min() == 0, \ 31 "The target min value has to be zero, otherwise this will lead to errors in scatter." 32 mean_embeddings = scatter_mean(embeddings, target, dim_size=n_instances) 33 return mean_embeddings.transpose(1, 0) 34 35 36def select_stable_anchor( 37 embeddings: torch.Tensor, 38 mean_embedding: torch.Tensor, 39 object_mask: torch.Tensor, 40 delta_var: float, 41 norm: str = "fro" 42) -> torch.Tensor: 43 """Sample anchor embeddings from the object mask. 44 45 Given a binary mask of an object (`object_mask`) and a `mean_embedding` vector within the mask, 46 the function selects a pixel from the mask at random and returns its embedding only if it's closer than 47 `delta_var` from the `mean_embedding`. 48 49 Args: 50 embeddings: The embeddings, a ExSPATIAL vector field of an image. 51 mean_embedding: The E-dimensional mean of embeddings lying within the `object_mask`. 52 object_mask: Binary image of a selected object. 53 delta_var: The pull force margin of the contrastive loss. 54 norm: The vector norm used. By default the frobenius norm is used. 55 56 Returns: 57 Embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found. 58 """ 59 indices = torch.nonzero(object_mask, as_tuple=True) 60 # convert to numpy 61 indices = [t.cpu().numpy() for t in indices] 62 63 # randomize coordinates 64 seed = np.random.randint(np.iinfo("int32").max) 65 for t in indices: 66 rs = np.random.RandomState(seed) 67 rs.shuffle(t) 68 69 for ind in range(len(indices[0])): 70 if object_mask.dim() == 2: 71 y, x = indices 72 anchor_emb = embeddings[:, y[ind], x[ind]] 73 anchor_emb = anchor_emb[..., None, None] 74 else: 75 z, y, x = indices 76 anchor_emb = embeddings[:, z[ind], y[ind], x[ind]] 77 anchor_emb = anchor_emb[..., None, None, None] 78 dist_to_mean = torch.norm(mean_embedding - anchor_emb, norm) 79 if dist_to_mean < delta_var: 80 return anchor_emb 81 # if stable anchor has not been found, return mean_embedding 82 return mean_embedding 83 84 85class GaussianKernel(nn.Module): 86 """@private 87 """ 88 def __init__(self, delta_var, pmaps_threshold): 89 super().__init__() 90 self.delta_var = delta_var 91 # dist_var^2 = -2*sigma*ln(pmaps_threshold) 92 self.two_sigma = delta_var * delta_var / (-math.log(pmaps_threshold)) 93 94 def forward(self, dist_map): 95 return torch.exp(- dist_map * dist_map / self.two_sigma) 96 97 98class CombinedAuxLoss(nn.Module): 99 """@private 100 """ 101 def __init__(self, losses, weights): 102 super().__init__() 103 self.losses = losses 104 self.weights = weights 105 106 def forward(self, embeddings, target, instance_pmaps, instance_masks): 107 result = 0. 108 for loss, weight in zip(self.losses, self.weights): 109 if isinstance(loss, AffinitySideLoss): 110 # add batch axis / batch and channel axis for embeddings, target 111 result += weight * loss(embeddings[None], target[None, None]) 112 elif instance_masks is not None: 113 result += weight * loss(instance_pmaps, instance_masks).mean() 114 return result 115 116 117class ContrastiveLossBase(nn.Module): 118 """@private 119 """ 120 def __init__(self, delta_var, delta_dist, 121 norm="fro", alpha=1., beta=1., gamma=0.001, unlabeled_push_weight=0.0, 122 instance_term_weight=1.0, impl=None): 123 assert scatter_mean is not None, "Spoco loss requires pytorch_scatter" 124 super().__init__() 125 self.delta_var = delta_var 126 self.delta_dist = delta_dist 127 self.norm = norm 128 self.alpha = alpha 129 self.beta = beta 130 self.gamma = gamma 131 self.unlabeled_push_weight = unlabeled_push_weight 132 self.unlabeled_push = unlabeled_push_weight > 0 133 self.instance_term_weight = instance_term_weight 134 135 def __str__(self): 136 return super().__str__() + f"\ndelta_var: {self.delta_var}\ndelta_dist: {self.delta_dist}" \ 137 f"\nalpha: {self.alpha}\nbeta: {self.beta}\ngamma: {self.gamma}" \ 138 f"\nunlabeled_push_weight: {self.unlabeled_push_weight}" \ 139 f"\ninstance_term_weight: {self.instance_term_weight}" 140 141 def _compute_variance_term(self, cluster_means, embeddings, target, instance_counts, ignore_zero_label): 142 """Computes the variance term, i.e. intra-cluster pull force that draws embeddings towards the mean embedding 143 144 C - number of clusters (instances) 145 E - embedding dimension 146 SPATIAL - volume shape, i.e. DxHxW for 3D/ HxW for 2D 147 148 Args: 149 cluster_means: mean embedding of each instance, tensor (CxE) 150 embeddings: embeddings vectors per instance, tensor (ExSPATIAL) 151 target: label tensor (1xSPATIAL); each label is represented as one-hot vector 152 instance_counts: number of voxels per instance 153 ignore_zero_label: if True ignores the cluster corresponding to the 0-label 154 """ 155 assert target.dim() in (2, 3) 156 ignore_labels = [0] if ignore_zero_label else None 157 return cimpl._compute_variance_term_scatter( 158 cluster_means, embeddings.unsqueeze(0), target.unsqueeze(0), 159 self.norm, self.delta_var, instance_counts, ignore_labels 160 ) 161 162 def _compute_unlabeled_push(self, cluster_means, embeddings, target): 163 assert target.dim() in (2, 3) 164 n_instances = cluster_means.shape[0] 165 166 # permute embedding dimension at the end 167 if target.dim() == 2: 168 embeddings = embeddings.permute(1, 2, 0) 169 else: 170 embeddings = embeddings.permute(1, 2, 3, 0) 171 172 # decrease number of instances `C` since we're ignoring 0-label 173 n_instances -= 1 174 # if there is only 0-label in the target return 0 175 if n_instances == 0: 176 return 0.0 177 178 background_mask = target == 0 179 n_background = background_mask.sum() 180 background_push = 0.0 181 # skip embedding corresponding to the background pixels 182 for cluster_mean in cluster_means[1:]: 183 # compute distances between embeddings and a given cluster_mean 184 dist_to_mean = torch.norm(embeddings - cluster_mean, self.norm, dim=-1) 185 # apply background mask and compute hinge 186 dist_hinged = torch.clamp((self.delta_dist - dist_to_mean) * background_mask, min=0) ** 2 187 background_push += torch.sum(dist_hinged) / n_background 188 189 # normalize by the number of instances 190 return background_push / n_instances 191 192 # def _compute_distance_term_scatter(cluster_means, norm, delta_dist): 193 def _compute_distance_term(self, cluster_means, ignore_zero_label): 194 """ 195 Compute the distance term, i.e an inter-cluster push-force that pushes clusters away from each other, increasing 196 the distance between cluster centers 197 198 Args: 199 cluster_means: mean embedding of each instance, tensor (CxE) 200 ignore_zero_label: if True ignores the cluster corresponding to the 0-label 201 """ 202 ignore_labels = [0] if ignore_zero_label else None 203 return cimpl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist, ignore_labels) 204 205 def _compute_regularizer_term(self, cluster_means): 206 """ 207 Computes the regularizer term, i.e. a small pull-force that draws all clusters towards origin to keep 208 the network activations bounded 209 """ 210 # compute the norm of the mean embeddings 211 norms = torch.norm(cluster_means, p=self.norm, dim=1) 212 # return the average norm per batch 213 return torch.sum(norms) / cluster_means.size(0) 214 215 def compute_instance_term(self, embeddings, cluster_means, target): 216 """Computes auxiliary loss based on embeddings and a given list of target 217 instances together with their mean embeddings. 218 219 Args: 220 embeddings (torch.tensor): pixel embeddings (ExSPATIAL) 221 cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL) 222 target (torch.tensor): ground truth instance segmentation (SPATIAL) 223 224 Returns: 225 float: value of the instance-based term 226 """ 227 raise NotImplementedError 228 229 def forward(self, input_, target): 230 """ 231 Args: 232 input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims) 233 expects float32 tensor 234 target (torch.tensor): ground truth instance segmentation (Nx1DxHxW) 235 expects int64 tensor 236 Returns: 237 Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term 238 + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term 239 """ 240 # enable calling this loss from the spoco trainer, which passes a tuple 241 if isinstance(input_, tuple): 242 assert len(input_) == 2 243 input_ = input_[0] 244 245 n_batches = input_.shape[0] 246 # compute the loss per each instance in the batch separately 247 # and sum it up in the per_instance variable 248 loss = 0.0 249 for single_input, single_target in zip(input_, target): 250 # compare spatial dimensions 251 assert single_input.shape[1:] == single_target.shape[1:], f"{single_input.shape}, {single_target.shape}" 252 assert single_target.shape[0] == 1 253 single_target = single_target[0] 254 255 contains_bg = 0 in single_target 256 ignore_zero_label = self.unlabeled_push and contains_bg 257 258 # get number of instances in the batch instance 259 instance_ids, instance_counts = torch.unique(single_target, return_counts=True) 260 261 # get the number of instances 262 C = instance_ids.size(0) 263 264 # compute mean embeddings (output is of shape CxE) 265 cluster_means = compute_cluster_means(single_input, single_target, C) 266 267 # compute variance term, i.e. pull force 268 variance_term = self._compute_variance_term( 269 cluster_means, single_input, single_target, instance_counts, ignore_zero_label 270 ) 271 272 # compute unlabeled push force, i.e. push force between 273 # the mean cluster embeddings and embeddings of background pixels 274 # compute only ignore_zero_label is True, i.e. a given patch contains background label 275 unlabeled_push_term = 0.0 276 if self.unlabeled_push and contains_bg: 277 unlabeled_push_term = self._compute_unlabeled_push(cluster_means, single_input, single_target) 278 279 # compute the instance-based auxiliary loss 280 instance_term = self.compute_instance_term(single_input, cluster_means, single_target) 281 282 # compute distance term, i.e. push force 283 distance_term = self._compute_distance_term(cluster_means, ignore_zero_label) 284 285 # compute regularization term 286 regularization_term = self._compute_regularizer_term(cluster_means) 287 288 # compute total loss and sum it up 289 loss = self.alpha * variance_term + \ 290 self.beta * distance_term + \ 291 self.gamma * regularization_term + \ 292 self.instance_term_weight * instance_term + \ 293 self.unlabeled_push_weight * unlabeled_push_term 294 295 loss += loss 296 297 # reduce across the batch dimension 298 return loss.div(n_batches) 299 300 301class ExtendedContrastiveLoss(ContrastiveLossBase): 302 """Contrastive loss extended with instance-based loss term and background push term. 303 304 Based on: 305 "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": 306 https://arxiv.org/abs/2103.14572 307 308 Args: 309 delta_var: The hinge distance for the variance term in the discriminative loss. 310 delta_dist: The hinge distance for the distance term in the discriminative loss. 311 norm: The norm to use. 312 alpha: Weight for the variance term of the discrimantive loss. 313 beta: Weight for the distance term of the discriminative loss. 314 gamma: Weight for the regularization term of the discriminative loss. 315 unlabeled_push_weight: The weight term for the unlabeled loss term. 316 instance_term_weight: The weight term for the instance loss term. 317 aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'. 318 pmaps_threshold: The probabilit threshold for the background push term. 319 kwargs: Additional keyword arguments for other loss terms. 320 """ 321 def __init__( 322 self, 323 delta_var: float, 324 delta_dist: float, 325 norm: str = "fro", 326 alpha: float = 1.0, 327 beta: float = 1.0, 328 gamma: float = 0.001, 329 unlabeled_push_weight: float = 1.0, 330 instance_term_weight: float = 1.0, 331 aux_loss: str = "dice", 332 pmaps_threshold: float = 0.9, 333 **kwargs, 334 ): 335 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 336 unlabeled_push_weight=unlabeled_push_weight, 337 instance_term_weight=instance_term_weight) 338 # Init auxiliary loss. 339 assert aux_loss in ["dice", "affinity", "dice_aff"] 340 if aux_loss == "dice": 341 self.aff_loss = None 342 self.dice_loss = DiceLoss() 343 # Additional auxiliary losses. 344 elif aux_loss == "affinity": 345 self.aff_loss = AffinitySideLoss( 346 delta=delta_dist, 347 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 348 n_samples=kwargs.get("n_samples", 9) 349 ) 350 self.dice_loss = None 351 elif aux_loss == "dice_aff": 352 # combine dice and affinity side loss 353 self.dice_weight = kwargs.get("dice_weight", 1.0) 354 self.aff_weight = kwargs.get("aff_weight", 1.0) 355 356 self.aff_loss = AffinitySideLoss( 357 delta=delta_dist, 358 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 359 n_samples=kwargs.get("n_samples", 9) 360 ) 361 self.dice_loss = DiceLoss() 362 363 # Init dist_to_mask kernel which maps distance to the cluster center to instance probability map. 364 self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold) 365 self.init_kwargs = { 366 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 367 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 368 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold 369 } 370 self.init_kwargs.update(kwargs) 371 372 # FIXME stacking per instance here makes this very memory hungry, 373 def _create_instance_pmaps_and_masks(self, embeddings, anchors, target): 374 inst_pmaps = [] 375 inst_masks = [] 376 377 if not inst_masks: 378 return None, None 379 380 # stack along batch dimension 381 inst_pmaps = torch.stack(inst_pmaps) 382 inst_masks = torch.stack(inst_masks) 383 384 return inst_pmaps, inst_masks 385 386 def compute_instance_term(self, embeddings, cluster_means, target): 387 """@private 388 """ 389 assert embeddings.size()[1:] == target.size() 390 391 if self.aff_loss is None: 392 aff_loss = None 393 else: 394 aff_loss = self.aff_loss(embeddings[None], target[None, None]) 395 396 if self.dice_loss is None: 397 dice_loss = None 398 else: 399 dice_loss = [] 400 401 # permute embedding dimension at the end 402 if target.dim() == 2: 403 embeddings = embeddings.permute(1, 2, 0) 404 else: 405 embeddings = embeddings.permute(1, 2, 3, 0) 406 407 # compute random anchors per instance 408 instances = torch.unique(target) 409 for i in instances: 410 if i == 0: 411 continue 412 anchor_emb = cluster_means[i] 413 # FIXME this makes training extremely slow, check with Adrian if this is the latest version 414 # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var) 415 416 distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1) 417 instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0) 418 instance_mask = (target == i).float().unsqueeze(0) 419 420 dice_loss.append(self.dice_loss(instance_pmap, instance_mask)) 421 422 dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0 423 424 assert not (dice_loss is None and aff_loss is None) 425 if dice_loss is None and aff_loss is not None: 426 return aff_loss 427 if dice_loss is not None and aff_loss is None: 428 return dice_loss 429 else: 430 return self.dice_weight * dice_loss + self.aff_weight * aff_loss 431 432 433class SPOCOLoss(ExtendedContrastiveLoss): 434 """The full SPOCO Loss for instance segmentation training with sparse instance labels. 435 436 Extends the "classic" contrastive loss with an instance-based term and a unsupervised embedding consistency term. 437 An additional background push term can be added. It is disabled by default because we assume sparse instance labels. 438 439 Based on: 440 "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": 441 https://arxiv.org/abs/2103.14572 442 443 Args: 444 delta_var: The hinge distance for the variance term in the discriminative loss. 445 delta_dist: The hinge distance for the distance term in the discriminative loss. 446 norm: The norm to use. 447 alpha: Weight for the variance term of the discrimantive loss. 448 beta: Weight for the distance term of the discriminative loss. 449 gamma: Weight for the regularization term of the discriminative loss. 450 unlabeled_push_weight: The weight term for the unlabeled loss term. 451 instance_term_weight: The weight term for the instance loss term. 452 aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'. 453 pmaps_threshold: The probabilit threshold for the background push term. 454 max_anchors: The number of anchors to sample for the consistency term. 455 volume_threshold: 456 kwargs: Additional keyword arguments for other loss terms. 457 """ 458 def __init__( 459 self, 460 delta_var: float, 461 delta_dist: float, 462 norm: str = "fro", 463 alpha: float = 1.0, 464 beta: float = 1.0, 465 gamma: float = 0.001, 466 unlabeled_push_weight: float = 0.0, 467 instance_term_weight: float = 1.0, 468 consistency_term_weight: float = 1.0, 469 aux_loss: str = "dice", 470 pmaps_threshold: float = 0.9, 471 max_anchors: int = 20, 472 volume_threshold: float = 0.05, 473 **kwargs, 474 ): 475 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 476 unlabeled_push_weight=unlabeled_push_weight, 477 instance_term_weight=instance_term_weight, 478 aux_loss=aux_loss, 479 pmaps_threshold=pmaps_threshold, 480 **kwargs) 481 482 self.consistency_term_weight = consistency_term_weight 483 self.max_anchors = max_anchors 484 self.volume_threshold = volume_threshold 485 self.consistency_loss = DiceLoss() 486 self.init_kwargs = { 487 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 488 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 489 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold, 490 "max_anchors": max_anchors, "volume_threshold": volume_threshold 491 } 492 self.init_kwargs.update(kwargs) 493 494 def __str__(self): 495 return super().__str__() + f"\nconsistency_term_weight: {self.consistency_term_weight}" 496 497 def _inst_pmap(self, emb, anchor): 498 # compute distance map 499 distance_map = torch.norm(emb - anchor, self.norm, dim=-1) 500 # convert distance map to instance pmaps and return 501 return self.dist_to_mask(distance_map) 502 503 def emb_consistency(self, emb_q, emb_k, mask): 504 """@private 505 """ 506 inst_q = [] 507 inst_k = [] 508 for i in range(self.max_anchors): 509 if mask.sum() < self.volume_threshold * mask.numel(): 510 break 511 512 # get random anchor 513 indices = torch.nonzero(mask, as_tuple=True) 514 ind = np.random.randint(len(indices[0])) 515 516 q_pmap = self._extract_pmap(emb_q, mask, indices, ind) 517 inst_q.append(q_pmap) 518 519 k_pmap = self._extract_pmap(emb_k, mask, indices, ind) 520 inst_k.append(k_pmap) 521 522 # stack along channel dim 523 inst_q = torch.stack(inst_q) 524 inst_k = torch.stack(inst_k) 525 526 loss = self.consistency_loss(inst_q, inst_k) 527 return loss 528 529 def _extract_pmap(self, emb, mask, indices, ind): 530 if mask.dim() == 2: 531 y, x = indices 532 anchor = emb[:, y[ind], x[ind]] 533 emb = emb.permute(1, 2, 0) 534 else: 535 z, y, x = indices 536 anchor = emb[:, z[ind], y[ind], x[ind]] 537 emb = emb.permute(1, 2, 3, 0) 538 539 return self._inst_pmap(emb, anchor) 540 541 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 542 """Compute the SPOCO loss. 543 544 Args: 545 input_: The predicted embeddings. 546 target: The segmentation target. 547 548 Returns: 549 The SPOCO loss. 550 """ 551 assert len(input_) == 2 552 emb_q, emb_k = input_ 553 554 # Compute extended contrastive loss only on the embeddings coming from q. 555 contrastive_loss = super().forward(emb_q, target) 556 557 # TODO enable computing the consistency on all pixels! 558 # Compute consistency term. 559 for e_q, e_k, t in zip(emb_q, emb_k, target): 560 unlabeled_mask = (t[0] == 0).int() 561 if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel(): 562 continue 563 emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask) 564 contrastive_loss += self.consistency_term_weight * emb_consistency_loss 565 566 return contrastive_loss 567 568 569class SPOCOConsistencyLoss(nn.Module): 570 """Unsupervised consistency term computed between embeddings. 571 572 Args: 573 delta_var: Hinge distance of the distance loss term. 574 pmaps_threshold: 575 max_anchors: The maximum number of anchors to compute for the consistency loss. 576 norm: The vector norm used. By default the frobenius norm is used. 577 """ 578 def __init__(self, delta_var: float, pmaps_threshold: float, max_anchors: int = 30, norm: str = "fro"): 579 super().__init__() 580 self.max_anchors = max_anchors 581 self.consistency_loss = DiceLoss() 582 self.norm = norm 583 self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold) 584 self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold, 585 "max_anchors": max_anchors, "norm": norm} 586 587 def _inst_pmap(self, emb, anchor): 588 # compute distance map 589 distance_map = torch.norm(emb - anchor, self.norm, dim=-1) 590 # convert distance map to instance pmaps and return 591 return self.dist_to_mask(distance_map) 592 593 def emb_consistency(self, emb_q, emb_k): 594 """@private 595 """ 596 inst_q = [] 597 inst_k = [] 598 mask = torch.ones(emb_q.shape[1:]) 599 for i in range(self.max_anchors): 600 # get random anchor 601 indices = torch.nonzero(mask, as_tuple=True) 602 ind = np.random.randint(len(indices[0])) 603 604 q_pmap = self._extract_pmap(emb_q, mask, indices, ind) 605 inst_q.append(q_pmap) 606 607 k_pmap = self._extract_pmap(emb_k, mask, indices, ind) 608 inst_k.append(k_pmap) 609 610 # stack along channel dim 611 inst_q = torch.stack(inst_q) 612 inst_k = torch.stack(inst_k) 613 614 loss = self.consistency_loss(inst_q, inst_k) 615 return loss 616 617 def _extract_pmap(self, emb, mask, indices, ind): 618 if mask.dim() == 2: 619 y, x = indices 620 anchor = emb[:, y[ind], x[ind]] 621 emb = emb.permute(1, 2, 0) 622 else: 623 z, y, x = indices 624 anchor = emb[:, z[ind], y[ind], x[ind]] 625 emb = emb.permute(1, 2, 3, 0) 626 627 return self._inst_pmap(emb, anchor) 628 629 def forward(self, emb_q: torch.Tensor, emb_k: torch.Tensor) -> torch.Tensor: 630 """Compute the consistency loss term between embeddings. 631 632 Args: 633 emb_q: The first embedding predictions. 634 emb_k: The second embedding predictions. 635 636 Returns: 637 The consistency loss. 638 """ 639 contrastive_loss = 0.0 640 # compute consistency term 641 for e_q, e_k in zip(emb_q, emb_k): 642 contrastive_loss += self.emb_consistency(e_q, e_k) 643 return contrastive_loss
17def compute_cluster_means(embeddings: torch.Tensor, target: torch.Tensor, n_instances: int) -> torch.Tensor: 18 """Compute mean embeddings per instance. 19 20 Args: 21 embeddings: The tensor of pixel embeddings with shape: ExSPATIAL. E is the embedding dimension. 22 target: One-hot encoded target instances with shape: SPATIAL. 23 n_instances: The number of instances. 24 25 Returns: 26 The cluster means. 27 """ 28 assert scatter_mean is not None, "torch_scatter is required" 29 embeddings = embeddings.flatten(1) 30 target = target.flatten() 31 assert target.min() == 0, \ 32 "The target min value has to be zero, otherwise this will lead to errors in scatter." 33 mean_embeddings = scatter_mean(embeddings, target, dim_size=n_instances) 34 return mean_embeddings.transpose(1, 0)
Compute mean embeddings per instance.
Arguments:
- embeddings: The tensor of pixel embeddings with shape: ExSPATIAL. E is the embedding dimension.
- target: One-hot encoded target instances with shape: SPATIAL.
- n_instances: The number of instances.
Returns:
The cluster means.
37def select_stable_anchor( 38 embeddings: torch.Tensor, 39 mean_embedding: torch.Tensor, 40 object_mask: torch.Tensor, 41 delta_var: float, 42 norm: str = "fro" 43) -> torch.Tensor: 44 """Sample anchor embeddings from the object mask. 45 46 Given a binary mask of an object (`object_mask`) and a `mean_embedding` vector within the mask, 47 the function selects a pixel from the mask at random and returns its embedding only if it's closer than 48 `delta_var` from the `mean_embedding`. 49 50 Args: 51 embeddings: The embeddings, a ExSPATIAL vector field of an image. 52 mean_embedding: The E-dimensional mean of embeddings lying within the `object_mask`. 53 object_mask: Binary image of a selected object. 54 delta_var: The pull force margin of the contrastive loss. 55 norm: The vector norm used. By default the frobenius norm is used. 56 57 Returns: 58 Embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found. 59 """ 60 indices = torch.nonzero(object_mask, as_tuple=True) 61 # convert to numpy 62 indices = [t.cpu().numpy() for t in indices] 63 64 # randomize coordinates 65 seed = np.random.randint(np.iinfo("int32").max) 66 for t in indices: 67 rs = np.random.RandomState(seed) 68 rs.shuffle(t) 69 70 for ind in range(len(indices[0])): 71 if object_mask.dim() == 2: 72 y, x = indices 73 anchor_emb = embeddings[:, y[ind], x[ind]] 74 anchor_emb = anchor_emb[..., None, None] 75 else: 76 z, y, x = indices 77 anchor_emb = embeddings[:, z[ind], y[ind], x[ind]] 78 anchor_emb = anchor_emb[..., None, None, None] 79 dist_to_mean = torch.norm(mean_embedding - anchor_emb, norm) 80 if dist_to_mean < delta_var: 81 return anchor_emb 82 # if stable anchor has not been found, return mean_embedding 83 return mean_embedding
Sample anchor embeddings from the object mask.
Given a binary mask of an object (object_mask
) and a mean_embedding
vector within the mask,
the function selects a pixel from the mask at random and returns its embedding only if it's closer than
delta_var
from the mean_embedding
.
Arguments:
- embeddings: The embeddings, a ExSPATIAL vector field of an image.
- mean_embedding: The E-dimensional mean of embeddings lying within the
object_mask
. - object_mask: Binary image of a selected object.
- delta_var: The pull force margin of the contrastive loss.
- norm: The vector norm used. By default the frobenius norm is used.
Returns:
Embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found.
302class ExtendedContrastiveLoss(ContrastiveLossBase): 303 """Contrastive loss extended with instance-based loss term and background push term. 304 305 Based on: 306 "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": 307 https://arxiv.org/abs/2103.14572 308 309 Args: 310 delta_var: The hinge distance for the variance term in the discriminative loss. 311 delta_dist: The hinge distance for the distance term in the discriminative loss. 312 norm: The norm to use. 313 alpha: Weight for the variance term of the discrimantive loss. 314 beta: Weight for the distance term of the discriminative loss. 315 gamma: Weight for the regularization term of the discriminative loss. 316 unlabeled_push_weight: The weight term for the unlabeled loss term. 317 instance_term_weight: The weight term for the instance loss term. 318 aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'. 319 pmaps_threshold: The probabilit threshold for the background push term. 320 kwargs: Additional keyword arguments for other loss terms. 321 """ 322 def __init__( 323 self, 324 delta_var: float, 325 delta_dist: float, 326 norm: str = "fro", 327 alpha: float = 1.0, 328 beta: float = 1.0, 329 gamma: float = 0.001, 330 unlabeled_push_weight: float = 1.0, 331 instance_term_weight: float = 1.0, 332 aux_loss: str = "dice", 333 pmaps_threshold: float = 0.9, 334 **kwargs, 335 ): 336 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 337 unlabeled_push_weight=unlabeled_push_weight, 338 instance_term_weight=instance_term_weight) 339 # Init auxiliary loss. 340 assert aux_loss in ["dice", "affinity", "dice_aff"] 341 if aux_loss == "dice": 342 self.aff_loss = None 343 self.dice_loss = DiceLoss() 344 # Additional auxiliary losses. 345 elif aux_loss == "affinity": 346 self.aff_loss = AffinitySideLoss( 347 delta=delta_dist, 348 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 349 n_samples=kwargs.get("n_samples", 9) 350 ) 351 self.dice_loss = None 352 elif aux_loss == "dice_aff": 353 # combine dice and affinity side loss 354 self.dice_weight = kwargs.get("dice_weight", 1.0) 355 self.aff_weight = kwargs.get("aff_weight", 1.0) 356 357 self.aff_loss = AffinitySideLoss( 358 delta=delta_dist, 359 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 360 n_samples=kwargs.get("n_samples", 9) 361 ) 362 self.dice_loss = DiceLoss() 363 364 # Init dist_to_mask kernel which maps distance to the cluster center to instance probability map. 365 self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold) 366 self.init_kwargs = { 367 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 368 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 369 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold 370 } 371 self.init_kwargs.update(kwargs) 372 373 # FIXME stacking per instance here makes this very memory hungry, 374 def _create_instance_pmaps_and_masks(self, embeddings, anchors, target): 375 inst_pmaps = [] 376 inst_masks = [] 377 378 if not inst_masks: 379 return None, None 380 381 # stack along batch dimension 382 inst_pmaps = torch.stack(inst_pmaps) 383 inst_masks = torch.stack(inst_masks) 384 385 return inst_pmaps, inst_masks 386 387 def compute_instance_term(self, embeddings, cluster_means, target): 388 """@private 389 """ 390 assert embeddings.size()[1:] == target.size() 391 392 if self.aff_loss is None: 393 aff_loss = None 394 else: 395 aff_loss = self.aff_loss(embeddings[None], target[None, None]) 396 397 if self.dice_loss is None: 398 dice_loss = None 399 else: 400 dice_loss = [] 401 402 # permute embedding dimension at the end 403 if target.dim() == 2: 404 embeddings = embeddings.permute(1, 2, 0) 405 else: 406 embeddings = embeddings.permute(1, 2, 3, 0) 407 408 # compute random anchors per instance 409 instances = torch.unique(target) 410 for i in instances: 411 if i == 0: 412 continue 413 anchor_emb = cluster_means[i] 414 # FIXME this makes training extremely slow, check with Adrian if this is the latest version 415 # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var) 416 417 distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1) 418 instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0) 419 instance_mask = (target == i).float().unsqueeze(0) 420 421 dice_loss.append(self.dice_loss(instance_pmap, instance_mask)) 422 423 dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0 424 425 assert not (dice_loss is None and aff_loss is None) 426 if dice_loss is None and aff_loss is not None: 427 return aff_loss 428 if dice_loss is not None and aff_loss is None: 429 return dice_loss 430 else: 431 return self.dice_weight * dice_loss + self.aff_weight * aff_loss
Contrastive loss extended with instance-based loss term and background push term.
Based on: "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572
Arguments:
- delta_var: The hinge distance for the variance term in the discriminative loss.
- delta_dist: The hinge distance for the distance term in the discriminative loss.
- norm: The norm to use.
- alpha: Weight for the variance term of the discrimantive loss.
- beta: Weight for the distance term of the discriminative loss.
- gamma: Weight for the regularization term of the discriminative loss.
- unlabeled_push_weight: The weight term for the unlabeled loss term.
- instance_term_weight: The weight term for the instance loss term.
- aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'.
- pmaps_threshold: The probabilit threshold for the background push term.
- kwargs: Additional keyword arguments for other loss terms.
322 def __init__( 323 self, 324 delta_var: float, 325 delta_dist: float, 326 norm: str = "fro", 327 alpha: float = 1.0, 328 beta: float = 1.0, 329 gamma: float = 0.001, 330 unlabeled_push_weight: float = 1.0, 331 instance_term_weight: float = 1.0, 332 aux_loss: str = "dice", 333 pmaps_threshold: float = 0.9, 334 **kwargs, 335 ): 336 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 337 unlabeled_push_weight=unlabeled_push_weight, 338 instance_term_weight=instance_term_weight) 339 # Init auxiliary loss. 340 assert aux_loss in ["dice", "affinity", "dice_aff"] 341 if aux_loss == "dice": 342 self.aff_loss = None 343 self.dice_loss = DiceLoss() 344 # Additional auxiliary losses. 345 elif aux_loss == "affinity": 346 self.aff_loss = AffinitySideLoss( 347 delta=delta_dist, 348 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 349 n_samples=kwargs.get("n_samples", 9) 350 ) 351 self.dice_loss = None 352 elif aux_loss == "dice_aff": 353 # combine dice and affinity side loss 354 self.dice_weight = kwargs.get("dice_weight", 1.0) 355 self.aff_weight = kwargs.get("aff_weight", 1.0) 356 357 self.aff_loss = AffinitySideLoss( 358 delta=delta_dist, 359 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 360 n_samples=kwargs.get("n_samples", 9) 361 ) 362 self.dice_loss = DiceLoss() 363 364 # Init dist_to_mask kernel which maps distance to the cluster center to instance probability map. 365 self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold) 366 self.init_kwargs = { 367 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 368 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 369 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold 370 } 371 self.init_kwargs.update(kwargs)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
434class SPOCOLoss(ExtendedContrastiveLoss): 435 """The full SPOCO Loss for instance segmentation training with sparse instance labels. 436 437 Extends the "classic" contrastive loss with an instance-based term and a unsupervised embedding consistency term. 438 An additional background push term can be added. It is disabled by default because we assume sparse instance labels. 439 440 Based on: 441 "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": 442 https://arxiv.org/abs/2103.14572 443 444 Args: 445 delta_var: The hinge distance for the variance term in the discriminative loss. 446 delta_dist: The hinge distance for the distance term in the discriminative loss. 447 norm: The norm to use. 448 alpha: Weight for the variance term of the discrimantive loss. 449 beta: Weight for the distance term of the discriminative loss. 450 gamma: Weight for the regularization term of the discriminative loss. 451 unlabeled_push_weight: The weight term for the unlabeled loss term. 452 instance_term_weight: The weight term for the instance loss term. 453 aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'. 454 pmaps_threshold: The probabilit threshold for the background push term. 455 max_anchors: The number of anchors to sample for the consistency term. 456 volume_threshold: 457 kwargs: Additional keyword arguments for other loss terms. 458 """ 459 def __init__( 460 self, 461 delta_var: float, 462 delta_dist: float, 463 norm: str = "fro", 464 alpha: float = 1.0, 465 beta: float = 1.0, 466 gamma: float = 0.001, 467 unlabeled_push_weight: float = 0.0, 468 instance_term_weight: float = 1.0, 469 consistency_term_weight: float = 1.0, 470 aux_loss: str = "dice", 471 pmaps_threshold: float = 0.9, 472 max_anchors: int = 20, 473 volume_threshold: float = 0.05, 474 **kwargs, 475 ): 476 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 477 unlabeled_push_weight=unlabeled_push_weight, 478 instance_term_weight=instance_term_weight, 479 aux_loss=aux_loss, 480 pmaps_threshold=pmaps_threshold, 481 **kwargs) 482 483 self.consistency_term_weight = consistency_term_weight 484 self.max_anchors = max_anchors 485 self.volume_threshold = volume_threshold 486 self.consistency_loss = DiceLoss() 487 self.init_kwargs = { 488 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 489 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 490 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold, 491 "max_anchors": max_anchors, "volume_threshold": volume_threshold 492 } 493 self.init_kwargs.update(kwargs) 494 495 def __str__(self): 496 return super().__str__() + f"\nconsistency_term_weight: {self.consistency_term_weight}" 497 498 def _inst_pmap(self, emb, anchor): 499 # compute distance map 500 distance_map = torch.norm(emb - anchor, self.norm, dim=-1) 501 # convert distance map to instance pmaps and return 502 return self.dist_to_mask(distance_map) 503 504 def emb_consistency(self, emb_q, emb_k, mask): 505 """@private 506 """ 507 inst_q = [] 508 inst_k = [] 509 for i in range(self.max_anchors): 510 if mask.sum() < self.volume_threshold * mask.numel(): 511 break 512 513 # get random anchor 514 indices = torch.nonzero(mask, as_tuple=True) 515 ind = np.random.randint(len(indices[0])) 516 517 q_pmap = self._extract_pmap(emb_q, mask, indices, ind) 518 inst_q.append(q_pmap) 519 520 k_pmap = self._extract_pmap(emb_k, mask, indices, ind) 521 inst_k.append(k_pmap) 522 523 # stack along channel dim 524 inst_q = torch.stack(inst_q) 525 inst_k = torch.stack(inst_k) 526 527 loss = self.consistency_loss(inst_q, inst_k) 528 return loss 529 530 def _extract_pmap(self, emb, mask, indices, ind): 531 if mask.dim() == 2: 532 y, x = indices 533 anchor = emb[:, y[ind], x[ind]] 534 emb = emb.permute(1, 2, 0) 535 else: 536 z, y, x = indices 537 anchor = emb[:, z[ind], y[ind], x[ind]] 538 emb = emb.permute(1, 2, 3, 0) 539 540 return self._inst_pmap(emb, anchor) 541 542 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 543 """Compute the SPOCO loss. 544 545 Args: 546 input_: The predicted embeddings. 547 target: The segmentation target. 548 549 Returns: 550 The SPOCO loss. 551 """ 552 assert len(input_) == 2 553 emb_q, emb_k = input_ 554 555 # Compute extended contrastive loss only on the embeddings coming from q. 556 contrastive_loss = super().forward(emb_q, target) 557 558 # TODO enable computing the consistency on all pixels! 559 # Compute consistency term. 560 for e_q, e_k, t in zip(emb_q, emb_k, target): 561 unlabeled_mask = (t[0] == 0).int() 562 if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel(): 563 continue 564 emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask) 565 contrastive_loss += self.consistency_term_weight * emb_consistency_loss 566 567 return contrastive_loss
The full SPOCO Loss for instance segmentation training with sparse instance labels.
Extends the "classic" contrastive loss with an instance-based term and a unsupervised embedding consistency term. An additional background push term can be added. It is disabled by default because we assume sparse instance labels.
Based on: "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572
Arguments:
- delta_var: The hinge distance for the variance term in the discriminative loss.
- delta_dist: The hinge distance for the distance term in the discriminative loss.
- norm: The norm to use.
- alpha: Weight for the variance term of the discrimantive loss.
- beta: Weight for the distance term of the discriminative loss.
- gamma: Weight for the regularization term of the discriminative loss.
- unlabeled_push_weight: The weight term for the unlabeled loss term.
- instance_term_weight: The weight term for the instance loss term.
- aux_loss: The auxiliary loss term to use. One of 'dice', 'affinity', 'dice_aff'.
- pmaps_threshold: The probabilit threshold for the background push term.
- max_anchors: The number of anchors to sample for the consistency term.
- volume_threshold:
- kwargs: Additional keyword arguments for other loss terms.
459 def __init__( 460 self, 461 delta_var: float, 462 delta_dist: float, 463 norm: str = "fro", 464 alpha: float = 1.0, 465 beta: float = 1.0, 466 gamma: float = 0.001, 467 unlabeled_push_weight: float = 0.0, 468 instance_term_weight: float = 1.0, 469 consistency_term_weight: float = 1.0, 470 aux_loss: str = "dice", 471 pmaps_threshold: float = 0.9, 472 max_anchors: int = 20, 473 volume_threshold: float = 0.05, 474 **kwargs, 475 ): 476 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 477 unlabeled_push_weight=unlabeled_push_weight, 478 instance_term_weight=instance_term_weight, 479 aux_loss=aux_loss, 480 pmaps_threshold=pmaps_threshold, 481 **kwargs) 482 483 self.consistency_term_weight = consistency_term_weight 484 self.max_anchors = max_anchors 485 self.volume_threshold = volume_threshold 486 self.consistency_loss = DiceLoss() 487 self.init_kwargs = { 488 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 489 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 490 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold, 491 "max_anchors": max_anchors, "volume_threshold": volume_threshold 492 } 493 self.init_kwargs.update(kwargs)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
542 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 543 """Compute the SPOCO loss. 544 545 Args: 546 input_: The predicted embeddings. 547 target: The segmentation target. 548 549 Returns: 550 The SPOCO loss. 551 """ 552 assert len(input_) == 2 553 emb_q, emb_k = input_ 554 555 # Compute extended contrastive loss only on the embeddings coming from q. 556 contrastive_loss = super().forward(emb_q, target) 557 558 # TODO enable computing the consistency on all pixels! 559 # Compute consistency term. 560 for e_q, e_k, t in zip(emb_q, emb_k, target): 561 unlabeled_mask = (t[0] == 0).int() 562 if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel(): 563 continue 564 emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask) 565 contrastive_loss += self.consistency_term_weight * emb_consistency_loss 566 567 return contrastive_loss
Compute the SPOCO loss.
Arguments:
- input_: The predicted embeddings.
- target: The segmentation target.
Returns:
The SPOCO loss.
570class SPOCOConsistencyLoss(nn.Module): 571 """Unsupervised consistency term computed between embeddings. 572 573 Args: 574 delta_var: Hinge distance of the distance loss term. 575 pmaps_threshold: 576 max_anchors: The maximum number of anchors to compute for the consistency loss. 577 norm: The vector norm used. By default the frobenius norm is used. 578 """ 579 def __init__(self, delta_var: float, pmaps_threshold: float, max_anchors: int = 30, norm: str = "fro"): 580 super().__init__() 581 self.max_anchors = max_anchors 582 self.consistency_loss = DiceLoss() 583 self.norm = norm 584 self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold) 585 self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold, 586 "max_anchors": max_anchors, "norm": norm} 587 588 def _inst_pmap(self, emb, anchor): 589 # compute distance map 590 distance_map = torch.norm(emb - anchor, self.norm, dim=-1) 591 # convert distance map to instance pmaps and return 592 return self.dist_to_mask(distance_map) 593 594 def emb_consistency(self, emb_q, emb_k): 595 """@private 596 """ 597 inst_q = [] 598 inst_k = [] 599 mask = torch.ones(emb_q.shape[1:]) 600 for i in range(self.max_anchors): 601 # get random anchor 602 indices = torch.nonzero(mask, as_tuple=True) 603 ind = np.random.randint(len(indices[0])) 604 605 q_pmap = self._extract_pmap(emb_q, mask, indices, ind) 606 inst_q.append(q_pmap) 607 608 k_pmap = self._extract_pmap(emb_k, mask, indices, ind) 609 inst_k.append(k_pmap) 610 611 # stack along channel dim 612 inst_q = torch.stack(inst_q) 613 inst_k = torch.stack(inst_k) 614 615 loss = self.consistency_loss(inst_q, inst_k) 616 return loss 617 618 def _extract_pmap(self, emb, mask, indices, ind): 619 if mask.dim() == 2: 620 y, x = indices 621 anchor = emb[:, y[ind], x[ind]] 622 emb = emb.permute(1, 2, 0) 623 else: 624 z, y, x = indices 625 anchor = emb[:, z[ind], y[ind], x[ind]] 626 emb = emb.permute(1, 2, 3, 0) 627 628 return self._inst_pmap(emb, anchor) 629 630 def forward(self, emb_q: torch.Tensor, emb_k: torch.Tensor) -> torch.Tensor: 631 """Compute the consistency loss term between embeddings. 632 633 Args: 634 emb_q: The first embedding predictions. 635 emb_k: The second embedding predictions. 636 637 Returns: 638 The consistency loss. 639 """ 640 contrastive_loss = 0.0 641 # compute consistency term 642 for e_q, e_k in zip(emb_q, emb_k): 643 contrastive_loss += self.emb_consistency(e_q, e_k) 644 return contrastive_loss
Unsupervised consistency term computed between embeddings.
Arguments:
- delta_var: Hinge distance of the distance loss term.
- pmaps_threshold:
- max_anchors: The maximum number of anchors to compute for the consistency loss.
- norm: The vector norm used. By default the frobenius norm is used.
579 def __init__(self, delta_var: float, pmaps_threshold: float, max_anchors: int = 30, norm: str = "fro"): 580 super().__init__() 581 self.max_anchors = max_anchors 582 self.consistency_loss = DiceLoss() 583 self.norm = norm 584 self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold) 585 self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold, 586 "max_anchors": max_anchors, "norm": norm}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
630 def forward(self, emb_q: torch.Tensor, emb_k: torch.Tensor) -> torch.Tensor: 631 """Compute the consistency loss term between embeddings. 632 633 Args: 634 emb_q: The first embedding predictions. 635 emb_k: The second embedding predictions. 636 637 Returns: 638 The consistency loss. 639 """ 640 contrastive_loss = 0.0 641 # compute consistency term 642 for e_q, e_k in zip(emb_q, emb_k): 643 contrastive_loss += self.emb_consistency(e_q, e_k) 644 return contrastive_loss
Compute the consistency loss term between embeddings.
Arguments:
- emb_q: The first embedding predictions.
- emb_k: The second embedding predictions.
Returns:
The consistency loss.