torch_em.loss.spoco_loss
1import math 2 3import numpy as np 4import torch 5import torch.nn as nn 6try: 7 from torch_scatter import scatter_mean 8except ImportError: 9 scatter_mean = None 10 11from . import contrastive_impl as cimpl 12from .affinity_side_loss import AffinitySideLoss 13from .dice import DiceLoss 14 15 16def compute_cluster_means(embeddings, target, n_instances): 17 """ 18 Computes mean embeddings per instance. 19 E - embedding dimension 20 21 Args: 22 embeddings: tensor of pixel embeddings, shape: ExSPATIAL 23 target: one-hot encoded target instances, shape: SPATIAL 24 n_instances: number of instances 25 """ 26 assert scatter_mean is not None, "torch_scatter is required" 27 embeddings = embeddings.flatten(1) 28 target = target.flatten() 29 assert target.min() == 0,\ 30 "The target min value has to be zero, otherwise this will lead to errors in scatter" 31 mean_embeddings = scatter_mean(embeddings, target, dim_size=n_instances) 32 return mean_embeddings.transpose(1, 0) 33 34 35def select_stable_anchor(embeddings, mean_embedding, object_mask, delta_var, norm="fro"): 36 """ 37 Anchor sampling procedure. Given a binary mask of an object (`object_mask`) and a `mean_embedding` vector within 38 the mask, the function selects a pixel from the mask at random and returns its embedding only if it"s closer than 39 `delta_var` from the `mean_embedding`. 40 41 Args: 42 embeddings (torch.Tensor): ExSpatial vector field of an image 43 mean_embedding (torch.Tensor): E-dimensional mean of embeddings lying within the `object_mask` 44 object_mask (torch.Tensor): binary image of a selected object 45 delta_var (float): contrastive loss, pull force margin 46 norm (str): vector norm used, default: Frobenius norm 47 48 Returns: 49 embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found 50 """ 51 indices = torch.nonzero(object_mask, as_tuple=True) 52 # convert to numpy 53 indices = [t.cpu().numpy() for t in indices] 54 55 # randomize coordinates 56 seed = np.random.randint(np.iinfo("int32").max) 57 for t in indices: 58 rs = np.random.RandomState(seed) 59 rs.shuffle(t) 60 61 for ind in range(len(indices[0])): 62 if object_mask.dim() == 2: 63 y, x = indices 64 anchor_emb = embeddings[:, y[ind], x[ind]] 65 anchor_emb = anchor_emb[..., None, None] 66 else: 67 z, y, x = indices 68 anchor_emb = embeddings[:, z[ind], y[ind], x[ind]] 69 anchor_emb = anchor_emb[..., None, None, None] 70 dist_to_mean = torch.norm(mean_embedding - anchor_emb, norm) 71 if dist_to_mean < delta_var: 72 return anchor_emb 73 # if stable anchor has not been found, return mean_embedding 74 return mean_embedding 75 76 77class GaussianKernel(nn.Module): 78 def __init__(self, delta_var, pmaps_threshold): 79 super().__init__() 80 self.delta_var = delta_var 81 # dist_var^2 = -2*sigma*ln(pmaps_threshold) 82 self.two_sigma = delta_var * delta_var / (-math.log(pmaps_threshold)) 83 84 def forward(self, dist_map): 85 return torch.exp(- dist_map * dist_map / self.two_sigma) 86 87 88class CombinedAuxLoss(nn.Module): 89 def __init__(self, losses, weights): 90 super().__init__() 91 self.losses = losses 92 self.weights = weights 93 94 def forward(self, embeddings, target, instance_pmaps, instance_masks): 95 result = 0. 96 for loss, weight in zip(self.losses, self.weights): 97 if isinstance(loss, AffinitySideLoss): 98 # add batch axis / batch and channel axis for embeddings, target 99 result += weight * loss(embeddings[None], target[None, None]) 100 elif instance_masks is not None: 101 result += weight * loss(instance_pmaps, instance_masks).mean() 102 return result 103 104 105class ContrastiveLossBase(nn.Module): 106 """Base class for the spoco losses. 107 """ 108 109 def __init__(self, delta_var, delta_dist, 110 norm="fro", alpha=1., beta=1., gamma=0.001, unlabeled_push_weight=0.0, 111 instance_term_weight=1.0, impl=None): 112 assert scatter_mean is not None, "Spoco loss requires pytorch_scatter" 113 super().__init__() 114 self.delta_var = delta_var 115 self.delta_dist = delta_dist 116 self.norm = norm 117 self.alpha = alpha 118 self.beta = beta 119 self.gamma = gamma 120 self.unlabeled_push_weight = unlabeled_push_weight 121 self.unlabeled_push = unlabeled_push_weight > 0 122 self.instance_term_weight = instance_term_weight 123 124 def __str__(self): 125 return super().__str__() + f"\ndelta_var: {self.delta_var}\ndelta_dist: {self.delta_dist}" \ 126 f"\nalpha: {self.alpha}\nbeta: {self.beta}\ngamma: {self.gamma}" \ 127 f"\nunlabeled_push_weight: {self.unlabeled_push_weight}" \ 128 f"\ninstance_term_weight: {self.instance_term_weight}" 129 130 def _compute_variance_term(self, cluster_means, embeddings, target, instance_counts, ignore_zero_label): 131 """Computes the variance term, i.e. intra-cluster pull force that draws embeddings towards the mean embedding 132 133 C - number of clusters (instances) 134 E - embedding dimension 135 SPATIAL - volume shape, i.e. DxHxW for 3D/ HxW for 2D 136 137 Args: 138 cluster_means: mean embedding of each instance, tensor (CxE) 139 embeddings: embeddings vectors per instance, tensor (ExSPATIAL) 140 target: label tensor (1xSPATIAL); each label is represented as one-hot vector 141 instance_counts: number of voxels per instance 142 ignore_zero_label: if True ignores the cluster corresponding to the 0-label 143 """ 144 assert target.dim() in (2, 3) 145 ignore_labels = [0] if ignore_zero_label else None 146 return cimpl._compute_variance_term_scatter( 147 cluster_means, embeddings.unsqueeze(0), target.unsqueeze(0), 148 self.norm, self.delta_var, instance_counts, ignore_labels 149 ) 150 151 def _compute_unlabeled_push(self, cluster_means, embeddings, target): 152 assert target.dim() in (2, 3) 153 n_instances = cluster_means.shape[0] 154 155 # permute embedding dimension at the end 156 if target.dim() == 2: 157 embeddings = embeddings.permute(1, 2, 0) 158 else: 159 embeddings = embeddings.permute(1, 2, 3, 0) 160 161 # decrease number of instances `C` since we're ignoring 0-label 162 n_instances -= 1 163 # if there is only 0-label in the target return 0 164 if n_instances == 0: 165 return 0.0 166 167 background_mask = target == 0 168 n_background = background_mask.sum() 169 background_push = 0.0 170 # skip embedding corresponding to the background pixels 171 for cluster_mean in cluster_means[1:]: 172 # compute distances between embeddings and a given cluster_mean 173 dist_to_mean = torch.norm(embeddings - cluster_mean, self.norm, dim=-1) 174 # apply background mask and compute hinge 175 dist_hinged = torch.clamp((self.delta_dist - dist_to_mean) * background_mask, min=0) ** 2 176 background_push += torch.sum(dist_hinged) / n_background 177 178 # normalize by the number of instances 179 return background_push / n_instances 180 181 # def _compute_distance_term_scatter(cluster_means, norm, delta_dist): 182 def _compute_distance_term(self, cluster_means, ignore_zero_label): 183 """ 184 Compute the distance term, i.e an inter-cluster push-force that pushes clusters away from each other, increasing 185 the distance between cluster centers 186 187 Args: 188 cluster_means: mean embedding of each instance, tensor (CxE) 189 ignore_zero_label: if True ignores the cluster corresponding to the 0-label 190 """ 191 ignore_labels = [0] if ignore_zero_label else None 192 return cimpl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist, ignore_labels) 193 194 def _compute_regularizer_term(self, cluster_means): 195 """ 196 Computes the regularizer term, i.e. a small pull-force that draws all clusters towards origin to keep 197 the network activations bounded 198 """ 199 # compute the norm of the mean embeddings 200 norms = torch.norm(cluster_means, p=self.norm, dim=1) 201 # return the average norm per batch 202 return torch.sum(norms) / cluster_means.size(0) 203 204 def compute_instance_term(self, embeddings, cluster_means, target): 205 """Computes auxiliary loss based on embeddings and a given list of target 206 instances together with their mean embeddings. 207 208 Args: 209 embeddings (torch.tensor): pixel embeddings (ExSPATIAL) 210 cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL) 211 target (torch.tensor): ground truth instance segmentation (SPATIAL) 212 213 Returns: 214 float: value of the instance-based term 215 """ 216 raise NotImplementedError 217 218 def forward(self, input_, target): 219 """ 220 Args: 221 input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims) 222 expects float32 tensor 223 target (torch.tensor): ground truth instance segmentation (Nx1DxHxW) 224 expects int64 tensor 225 Returns: 226 Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term 227 + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term 228 """ 229 # enable calling this loss from the spoco trainer, which passes a tuple 230 if isinstance(input_, tuple): 231 assert len(input_) == 2 232 input_ = input_[0] 233 234 n_batches = input_.shape[0] 235 # compute the loss per each instance in the batch separately 236 # and sum it up in the per_instance variable 237 loss = 0.0 238 for single_input, single_target in zip(input_, target): 239 # compare spatial dimensions 240 assert single_input.shape[1:] == single_target.shape[1:], f"{single_input.shape}, {single_target.shape}" 241 assert single_target.shape[0] == 1 242 single_target = single_target[0] 243 244 contains_bg = 0 in single_target 245 ignore_zero_label = self.unlabeled_push and contains_bg 246 247 # get number of instances in the batch instance 248 instance_ids, instance_counts = torch.unique(single_target, return_counts=True) 249 250 # get the number of instances 251 C = instance_ids.size(0) 252 253 # compute mean embeddings (output is of shape CxE) 254 cluster_means = compute_cluster_means(single_input, single_target, C) 255 256 # compute variance term, i.e. pull force 257 variance_term = self._compute_variance_term( 258 cluster_means, single_input, single_target, instance_counts, ignore_zero_label 259 ) 260 261 # compute unlabeled push force, i.e. push force between 262 # the mean cluster embeddings and embeddings of background pixels 263 # compute only ignore_zero_label is True, i.e. a given patch contains background label 264 unlabeled_push_term = 0.0 265 if self.unlabeled_push and contains_bg: 266 unlabeled_push_term = self._compute_unlabeled_push(cluster_means, single_input, single_target) 267 268 # compute the instance-based auxiliary loss 269 instance_term = self.compute_instance_term(single_input, cluster_means, single_target) 270 271 # compute distance term, i.e. push force 272 distance_term = self._compute_distance_term(cluster_means, ignore_zero_label) 273 274 # compute regularization term 275 regularization_term = self._compute_regularizer_term(cluster_means) 276 277 # compute total loss and sum it up 278 loss = self.alpha * variance_term + \ 279 self.beta * distance_term + \ 280 self.gamma * regularization_term + \ 281 self.instance_term_weight * instance_term + \ 282 self.unlabeled_push_weight * unlabeled_push_term 283 284 loss += loss 285 286 # reduce across the batch dimension 287 return loss.div(n_batches) 288 289 290class ExtendedContrastiveLoss(ContrastiveLossBase): 291 """Contrastive loss extended with instance-based loss term and background push term. 292 293 Based on: 294 "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572 295 """ 296 297 def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001, 298 unlabeled_push_weight=1.0, instance_term_weight=1.0, aux_loss="dice", pmaps_threshold=0.9, **kwargs): 299 300 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 301 unlabeled_push_weight=unlabeled_push_weight, 302 instance_term_weight=instance_term_weight) 303 304 # init auxiliary loss 305 assert aux_loss in ["dice", "affinity", "dice_aff"] 306 if aux_loss == "dice": 307 self.aff_loss = None 308 self.dice_loss = DiceLoss() 309 # additional auxiliary losses 310 elif aux_loss == "affinity": 311 self.aff_loss = AffinitySideLoss( 312 delta=delta_dist, 313 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 314 n_samples=kwargs.get("n_samples", 9) 315 ) 316 self.dice_loss = None 317 elif aux_loss == "dice_aff": 318 # combine dice and affinity side loss 319 self.dice_weight = kwargs.get("dice_weight", 1.0) 320 self.aff_weight = kwargs.get("aff_weight", 1.0) 321 322 self.aff_loss = AffinitySideLoss( 323 delta=delta_dist, 324 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 325 n_samples=kwargs.get("n_samples", 9) 326 ) 327 self.dice_loss = DiceLoss() 328 329 # init dist_to_mask kernel which maps distance to the cluster center to instance probability map 330 self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold) 331 self.init_kwargs = { 332 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 333 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 334 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold 335 } 336 self.init_kwargs.update(kwargs) 337 338 # FIXME stacking per instance here makes this very memory hungry, 339 def _create_instance_pmaps_and_masks(self, embeddings, anchors, target): 340 inst_pmaps = [] 341 inst_masks = [] 342 343 if not inst_masks: 344 return None, None 345 346 # stack along batch dimension 347 inst_pmaps = torch.stack(inst_pmaps) 348 inst_masks = torch.stack(inst_masks) 349 350 return inst_pmaps, inst_masks 351 352 def compute_instance_term(self, embeddings, cluster_means, target): 353 assert embeddings.size()[1:] == target.size() 354 355 if self.aff_loss is None: 356 aff_loss = None 357 else: 358 aff_loss = self.aff_loss(embeddings[None], target[None, None]) 359 360 if self.dice_loss is None: 361 dice_loss = None 362 else: 363 dice_loss = [] 364 365 # permute embedding dimension at the end 366 if target.dim() == 2: 367 embeddings = embeddings.permute(1, 2, 0) 368 else: 369 embeddings = embeddings.permute(1, 2, 3, 0) 370 371 # compute random anchors per instance 372 instances = torch.unique(target) 373 for i in instances: 374 if i == 0: 375 continue 376 anchor_emb = cluster_means[i] 377 # FIXME this makes training extremely slow, check with Adrian if this is the latest version 378 # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var) 379 380 distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1) 381 instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0) 382 instance_mask = (target == i).float().unsqueeze(0) 383 384 dice_loss.append(self.dice_loss(instance_pmap, instance_mask)) 385 386 dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0 387 388 assert not (dice_loss is None and aff_loss is None) 389 if dice_loss is None and aff_loss is not None: 390 return aff_loss 391 if dice_loss is not None and aff_loss is None: 392 return dice_loss 393 else: 394 return self.dice_weight * dice_loss + self.aff_weight * aff_loss 395 396 397class SPOCOLoss(ExtendedContrastiveLoss): 398 """The full SPOCO Loss for instance segmentation training with sparse instance labels. 399 400 Extends the "classic" contrastive loss with an instance-based term and a embedding consistency term. 401 (The unlabeled push term is turned off by default, since we assume sparse instance labels). 402 403 Based on: 404 "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572 405 """ 406 407 def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001, 408 unlabeled_push_weight=0.0, instance_term_weight=1.0, consistency_term_weight=1.0, 409 aux_loss="dice", pmaps_threshold=0.9, max_anchors=20, volume_threshold=0.05, **kwargs): 410 411 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 412 unlabeled_push_weight=unlabeled_push_weight, 413 instance_term_weight=instance_term_weight, 414 aux_loss=aux_loss, 415 pmaps_threshold=pmaps_threshold, 416 **kwargs) 417 418 self.consistency_term_weight = consistency_term_weight 419 self.max_anchors = max_anchors 420 self.volume_threshold = volume_threshold 421 self.consistency_loss = DiceLoss() 422 self.init_kwargs = { 423 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 424 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 425 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold, 426 "max_anchors": max_anchors, "volume_threshold": volume_threshold 427 } 428 self.init_kwargs.update(kwargs) 429 430 def __str__(self): 431 return super().__str__() + f"\nconsistency_term_weight: {self.consistency_term_weight}" 432 433 def _inst_pmap(self, emb, anchor): 434 # compute distance map 435 distance_map = torch.norm(emb - anchor, self.norm, dim=-1) 436 # convert distance map to instance pmaps and return 437 return self.dist_to_mask(distance_map) 438 439 def emb_consistency(self, emb_q, emb_k, mask): 440 inst_q = [] 441 inst_k = [] 442 for i in range(self.max_anchors): 443 if mask.sum() < self.volume_threshold * mask.numel(): 444 break 445 446 # get random anchor 447 indices = torch.nonzero(mask, as_tuple=True) 448 ind = np.random.randint(len(indices[0])) 449 450 q_pmap = self._extract_pmap(emb_q, mask, indices, ind) 451 inst_q.append(q_pmap) 452 453 k_pmap = self._extract_pmap(emb_k, mask, indices, ind) 454 inst_k.append(k_pmap) 455 456 # stack along channel dim 457 inst_q = torch.stack(inst_q) 458 inst_k = torch.stack(inst_k) 459 460 loss = self.consistency_loss(inst_q, inst_k) 461 return loss 462 463 def _extract_pmap(self, emb, mask, indices, ind): 464 if mask.dim() == 2: 465 y, x = indices 466 anchor = emb[:, y[ind], x[ind]] 467 emb = emb.permute(1, 2, 0) 468 else: 469 z, y, x = indices 470 anchor = emb[:, z[ind], y[ind], x[ind]] 471 emb = emb.permute(1, 2, 3, 0) 472 473 return self._inst_pmap(emb, anchor) 474 475 def forward(self, input, target): 476 assert len(input) == 2 477 emb_q, emb_k = input 478 479 # compute extended contrastive loss only on the embeddings coming from q 480 contrastive_loss = super().forward(emb_q, target) 481 482 # TODO enable computing the consistency on all pixels! 483 # compute consistency term 484 for e_q, e_k, t in zip(emb_q, emb_k, target): 485 unlabeled_mask = (t[0] == 0).int() 486 if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel(): 487 continue 488 emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask) 489 contrastive_loss += self.consistency_term_weight * emb_consistency_loss 490 491 return contrastive_loss 492 493 494# FIXME clarify what this is! 495class SPOCOConsistencyLoss(nn.Module): 496 def __init__(self, delta_var, pmaps_threshold, max_anchors=30, norm="fro"): 497 super().__init__() 498 self.max_anchors = max_anchors 499 self.consistency_loss = DiceLoss() 500 self.norm = norm 501 self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold) 502 self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold, 503 "max_anchors": max_anchors, "norm": norm} 504 505 def _inst_pmap(self, emb, anchor): 506 # compute distance map 507 distance_map = torch.norm(emb - anchor, self.norm, dim=-1) 508 # convert distance map to instance pmaps and return 509 return self.dist_to_mask(distance_map) 510 511 def emb_consistency(self, emb_q, emb_k): 512 inst_q = [] 513 inst_k = [] 514 mask = torch.ones(emb_q.shape[1:]) 515 for i in range(self.max_anchors): 516 # get random anchor 517 indices = torch.nonzero(mask, as_tuple=True) 518 ind = np.random.randint(len(indices[0])) 519 520 q_pmap = self._extract_pmap(emb_q, mask, indices, ind) 521 inst_q.append(q_pmap) 522 523 k_pmap = self._extract_pmap(emb_k, mask, indices, ind) 524 inst_k.append(k_pmap) 525 526 # stack along channel dim 527 inst_q = torch.stack(inst_q) 528 inst_k = torch.stack(inst_k) 529 530 loss = self.consistency_loss(inst_q, inst_k) 531 return loss 532 533 def _extract_pmap(self, emb, mask, indices, ind): 534 if mask.dim() == 2: 535 y, x = indices 536 anchor = emb[:, y[ind], x[ind]] 537 emb = emb.permute(1, 2, 0) 538 else: 539 z, y, x = indices 540 anchor = emb[:, z[ind], y[ind], x[ind]] 541 emb = emb.permute(1, 2, 3, 0) 542 543 return self._inst_pmap(emb, anchor) 544 545 def forward(self, emb_q, emb_k): 546 contrastive_loss = 0.0 547 # compute consistency term 548 for e_q, e_k in zip(emb_q, emb_k): 549 contrastive_loss += self.emb_consistency(e_q, e_k) 550 return contrastive_loss
17def compute_cluster_means(embeddings, target, n_instances): 18 """ 19 Computes mean embeddings per instance. 20 E - embedding dimension 21 22 Args: 23 embeddings: tensor of pixel embeddings, shape: ExSPATIAL 24 target: one-hot encoded target instances, shape: SPATIAL 25 n_instances: number of instances 26 """ 27 assert scatter_mean is not None, "torch_scatter is required" 28 embeddings = embeddings.flatten(1) 29 target = target.flatten() 30 assert target.min() == 0,\ 31 "The target min value has to be zero, otherwise this will lead to errors in scatter" 32 mean_embeddings = scatter_mean(embeddings, target, dim_size=n_instances) 33 return mean_embeddings.transpose(1, 0)
Computes mean embeddings per instance. E - embedding dimension
Arguments:
- embeddings: tensor of pixel embeddings, shape: ExSPATIAL
- target: one-hot encoded target instances, shape: SPATIAL
- n_instances: number of instances
36def select_stable_anchor(embeddings, mean_embedding, object_mask, delta_var, norm="fro"): 37 """ 38 Anchor sampling procedure. Given a binary mask of an object (`object_mask`) and a `mean_embedding` vector within 39 the mask, the function selects a pixel from the mask at random and returns its embedding only if it"s closer than 40 `delta_var` from the `mean_embedding`. 41 42 Args: 43 embeddings (torch.Tensor): ExSpatial vector field of an image 44 mean_embedding (torch.Tensor): E-dimensional mean of embeddings lying within the `object_mask` 45 object_mask (torch.Tensor): binary image of a selected object 46 delta_var (float): contrastive loss, pull force margin 47 norm (str): vector norm used, default: Frobenius norm 48 49 Returns: 50 embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found 51 """ 52 indices = torch.nonzero(object_mask, as_tuple=True) 53 # convert to numpy 54 indices = [t.cpu().numpy() for t in indices] 55 56 # randomize coordinates 57 seed = np.random.randint(np.iinfo("int32").max) 58 for t in indices: 59 rs = np.random.RandomState(seed) 60 rs.shuffle(t) 61 62 for ind in range(len(indices[0])): 63 if object_mask.dim() == 2: 64 y, x = indices 65 anchor_emb = embeddings[:, y[ind], x[ind]] 66 anchor_emb = anchor_emb[..., None, None] 67 else: 68 z, y, x = indices 69 anchor_emb = embeddings[:, z[ind], y[ind], x[ind]] 70 anchor_emb = anchor_emb[..., None, None, None] 71 dist_to_mean = torch.norm(mean_embedding - anchor_emb, norm) 72 if dist_to_mean < delta_var: 73 return anchor_emb 74 # if stable anchor has not been found, return mean_embedding 75 return mean_embedding
Anchor sampling procedure. Given a binary mask of an object (object_mask
) and a mean_embedding
vector within
the mask, the function selects a pixel from the mask at random and returns its embedding only if it"s closer than
delta_var
from the mean_embedding
.
Arguments:
- embeddings (torch.Tensor): ExSpatial vector field of an image
- mean_embedding (torch.Tensor): E-dimensional mean of embeddings lying within the
object_mask
- object_mask (torch.Tensor): binary image of a selected object
- delta_var (float): contrastive loss, pull force margin
- norm (str): vector norm used, default: Frobenius norm
Returns:
embedding of a selected pixel within the mask or the mean embedding if stable anchor could be found
78class GaussianKernel(nn.Module): 79 def __init__(self, delta_var, pmaps_threshold): 80 super().__init__() 81 self.delta_var = delta_var 82 # dist_var^2 = -2*sigma*ln(pmaps_threshold) 83 self.two_sigma = delta_var * delta_var / (-math.log(pmaps_threshold)) 84 85 def forward(self, dist_map): 86 return torch.exp(- dist_map * dist_map / self.two_sigma)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
79 def __init__(self, delta_var, pmaps_threshold): 80 super().__init__() 81 self.delta_var = delta_var 82 # dist_var^2 = -2*sigma*ln(pmaps_threshold) 83 self.two_sigma = delta_var * delta_var / (-math.log(pmaps_threshold))
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
89class CombinedAuxLoss(nn.Module): 90 def __init__(self, losses, weights): 91 super().__init__() 92 self.losses = losses 93 self.weights = weights 94 95 def forward(self, embeddings, target, instance_pmaps, instance_masks): 96 result = 0. 97 for loss, weight in zip(self.losses, self.weights): 98 if isinstance(loss, AffinitySideLoss): 99 # add batch axis / batch and channel axis for embeddings, target 100 result += weight * loss(embeddings[None], target[None, None]) 101 elif instance_masks is not None: 102 result += weight * loss(instance_pmaps, instance_masks).mean() 103 return result
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
90 def __init__(self, losses, weights): 91 super().__init__() 92 self.losses = losses 93 self.weights = weights
Initializes internal Module state, shared by both nn.Module and ScriptModule.
95 def forward(self, embeddings, target, instance_pmaps, instance_masks): 96 result = 0. 97 for loss, weight in zip(self.losses, self.weights): 98 if isinstance(loss, AffinitySideLoss): 99 # add batch axis / batch and channel axis for embeddings, target 100 result += weight * loss(embeddings[None], target[None, None]) 101 elif instance_masks is not None: 102 result += weight * loss(instance_pmaps, instance_masks).mean() 103 return result
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
106class ContrastiveLossBase(nn.Module): 107 """Base class for the spoco losses. 108 """ 109 110 def __init__(self, delta_var, delta_dist, 111 norm="fro", alpha=1., beta=1., gamma=0.001, unlabeled_push_weight=0.0, 112 instance_term_weight=1.0, impl=None): 113 assert scatter_mean is not None, "Spoco loss requires pytorch_scatter" 114 super().__init__() 115 self.delta_var = delta_var 116 self.delta_dist = delta_dist 117 self.norm = norm 118 self.alpha = alpha 119 self.beta = beta 120 self.gamma = gamma 121 self.unlabeled_push_weight = unlabeled_push_weight 122 self.unlabeled_push = unlabeled_push_weight > 0 123 self.instance_term_weight = instance_term_weight 124 125 def __str__(self): 126 return super().__str__() + f"\ndelta_var: {self.delta_var}\ndelta_dist: {self.delta_dist}" \ 127 f"\nalpha: {self.alpha}\nbeta: {self.beta}\ngamma: {self.gamma}" \ 128 f"\nunlabeled_push_weight: {self.unlabeled_push_weight}" \ 129 f"\ninstance_term_weight: {self.instance_term_weight}" 130 131 def _compute_variance_term(self, cluster_means, embeddings, target, instance_counts, ignore_zero_label): 132 """Computes the variance term, i.e. intra-cluster pull force that draws embeddings towards the mean embedding 133 134 C - number of clusters (instances) 135 E - embedding dimension 136 SPATIAL - volume shape, i.e. DxHxW for 3D/ HxW for 2D 137 138 Args: 139 cluster_means: mean embedding of each instance, tensor (CxE) 140 embeddings: embeddings vectors per instance, tensor (ExSPATIAL) 141 target: label tensor (1xSPATIAL); each label is represented as one-hot vector 142 instance_counts: number of voxels per instance 143 ignore_zero_label: if True ignores the cluster corresponding to the 0-label 144 """ 145 assert target.dim() in (2, 3) 146 ignore_labels = [0] if ignore_zero_label else None 147 return cimpl._compute_variance_term_scatter( 148 cluster_means, embeddings.unsqueeze(0), target.unsqueeze(0), 149 self.norm, self.delta_var, instance_counts, ignore_labels 150 ) 151 152 def _compute_unlabeled_push(self, cluster_means, embeddings, target): 153 assert target.dim() in (2, 3) 154 n_instances = cluster_means.shape[0] 155 156 # permute embedding dimension at the end 157 if target.dim() == 2: 158 embeddings = embeddings.permute(1, 2, 0) 159 else: 160 embeddings = embeddings.permute(1, 2, 3, 0) 161 162 # decrease number of instances `C` since we're ignoring 0-label 163 n_instances -= 1 164 # if there is only 0-label in the target return 0 165 if n_instances == 0: 166 return 0.0 167 168 background_mask = target == 0 169 n_background = background_mask.sum() 170 background_push = 0.0 171 # skip embedding corresponding to the background pixels 172 for cluster_mean in cluster_means[1:]: 173 # compute distances between embeddings and a given cluster_mean 174 dist_to_mean = torch.norm(embeddings - cluster_mean, self.norm, dim=-1) 175 # apply background mask and compute hinge 176 dist_hinged = torch.clamp((self.delta_dist - dist_to_mean) * background_mask, min=0) ** 2 177 background_push += torch.sum(dist_hinged) / n_background 178 179 # normalize by the number of instances 180 return background_push / n_instances 181 182 # def _compute_distance_term_scatter(cluster_means, norm, delta_dist): 183 def _compute_distance_term(self, cluster_means, ignore_zero_label): 184 """ 185 Compute the distance term, i.e an inter-cluster push-force that pushes clusters away from each other, increasing 186 the distance between cluster centers 187 188 Args: 189 cluster_means: mean embedding of each instance, tensor (CxE) 190 ignore_zero_label: if True ignores the cluster corresponding to the 0-label 191 """ 192 ignore_labels = [0] if ignore_zero_label else None 193 return cimpl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist, ignore_labels) 194 195 def _compute_regularizer_term(self, cluster_means): 196 """ 197 Computes the regularizer term, i.e. a small pull-force that draws all clusters towards origin to keep 198 the network activations bounded 199 """ 200 # compute the norm of the mean embeddings 201 norms = torch.norm(cluster_means, p=self.norm, dim=1) 202 # return the average norm per batch 203 return torch.sum(norms) / cluster_means.size(0) 204 205 def compute_instance_term(self, embeddings, cluster_means, target): 206 """Computes auxiliary loss based on embeddings and a given list of target 207 instances together with their mean embeddings. 208 209 Args: 210 embeddings (torch.tensor): pixel embeddings (ExSPATIAL) 211 cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL) 212 target (torch.tensor): ground truth instance segmentation (SPATIAL) 213 214 Returns: 215 float: value of the instance-based term 216 """ 217 raise NotImplementedError 218 219 def forward(self, input_, target): 220 """ 221 Args: 222 input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims) 223 expects float32 tensor 224 target (torch.tensor): ground truth instance segmentation (Nx1DxHxW) 225 expects int64 tensor 226 Returns: 227 Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term 228 + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term 229 """ 230 # enable calling this loss from the spoco trainer, which passes a tuple 231 if isinstance(input_, tuple): 232 assert len(input_) == 2 233 input_ = input_[0] 234 235 n_batches = input_.shape[0] 236 # compute the loss per each instance in the batch separately 237 # and sum it up in the per_instance variable 238 loss = 0.0 239 for single_input, single_target in zip(input_, target): 240 # compare spatial dimensions 241 assert single_input.shape[1:] == single_target.shape[1:], f"{single_input.shape}, {single_target.shape}" 242 assert single_target.shape[0] == 1 243 single_target = single_target[0] 244 245 contains_bg = 0 in single_target 246 ignore_zero_label = self.unlabeled_push and contains_bg 247 248 # get number of instances in the batch instance 249 instance_ids, instance_counts = torch.unique(single_target, return_counts=True) 250 251 # get the number of instances 252 C = instance_ids.size(0) 253 254 # compute mean embeddings (output is of shape CxE) 255 cluster_means = compute_cluster_means(single_input, single_target, C) 256 257 # compute variance term, i.e. pull force 258 variance_term = self._compute_variance_term( 259 cluster_means, single_input, single_target, instance_counts, ignore_zero_label 260 ) 261 262 # compute unlabeled push force, i.e. push force between 263 # the mean cluster embeddings and embeddings of background pixels 264 # compute only ignore_zero_label is True, i.e. a given patch contains background label 265 unlabeled_push_term = 0.0 266 if self.unlabeled_push and contains_bg: 267 unlabeled_push_term = self._compute_unlabeled_push(cluster_means, single_input, single_target) 268 269 # compute the instance-based auxiliary loss 270 instance_term = self.compute_instance_term(single_input, cluster_means, single_target) 271 272 # compute distance term, i.e. push force 273 distance_term = self._compute_distance_term(cluster_means, ignore_zero_label) 274 275 # compute regularization term 276 regularization_term = self._compute_regularizer_term(cluster_means) 277 278 # compute total loss and sum it up 279 loss = self.alpha * variance_term + \ 280 self.beta * distance_term + \ 281 self.gamma * regularization_term + \ 282 self.instance_term_weight * instance_term + \ 283 self.unlabeled_push_weight * unlabeled_push_term 284 285 loss += loss 286 287 # reduce across the batch dimension 288 return loss.div(n_batches)
Base class for the spoco losses.
110 def __init__(self, delta_var, delta_dist, 111 norm="fro", alpha=1., beta=1., gamma=0.001, unlabeled_push_weight=0.0, 112 instance_term_weight=1.0, impl=None): 113 assert scatter_mean is not None, "Spoco loss requires pytorch_scatter" 114 super().__init__() 115 self.delta_var = delta_var 116 self.delta_dist = delta_dist 117 self.norm = norm 118 self.alpha = alpha 119 self.beta = beta 120 self.gamma = gamma 121 self.unlabeled_push_weight = unlabeled_push_weight 122 self.unlabeled_push = unlabeled_push_weight > 0 123 self.instance_term_weight = instance_term_weight
Initializes internal Module state, shared by both nn.Module and ScriptModule.
205 def compute_instance_term(self, embeddings, cluster_means, target): 206 """Computes auxiliary loss based on embeddings and a given list of target 207 instances together with their mean embeddings. 208 209 Args: 210 embeddings (torch.tensor): pixel embeddings (ExSPATIAL) 211 cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL) 212 target (torch.tensor): ground truth instance segmentation (SPATIAL) 213 214 Returns: 215 float: value of the instance-based term 216 """ 217 raise NotImplementedError
Computes auxiliary loss based on embeddings and a given list of target instances together with their mean embeddings.
Arguments:
- embeddings (torch.tensor): pixel embeddings (ExSPATIAL)
- cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL)
- target (torch.tensor): ground truth instance segmentation (SPATIAL)
Returns:
float: value of the instance-based term
219 def forward(self, input_, target): 220 """ 221 Args: 222 input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims) 223 expects float32 tensor 224 target (torch.tensor): ground truth instance segmentation (Nx1DxHxW) 225 expects int64 tensor 226 Returns: 227 Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term 228 + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term 229 """ 230 # enable calling this loss from the spoco trainer, which passes a tuple 231 if isinstance(input_, tuple): 232 assert len(input_) == 2 233 input_ = input_[0] 234 235 n_batches = input_.shape[0] 236 # compute the loss per each instance in the batch separately 237 # and sum it up in the per_instance variable 238 loss = 0.0 239 for single_input, single_target in zip(input_, target): 240 # compare spatial dimensions 241 assert single_input.shape[1:] == single_target.shape[1:], f"{single_input.shape}, {single_target.shape}" 242 assert single_target.shape[0] == 1 243 single_target = single_target[0] 244 245 contains_bg = 0 in single_target 246 ignore_zero_label = self.unlabeled_push and contains_bg 247 248 # get number of instances in the batch instance 249 instance_ids, instance_counts = torch.unique(single_target, return_counts=True) 250 251 # get the number of instances 252 C = instance_ids.size(0) 253 254 # compute mean embeddings (output is of shape CxE) 255 cluster_means = compute_cluster_means(single_input, single_target, C) 256 257 # compute variance term, i.e. pull force 258 variance_term = self._compute_variance_term( 259 cluster_means, single_input, single_target, instance_counts, ignore_zero_label 260 ) 261 262 # compute unlabeled push force, i.e. push force between 263 # the mean cluster embeddings and embeddings of background pixels 264 # compute only ignore_zero_label is True, i.e. a given patch contains background label 265 unlabeled_push_term = 0.0 266 if self.unlabeled_push and contains_bg: 267 unlabeled_push_term = self._compute_unlabeled_push(cluster_means, single_input, single_target) 268 269 # compute the instance-based auxiliary loss 270 instance_term = self.compute_instance_term(single_input, cluster_means, single_target) 271 272 # compute distance term, i.e. push force 273 distance_term = self._compute_distance_term(cluster_means, ignore_zero_label) 274 275 # compute regularization term 276 regularization_term = self._compute_regularizer_term(cluster_means) 277 278 # compute total loss and sum it up 279 loss = self.alpha * variance_term + \ 280 self.beta * distance_term + \ 281 self.gamma * regularization_term + \ 282 self.instance_term_weight * instance_term + \ 283 self.unlabeled_push_weight * unlabeled_push_term 284 285 loss += loss 286 287 # reduce across the batch dimension 288 return loss.div(n_batches)
Arguments:
- input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims) expects float32 tensor
- target (torch.tensor): ground truth instance segmentation (Nx1DxHxW) expects int64 tensor
Returns:
Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
291class ExtendedContrastiveLoss(ContrastiveLossBase): 292 """Contrastive loss extended with instance-based loss term and background push term. 293 294 Based on: 295 "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572 296 """ 297 298 def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001, 299 unlabeled_push_weight=1.0, instance_term_weight=1.0, aux_loss="dice", pmaps_threshold=0.9, **kwargs): 300 301 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 302 unlabeled_push_weight=unlabeled_push_weight, 303 instance_term_weight=instance_term_weight) 304 305 # init auxiliary loss 306 assert aux_loss in ["dice", "affinity", "dice_aff"] 307 if aux_loss == "dice": 308 self.aff_loss = None 309 self.dice_loss = DiceLoss() 310 # additional auxiliary losses 311 elif aux_loss == "affinity": 312 self.aff_loss = AffinitySideLoss( 313 delta=delta_dist, 314 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 315 n_samples=kwargs.get("n_samples", 9) 316 ) 317 self.dice_loss = None 318 elif aux_loss == "dice_aff": 319 # combine dice and affinity side loss 320 self.dice_weight = kwargs.get("dice_weight", 1.0) 321 self.aff_weight = kwargs.get("aff_weight", 1.0) 322 323 self.aff_loss = AffinitySideLoss( 324 delta=delta_dist, 325 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 326 n_samples=kwargs.get("n_samples", 9) 327 ) 328 self.dice_loss = DiceLoss() 329 330 # init dist_to_mask kernel which maps distance to the cluster center to instance probability map 331 self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold) 332 self.init_kwargs = { 333 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 334 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 335 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold 336 } 337 self.init_kwargs.update(kwargs) 338 339 # FIXME stacking per instance here makes this very memory hungry, 340 def _create_instance_pmaps_and_masks(self, embeddings, anchors, target): 341 inst_pmaps = [] 342 inst_masks = [] 343 344 if not inst_masks: 345 return None, None 346 347 # stack along batch dimension 348 inst_pmaps = torch.stack(inst_pmaps) 349 inst_masks = torch.stack(inst_masks) 350 351 return inst_pmaps, inst_masks 352 353 def compute_instance_term(self, embeddings, cluster_means, target): 354 assert embeddings.size()[1:] == target.size() 355 356 if self.aff_loss is None: 357 aff_loss = None 358 else: 359 aff_loss = self.aff_loss(embeddings[None], target[None, None]) 360 361 if self.dice_loss is None: 362 dice_loss = None 363 else: 364 dice_loss = [] 365 366 # permute embedding dimension at the end 367 if target.dim() == 2: 368 embeddings = embeddings.permute(1, 2, 0) 369 else: 370 embeddings = embeddings.permute(1, 2, 3, 0) 371 372 # compute random anchors per instance 373 instances = torch.unique(target) 374 for i in instances: 375 if i == 0: 376 continue 377 anchor_emb = cluster_means[i] 378 # FIXME this makes training extremely slow, check with Adrian if this is the latest version 379 # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var) 380 381 distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1) 382 instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0) 383 instance_mask = (target == i).float().unsqueeze(0) 384 385 dice_loss.append(self.dice_loss(instance_pmap, instance_mask)) 386 387 dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0 388 389 assert not (dice_loss is None and aff_loss is None) 390 if dice_loss is None and aff_loss is not None: 391 return aff_loss 392 if dice_loss is not None and aff_loss is None: 393 return dice_loss 394 else: 395 return self.dice_weight * dice_loss + self.aff_weight * aff_loss
Contrastive loss extended with instance-based loss term and background push term.
Based on: "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572
298 def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001, 299 unlabeled_push_weight=1.0, instance_term_weight=1.0, aux_loss="dice", pmaps_threshold=0.9, **kwargs): 300 301 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 302 unlabeled_push_weight=unlabeled_push_weight, 303 instance_term_weight=instance_term_weight) 304 305 # init auxiliary loss 306 assert aux_loss in ["dice", "affinity", "dice_aff"] 307 if aux_loss == "dice": 308 self.aff_loss = None 309 self.dice_loss = DiceLoss() 310 # additional auxiliary losses 311 elif aux_loss == "affinity": 312 self.aff_loss = AffinitySideLoss( 313 delta=delta_dist, 314 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 315 n_samples=kwargs.get("n_samples", 9) 316 ) 317 self.dice_loss = None 318 elif aux_loss == "dice_aff": 319 # combine dice and affinity side loss 320 self.dice_weight = kwargs.get("dice_weight", 1.0) 321 self.aff_weight = kwargs.get("aff_weight", 1.0) 322 323 self.aff_loss = AffinitySideLoss( 324 delta=delta_dist, 325 offset_ranges=kwargs.get("offset_ranges", [(-18, 18), (-18, 18)]), 326 n_samples=kwargs.get("n_samples", 9) 327 ) 328 self.dice_loss = DiceLoss() 329 330 # init dist_to_mask kernel which maps distance to the cluster center to instance probability map 331 self.dist_to_mask = GaussianKernel(delta_var=self.delta_var, pmaps_threshold=pmaps_threshold) 332 self.init_kwargs = { 333 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 334 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 335 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold 336 } 337 self.init_kwargs.update(kwargs)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
353 def compute_instance_term(self, embeddings, cluster_means, target): 354 assert embeddings.size()[1:] == target.size() 355 356 if self.aff_loss is None: 357 aff_loss = None 358 else: 359 aff_loss = self.aff_loss(embeddings[None], target[None, None]) 360 361 if self.dice_loss is None: 362 dice_loss = None 363 else: 364 dice_loss = [] 365 366 # permute embedding dimension at the end 367 if target.dim() == 2: 368 embeddings = embeddings.permute(1, 2, 0) 369 else: 370 embeddings = embeddings.permute(1, 2, 3, 0) 371 372 # compute random anchors per instance 373 instances = torch.unique(target) 374 for i in instances: 375 if i == 0: 376 continue 377 anchor_emb = cluster_means[i] 378 # FIXME this makes training extremely slow, check with Adrian if this is the latest version 379 # anchor_emb = select_stable_anchor(embeddings, cluster_means[i], target == i, self.delta_var) 380 381 distance_map = torch.norm(embeddings - anchor_emb, self.norm, dim=-1) 382 instance_pmap = self.dist_to_mask(distance_map).unsqueeze(0) 383 instance_mask = (target == i).float().unsqueeze(0) 384 385 dice_loss.append(self.dice_loss(instance_pmap, instance_mask)) 386 387 dice_loss = torch.tensor(dice_loss).to(embeddings.device).mean() if dice_loss else 0.0 388 389 assert not (dice_loss is None and aff_loss is None) 390 if dice_loss is None and aff_loss is not None: 391 return aff_loss 392 if dice_loss is not None and aff_loss is None: 393 return dice_loss 394 else: 395 return self.dice_weight * dice_loss + self.aff_weight * aff_loss
Computes auxiliary loss based on embeddings and a given list of target instances together with their mean embeddings.
Arguments:
- embeddings (torch.tensor): pixel embeddings (ExSPATIAL)
- cluster_means (torch.tensor): mean embeddings per instance (CxExSINGLETON_SPATIAL)
- target (torch.tensor): ground truth instance segmentation (SPATIAL)
Returns:
float: value of the instance-based term
Inherited Members
- ContrastiveLossBase
- delta_var
- delta_dist
- norm
- alpha
- beta
- gamma
- unlabeled_push_weight
- unlabeled_push
- instance_term_weight
- forward
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
398class SPOCOLoss(ExtendedContrastiveLoss): 399 """The full SPOCO Loss for instance segmentation training with sparse instance labels. 400 401 Extends the "classic" contrastive loss with an instance-based term and a embedding consistency term. 402 (The unlabeled push term is turned off by default, since we assume sparse instance labels). 403 404 Based on: 405 "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572 406 """ 407 408 def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001, 409 unlabeled_push_weight=0.0, instance_term_weight=1.0, consistency_term_weight=1.0, 410 aux_loss="dice", pmaps_threshold=0.9, max_anchors=20, volume_threshold=0.05, **kwargs): 411 412 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 413 unlabeled_push_weight=unlabeled_push_weight, 414 instance_term_weight=instance_term_weight, 415 aux_loss=aux_loss, 416 pmaps_threshold=pmaps_threshold, 417 **kwargs) 418 419 self.consistency_term_weight = consistency_term_weight 420 self.max_anchors = max_anchors 421 self.volume_threshold = volume_threshold 422 self.consistency_loss = DiceLoss() 423 self.init_kwargs = { 424 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 425 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 426 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold, 427 "max_anchors": max_anchors, "volume_threshold": volume_threshold 428 } 429 self.init_kwargs.update(kwargs) 430 431 def __str__(self): 432 return super().__str__() + f"\nconsistency_term_weight: {self.consistency_term_weight}" 433 434 def _inst_pmap(self, emb, anchor): 435 # compute distance map 436 distance_map = torch.norm(emb - anchor, self.norm, dim=-1) 437 # convert distance map to instance pmaps and return 438 return self.dist_to_mask(distance_map) 439 440 def emb_consistency(self, emb_q, emb_k, mask): 441 inst_q = [] 442 inst_k = [] 443 for i in range(self.max_anchors): 444 if mask.sum() < self.volume_threshold * mask.numel(): 445 break 446 447 # get random anchor 448 indices = torch.nonzero(mask, as_tuple=True) 449 ind = np.random.randint(len(indices[0])) 450 451 q_pmap = self._extract_pmap(emb_q, mask, indices, ind) 452 inst_q.append(q_pmap) 453 454 k_pmap = self._extract_pmap(emb_k, mask, indices, ind) 455 inst_k.append(k_pmap) 456 457 # stack along channel dim 458 inst_q = torch.stack(inst_q) 459 inst_k = torch.stack(inst_k) 460 461 loss = self.consistency_loss(inst_q, inst_k) 462 return loss 463 464 def _extract_pmap(self, emb, mask, indices, ind): 465 if mask.dim() == 2: 466 y, x = indices 467 anchor = emb[:, y[ind], x[ind]] 468 emb = emb.permute(1, 2, 0) 469 else: 470 z, y, x = indices 471 anchor = emb[:, z[ind], y[ind], x[ind]] 472 emb = emb.permute(1, 2, 3, 0) 473 474 return self._inst_pmap(emb, anchor) 475 476 def forward(self, input, target): 477 assert len(input) == 2 478 emb_q, emb_k = input 479 480 # compute extended contrastive loss only on the embeddings coming from q 481 contrastive_loss = super().forward(emb_q, target) 482 483 # TODO enable computing the consistency on all pixels! 484 # compute consistency term 485 for e_q, e_k, t in zip(emb_q, emb_k, target): 486 unlabeled_mask = (t[0] == 0).int() 487 if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel(): 488 continue 489 emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask) 490 contrastive_loss += self.consistency_term_weight * emb_consistency_loss 491 492 return contrastive_loss
The full SPOCO Loss for instance segmentation training with sparse instance labels.
Extends the "classic" contrastive loss with an instance-based term and a embedding consistency term. (The unlabeled push term is turned off by default, since we assume sparse instance labels).
Based on: "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572
408 def __init__(self, delta_var, delta_dist, norm="fro", alpha=1.0, beta=1.0, gamma=0.001, 409 unlabeled_push_weight=0.0, instance_term_weight=1.0, consistency_term_weight=1.0, 410 aux_loss="dice", pmaps_threshold=0.9, max_anchors=20, volume_threshold=0.05, **kwargs): 411 412 super().__init__(delta_var, delta_dist, norm=norm, alpha=alpha, beta=beta, gamma=gamma, 413 unlabeled_push_weight=unlabeled_push_weight, 414 instance_term_weight=instance_term_weight, 415 aux_loss=aux_loss, 416 pmaps_threshold=pmaps_threshold, 417 **kwargs) 418 419 self.consistency_term_weight = consistency_term_weight 420 self.max_anchors = max_anchors 421 self.volume_threshold = volume_threshold 422 self.consistency_loss = DiceLoss() 423 self.init_kwargs = { 424 "delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, "alpha": alpha, "beta": beta, 425 "gamma": gamma, "unlabeled_push_weight": unlabeled_push_weight, 426 "instance_term_weight": instance_term_weight, "aux_loss": aux_loss, "pmaps_threshold": pmaps_threshold, 427 "max_anchors": max_anchors, "volume_threshold": volume_threshold 428 } 429 self.init_kwargs.update(kwargs)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
440 def emb_consistency(self, emb_q, emb_k, mask): 441 inst_q = [] 442 inst_k = [] 443 for i in range(self.max_anchors): 444 if mask.sum() < self.volume_threshold * mask.numel(): 445 break 446 447 # get random anchor 448 indices = torch.nonzero(mask, as_tuple=True) 449 ind = np.random.randint(len(indices[0])) 450 451 q_pmap = self._extract_pmap(emb_q, mask, indices, ind) 452 inst_q.append(q_pmap) 453 454 k_pmap = self._extract_pmap(emb_k, mask, indices, ind) 455 inst_k.append(k_pmap) 456 457 # stack along channel dim 458 inst_q = torch.stack(inst_q) 459 inst_k = torch.stack(inst_k) 460 461 loss = self.consistency_loss(inst_q, inst_k) 462 return loss
476 def forward(self, input, target): 477 assert len(input) == 2 478 emb_q, emb_k = input 479 480 # compute extended contrastive loss only on the embeddings coming from q 481 contrastive_loss = super().forward(emb_q, target) 482 483 # TODO enable computing the consistency on all pixels! 484 # compute consistency term 485 for e_q, e_k, t in zip(emb_q, emb_k, target): 486 unlabeled_mask = (t[0] == 0).int() 487 if unlabeled_mask.sum() < self.volume_threshold * unlabeled_mask.numel(): 488 continue 489 emb_consistency_loss = self.emb_consistency(e_q, e_k, unlabeled_mask) 490 contrastive_loss += self.consistency_term_weight * emb_consistency_loss 491 492 return contrastive_loss
Arguments:
- input_ (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims) expects float32 tensor
- target (torch.tensor): ground truth instance segmentation (Nx1DxHxW) expects int64 tensor
Returns:
Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term + instance_term_weight * instance_term + unlabeled_push_weight * unlabeled_push_term
Inherited Members
- ContrastiveLossBase
- delta_var
- delta_dist
- norm
- alpha
- beta
- gamma
- unlabeled_push_weight
- unlabeled_push
- instance_term_weight
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
496class SPOCOConsistencyLoss(nn.Module): 497 def __init__(self, delta_var, pmaps_threshold, max_anchors=30, norm="fro"): 498 super().__init__() 499 self.max_anchors = max_anchors 500 self.consistency_loss = DiceLoss() 501 self.norm = norm 502 self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold) 503 self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold, 504 "max_anchors": max_anchors, "norm": norm} 505 506 def _inst_pmap(self, emb, anchor): 507 # compute distance map 508 distance_map = torch.norm(emb - anchor, self.norm, dim=-1) 509 # convert distance map to instance pmaps and return 510 return self.dist_to_mask(distance_map) 511 512 def emb_consistency(self, emb_q, emb_k): 513 inst_q = [] 514 inst_k = [] 515 mask = torch.ones(emb_q.shape[1:]) 516 for i in range(self.max_anchors): 517 # get random anchor 518 indices = torch.nonzero(mask, as_tuple=True) 519 ind = np.random.randint(len(indices[0])) 520 521 q_pmap = self._extract_pmap(emb_q, mask, indices, ind) 522 inst_q.append(q_pmap) 523 524 k_pmap = self._extract_pmap(emb_k, mask, indices, ind) 525 inst_k.append(k_pmap) 526 527 # stack along channel dim 528 inst_q = torch.stack(inst_q) 529 inst_k = torch.stack(inst_k) 530 531 loss = self.consistency_loss(inst_q, inst_k) 532 return loss 533 534 def _extract_pmap(self, emb, mask, indices, ind): 535 if mask.dim() == 2: 536 y, x = indices 537 anchor = emb[:, y[ind], x[ind]] 538 emb = emb.permute(1, 2, 0) 539 else: 540 z, y, x = indices 541 anchor = emb[:, z[ind], y[ind], x[ind]] 542 emb = emb.permute(1, 2, 3, 0) 543 544 return self._inst_pmap(emb, anchor) 545 546 def forward(self, emb_q, emb_k): 547 contrastive_loss = 0.0 548 # compute consistency term 549 for e_q, e_k in zip(emb_q, emb_k): 550 contrastive_loss += self.emb_consistency(e_q, e_k) 551 return contrastive_loss
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
497 def __init__(self, delta_var, pmaps_threshold, max_anchors=30, norm="fro"): 498 super().__init__() 499 self.max_anchors = max_anchors 500 self.consistency_loss = DiceLoss() 501 self.norm = norm 502 self.dist_to_mask = GaussianKernel(delta_var=delta_var, pmaps_threshold=pmaps_threshold) 503 self.init_kwargs = {"delta_var": delta_var, "pmaps_threshold": pmaps_threshold, 504 "max_anchors": max_anchors, "norm": norm}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
512 def emb_consistency(self, emb_q, emb_k): 513 inst_q = [] 514 inst_k = [] 515 mask = torch.ones(emb_q.shape[1:]) 516 for i in range(self.max_anchors): 517 # get random anchor 518 indices = torch.nonzero(mask, as_tuple=True) 519 ind = np.random.randint(len(indices[0])) 520 521 q_pmap = self._extract_pmap(emb_q, mask, indices, ind) 522 inst_q.append(q_pmap) 523 524 k_pmap = self._extract_pmap(emb_k, mask, indices, ind) 525 inst_k.append(k_pmap) 526 527 # stack along channel dim 528 inst_q = torch.stack(inst_q) 529 inst_k = torch.stack(inst_k) 530 531 loss = self.consistency_loss(inst_q, inst_k) 532 return loss
546 def forward(self, emb_q, emb_k): 547 contrastive_loss = 0.0 548 # compute consistency term 549 for e_q, e_k in zip(emb_q, emb_k): 550 contrastive_loss += self.emb_consistency(e_q, e_k) 551 return contrastive_loss
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile