torch_em.loss.contrastive
1from warnings import warn 2 3import torch 4import torch.nn as nn 5from . import contrastive_impl as impl 6 7 8def check_consecutive(labels): 9 """Check that the input labels are consecutive and start at zero. 10 """ 11 diff = labels[1:] - labels[:-1] 12 return (labels[0] == 0) and (diff == 1).all() 13 14 15# TODO support more sophisticated ignore labels: 16# - ignore_dist: ignored in distance term 17# - ignore_var: ignored in variance term 18class ContrastiveLoss(nn.Module): 19 """Implementation of contrastive loss defined in https://arxiv.org/pdf/1708.02551.pdf 20 Semantic Instance Segmentation with a Discriminative Loss Function 21 22 This class contians different implementations for the discrimnative loss: 23 - based on pure pytorch, expanding the instance dimension, this is not memory efficient 24 - based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient 25 26 Arguments: 27 delta_var [float] - 28 delta_dist [float] - 29 norm [str] - 30 aplpha [float] - 31 beta [float] - 32 gamma [float] - 33 ignore_label [int] - 34 impl [str] - 35 """ 36 implementations = (None, "scatter", "expand") 37 38 def __init__(self, delta_var, delta_dist, norm="fro", 39 alpha=1.0, beta=1.0, gamma=0.001, 40 ignore_label=None, impl=None): 41 assert ignore_label is None, "Not implemented" # TODO 42 super().__init__() 43 self.delta_var = delta_var 44 self.delta_dist = delta_dist 45 self.norm = norm 46 self.alpha = alpha 47 self.beta = beta 48 self.gamma = gamma 49 self.ignore_label = ignore_label 50 51 assert impl in self.implementations 52 has_torch_scatter = self.has_torch_scatter() 53 if impl is None: 54 if not has_torch_scatter: 55 pt_scatter = "https://github.com/rusty1s/pytorch_scatter" 56 warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.") 57 self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch 58 elif impl == "scatter": 59 assert has_torch_scatter 60 self._contrastive_impl = self._scatter_impl_batch 61 elif impl == "expand": 62 self._contrastive_impl = self._expand_impl_batch 63 64 # all torch_em classes should store init kwargs to easily recreate the init call 65 self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, 66 "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label, 67 "impl": impl} 68 69 @staticmethod 70 def has_torch_scatter(): 71 try: 72 import torch_scatter 73 except ImportError: 74 torch_scatter = None 75 return torch_scatter is not None 76 77 # This implementation expands all tensors to match the instance dimensions. 78 # Hence it's fast, but has high memory consumption. 79 # The implementation does not support masking any instance labels in the loss. 80 def _expand_impl_batch(self, input_batch, target_batch, ndim): 81 # add singleton batch dimension required for further computation 82 input_batch = input_batch.unsqueeze(0) 83 84 # get number of instances in the batch 85 instances = torch.unique(target_batch) 86 assert check_consecutive(instances), f"{instances}" 87 n_instances = instances.size()[0] 88 89 # SPATIAL = D X H X W in 3d case, H X W in 2d case 90 # expand each label as a one-hot vector: N x SPATIAL -> N x C x SPATIAL 91 target_batch = impl.expand_as_one_hot(target_batch, n_instances) 92 93 cluster_means, embeddings_per_instance = impl._compute_cluster_means(input_batch, 94 target_batch, ndim) 95 variance_term = impl._compute_variance_term(cluster_means, embeddings_per_instance, 96 target_batch, ndim, self.norm, self.delta_var) 97 distance_term = impl._compute_distance_term(cluster_means, n_instances, 98 ndim, self.norm, self.delta_dist) 99 regularization_term = impl._compute_regularizer_term(cluster_means, n_instances, 100 ndim, self.norm) 101 # compute total loss 102 return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term 103 104 def _scatter_impl_batch(self, input_batch, target_batch, ndim): 105 # add singleton batch dimension required for further computation 106 input_batch = input_batch.unsqueeze(0) 107 108 instance_ids, instance_sizes = torch.unique(target_batch, return_counts=True) 109 n_instances = len(instance_ids) 110 cluster_means = impl._compute_cluster_means_scatter(input_batch, target_batch, ndim, n_lbl=n_instances) 111 112 variance_term = impl._compute_variance_term_scatter(cluster_means, input_batch, target_batch, self.norm, 113 self.delta_var, instance_sizes) 114 distance_term = impl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist) 115 116 regularization_term = torch.sum( 117 torch.norm(cluster_means, p=self.norm, dim=1) 118 ).div(n_instances) 119 120 # compute total loss 121 return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term 122 123 def forward(self, input_, target): 124 n_batches = input_.shape[0] 125 assert target.dim() == input_.dim() 126 assert target.shape[1] == 1 127 assert n_batches == target.shape[0] 128 assert input_.size()[2:] == target.size()[2:] 129 130 ndim = input_.dim() - 2 131 assert ndim in (2, 3) 132 133 # iterate over the batches 134 loss = 0.0 135 for input_batch, target_batch in zip(input_, target): 136 loss_batch = self._contrastive_impl(input_batch, target_batch, ndim) 137 loss += loss_batch 138 139 return loss.div(n_batches)
def
check_consecutive(labels):
9def check_consecutive(labels): 10 """Check that the input labels are consecutive and start at zero. 11 """ 12 diff = labels[1:] - labels[:-1] 13 return (labels[0] == 0) and (diff == 1).all()
Check that the input labels are consecutive and start at zero.
class
ContrastiveLoss(torch.nn.modules.module.Module):
19class ContrastiveLoss(nn.Module): 20 """Implementation of contrastive loss defined in https://arxiv.org/pdf/1708.02551.pdf 21 Semantic Instance Segmentation with a Discriminative Loss Function 22 23 This class contians different implementations for the discrimnative loss: 24 - based on pure pytorch, expanding the instance dimension, this is not memory efficient 25 - based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient 26 27 Arguments: 28 delta_var [float] - 29 delta_dist [float] - 30 norm [str] - 31 aplpha [float] - 32 beta [float] - 33 gamma [float] - 34 ignore_label [int] - 35 impl [str] - 36 """ 37 implementations = (None, "scatter", "expand") 38 39 def __init__(self, delta_var, delta_dist, norm="fro", 40 alpha=1.0, beta=1.0, gamma=0.001, 41 ignore_label=None, impl=None): 42 assert ignore_label is None, "Not implemented" # TODO 43 super().__init__() 44 self.delta_var = delta_var 45 self.delta_dist = delta_dist 46 self.norm = norm 47 self.alpha = alpha 48 self.beta = beta 49 self.gamma = gamma 50 self.ignore_label = ignore_label 51 52 assert impl in self.implementations 53 has_torch_scatter = self.has_torch_scatter() 54 if impl is None: 55 if not has_torch_scatter: 56 pt_scatter = "https://github.com/rusty1s/pytorch_scatter" 57 warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.") 58 self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch 59 elif impl == "scatter": 60 assert has_torch_scatter 61 self._contrastive_impl = self._scatter_impl_batch 62 elif impl == "expand": 63 self._contrastive_impl = self._expand_impl_batch 64 65 # all torch_em classes should store init kwargs to easily recreate the init call 66 self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, 67 "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label, 68 "impl": impl} 69 70 @staticmethod 71 def has_torch_scatter(): 72 try: 73 import torch_scatter 74 except ImportError: 75 torch_scatter = None 76 return torch_scatter is not None 77 78 # This implementation expands all tensors to match the instance dimensions. 79 # Hence it's fast, but has high memory consumption. 80 # The implementation does not support masking any instance labels in the loss. 81 def _expand_impl_batch(self, input_batch, target_batch, ndim): 82 # add singleton batch dimension required for further computation 83 input_batch = input_batch.unsqueeze(0) 84 85 # get number of instances in the batch 86 instances = torch.unique(target_batch) 87 assert check_consecutive(instances), f"{instances}" 88 n_instances = instances.size()[0] 89 90 # SPATIAL = D X H X W in 3d case, H X W in 2d case 91 # expand each label as a one-hot vector: N x SPATIAL -> N x C x SPATIAL 92 target_batch = impl.expand_as_one_hot(target_batch, n_instances) 93 94 cluster_means, embeddings_per_instance = impl._compute_cluster_means(input_batch, 95 target_batch, ndim) 96 variance_term = impl._compute_variance_term(cluster_means, embeddings_per_instance, 97 target_batch, ndim, self.norm, self.delta_var) 98 distance_term = impl._compute_distance_term(cluster_means, n_instances, 99 ndim, self.norm, self.delta_dist) 100 regularization_term = impl._compute_regularizer_term(cluster_means, n_instances, 101 ndim, self.norm) 102 # compute total loss 103 return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term 104 105 def _scatter_impl_batch(self, input_batch, target_batch, ndim): 106 # add singleton batch dimension required for further computation 107 input_batch = input_batch.unsqueeze(0) 108 109 instance_ids, instance_sizes = torch.unique(target_batch, return_counts=True) 110 n_instances = len(instance_ids) 111 cluster_means = impl._compute_cluster_means_scatter(input_batch, target_batch, ndim, n_lbl=n_instances) 112 113 variance_term = impl._compute_variance_term_scatter(cluster_means, input_batch, target_batch, self.norm, 114 self.delta_var, instance_sizes) 115 distance_term = impl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist) 116 117 regularization_term = torch.sum( 118 torch.norm(cluster_means, p=self.norm, dim=1) 119 ).div(n_instances) 120 121 # compute total loss 122 return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term 123 124 def forward(self, input_, target): 125 n_batches = input_.shape[0] 126 assert target.dim() == input_.dim() 127 assert target.shape[1] == 1 128 assert n_batches == target.shape[0] 129 assert input_.size()[2:] == target.size()[2:] 130 131 ndim = input_.dim() - 2 132 assert ndim in (2, 3) 133 134 # iterate over the batches 135 loss = 0.0 136 for input_batch, target_batch in zip(input_, target): 137 loss_batch = self._contrastive_impl(input_batch, target_batch, ndim) 138 loss += loss_batch 139 140 return loss.div(n_batches)
Implementation of contrastive loss defined in https://arxiv.org/pdf/1708.02551.pdf Semantic Instance Segmentation with a Discriminative Loss Function
This class contians different implementations for the discrimnative loss:
- based on pure pytorch, expanding the instance dimension, this is not memory efficient
- based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient
Arguments:
- delta_var [float] -
- delta_dist [float] -
- norm [str] -
- aplpha [float] -
- beta [float] -
- gamma [float] -
- ignore_label [int] -
- impl [str] -
ContrastiveLoss( delta_var, delta_dist, norm='fro', alpha=1.0, beta=1.0, gamma=0.001, ignore_label=None, impl=None)
39 def __init__(self, delta_var, delta_dist, norm="fro", 40 alpha=1.0, beta=1.0, gamma=0.001, 41 ignore_label=None, impl=None): 42 assert ignore_label is None, "Not implemented" # TODO 43 super().__init__() 44 self.delta_var = delta_var 45 self.delta_dist = delta_dist 46 self.norm = norm 47 self.alpha = alpha 48 self.beta = beta 49 self.gamma = gamma 50 self.ignore_label = ignore_label 51 52 assert impl in self.implementations 53 has_torch_scatter = self.has_torch_scatter() 54 if impl is None: 55 if not has_torch_scatter: 56 pt_scatter = "https://github.com/rusty1s/pytorch_scatter" 57 warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.") 58 self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch 59 elif impl == "scatter": 60 assert has_torch_scatter 61 self._contrastive_impl = self._scatter_impl_batch 62 elif impl == "expand": 63 self._contrastive_impl = self._expand_impl_batch 64 65 # all torch_em classes should store init kwargs to easily recreate the init call 66 self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, 67 "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label, 68 "impl": impl}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
def
forward(self, input_, target):
124 def forward(self, input_, target): 125 n_batches = input_.shape[0] 126 assert target.dim() == input_.dim() 127 assert target.shape[1] == 1 128 assert n_batches == target.shape[0] 129 assert input_.size()[2:] == target.size()[2:] 130 131 ndim = input_.dim() - 2 132 assert ndim in (2, 3) 133 134 # iterate over the batches 135 loss = 0.0 136 for input_batch, target_batch in zip(input_, target): 137 loss_batch = self._contrastive_impl(input_batch, target_batch, ndim) 138 loss += loss_batch 139 140 return loss.div(n_batches)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile