torch_em.loss.contrastive
1from warnings import warn 2from typing import Optional 3 4import torch 5import torch.nn as nn 6from . import contrastive_impl as impl 7 8 9def check_consecutive(labels: torch.Tensor) -> bool: 10 """Check that the input labels are consecutive and start at zero. 11 12 Args: 13 labels: The labels to check. 14 15 Returns: 16 Whether the labels are consecutive. 17 """ 18 diff = labels[1:] - labels[:-1] 19 return (labels[0] == 0) and (diff == 1).all() 20 21 22# TODO support more sophisticated ignore labels: 23# - ignore_dist: ignored in distance term 24# - ignore_var: ignored in variance term 25class ContrastiveLoss(nn.Module): 26 """Implementation of a contrastive segmentation loss. 27 28 From "Semantic Instance Segmentation with a Discriminative Loss Function": 29 https://arxiv.org/pdf/1708.02551.pdf 30 31 This class contains different implementations for the discrimnative loss: 32 - Based on pure pytorch, expanding the instance dimension, this is not memory efficient. 33 - Based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient. 34 35 Args: 36 delta_var: The hinge distance for the variance term. 37 The variance term corresponds to the attractive term of the loss function. 38 delta_dist: The hinge distance for the distance term. 39 The distance term corresponds to the repulsive term of the loss function. 40 norm: The norm to use. 41 alpha: Weight for the variance term of the loss. 42 beta: Weight for the distance term of the loss. 43 gamma: Weight for the regularization term of the loss. 44 ignore_label: Ignore label to exclude from the loss computation. 45 impl: Implementation of the loss to use, either 'scatter' or 'expand'. 46 """ 47 implementations = (None, "scatter", "expand") 48 49 def __init__( 50 self, 51 delta_var: float, 52 delta_dist: float, 53 norm: str = "fro", 54 alpha: float = 1.0, 55 beta: float = 1.0, 56 gamma: float = 0.001, 57 ignore_label: Optional[int] = None, 58 impl: Optional[str] = None 59 ): 60 assert ignore_label is None, "Not implemented" 61 super().__init__() 62 self.delta_var = delta_var 63 self.delta_dist = delta_dist 64 self.norm = norm 65 self.alpha = alpha 66 self.beta = beta 67 self.gamma = gamma 68 self.ignore_label = ignore_label 69 70 assert impl in self.implementations 71 has_torch_scatter = self.has_torch_scatter() 72 if impl is None: 73 if not has_torch_scatter: 74 pt_scatter = "https://github.com/rusty1s/pytorch_scatter" 75 warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.") 76 self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch 77 elif impl == "scatter": 78 assert has_torch_scatter 79 self._contrastive_impl = self._scatter_impl_batch 80 elif impl == "expand": 81 self._contrastive_impl = self._expand_impl_batch 82 83 # all torch_em classes should store init kwargs to easily recreate the init call 84 self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, 85 "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label, 86 "impl": impl} 87 88 @staticmethod 89 def has_torch_scatter(): 90 """@private 91 """ 92 try: 93 import torch_scatter 94 except ImportError: 95 torch_scatter = None 96 return torch_scatter is not None 97 98 # This implementation expands all tensors to match the instance dimensions. 99 # Hence it's fast, but has high memory consumption. 100 # The implementation does not support masking any instance labels in the loss. 101 def _expand_impl_batch(self, input_batch, target_batch, ndim): 102 # add singleton batch dimension required for further computation 103 input_batch = input_batch.unsqueeze(0) 104 105 # get number of instances in the batch 106 instances = torch.unique(target_batch) 107 assert check_consecutive(instances), f"{instances}" 108 n_instances = instances.size()[0] 109 110 # SPATIAL = D X H X W in 3d case, H X W in 2d case 111 # expand each label as a one-hot vector: N x SPATIAL -> N x C x SPATIAL 112 target_batch = impl.expand_as_one_hot(target_batch, n_instances) 113 114 cluster_means, embeddings_per_instance = impl._compute_cluster_means(input_batch, 115 target_batch, ndim) 116 variance_term = impl._compute_variance_term(cluster_means, embeddings_per_instance, 117 target_batch, ndim, self.norm, self.delta_var) 118 distance_term = impl._compute_distance_term(cluster_means, n_instances, 119 ndim, self.norm, self.delta_dist) 120 regularization_term = impl._compute_regularizer_term(cluster_means, n_instances, 121 ndim, self.norm) 122 # compute total loss 123 return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term 124 125 def _scatter_impl_batch(self, input_batch, target_batch, ndim): 126 # add singleton batch dimension required for further computation 127 input_batch = input_batch.unsqueeze(0) 128 129 instance_ids, instance_sizes = torch.unique(target_batch, return_counts=True) 130 n_instances = len(instance_ids) 131 cluster_means = impl._compute_cluster_means_scatter(input_batch, target_batch, ndim, n_lbl=n_instances) 132 133 variance_term = impl._compute_variance_term_scatter(cluster_means, input_batch, target_batch, self.norm, 134 self.delta_var, instance_sizes) 135 distance_term = impl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist) 136 137 regularization_term = torch.sum( 138 torch.norm(cluster_means, p=self.norm, dim=1) 139 ).div(n_instances) 140 141 # Compute the combined loss. 142 return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term 143 144 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 145 """Compute the discrimantive loss. 146 147 Args: 148 input_: The embedding predictions. 149 target: The target segmentation. 150 151 Returns: 152 The discriminative loss value. 153 """ 154 n_batches = input_.shape[0] 155 assert target.dim() == input_.dim() 156 assert target.shape[1] == 1 157 assert n_batches == target.shape[0] 158 assert input_.size()[2:] == target.size()[2:] 159 160 ndim = input_.dim() - 2 161 assert ndim in (2, 3) 162 163 # iterate over the batches 164 loss = 0.0 165 for input_batch, target_batch in zip(input_, target): 166 loss_batch = self._contrastive_impl(input_batch, target_batch, ndim) 167 loss += loss_batch 168 169 return loss.div(n_batches)
def
check_consecutive(labels: torch.Tensor) -> bool:
10def check_consecutive(labels: torch.Tensor) -> bool: 11 """Check that the input labels are consecutive and start at zero. 12 13 Args: 14 labels: The labels to check. 15 16 Returns: 17 Whether the labels are consecutive. 18 """ 19 diff = labels[1:] - labels[:-1] 20 return (labels[0] == 0) and (diff == 1).all()
Check that the input labels are consecutive and start at zero.
Arguments:
- labels: The labels to check.
Returns:
Whether the labels are consecutive.
class
ContrastiveLoss(torch.nn.modules.module.Module):
26class ContrastiveLoss(nn.Module): 27 """Implementation of a contrastive segmentation loss. 28 29 From "Semantic Instance Segmentation with a Discriminative Loss Function": 30 https://arxiv.org/pdf/1708.02551.pdf 31 32 This class contains different implementations for the discrimnative loss: 33 - Based on pure pytorch, expanding the instance dimension, this is not memory efficient. 34 - Based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient. 35 36 Args: 37 delta_var: The hinge distance for the variance term. 38 The variance term corresponds to the attractive term of the loss function. 39 delta_dist: The hinge distance for the distance term. 40 The distance term corresponds to the repulsive term of the loss function. 41 norm: The norm to use. 42 alpha: Weight for the variance term of the loss. 43 beta: Weight for the distance term of the loss. 44 gamma: Weight for the regularization term of the loss. 45 ignore_label: Ignore label to exclude from the loss computation. 46 impl: Implementation of the loss to use, either 'scatter' or 'expand'. 47 """ 48 implementations = (None, "scatter", "expand") 49 50 def __init__( 51 self, 52 delta_var: float, 53 delta_dist: float, 54 norm: str = "fro", 55 alpha: float = 1.0, 56 beta: float = 1.0, 57 gamma: float = 0.001, 58 ignore_label: Optional[int] = None, 59 impl: Optional[str] = None 60 ): 61 assert ignore_label is None, "Not implemented" 62 super().__init__() 63 self.delta_var = delta_var 64 self.delta_dist = delta_dist 65 self.norm = norm 66 self.alpha = alpha 67 self.beta = beta 68 self.gamma = gamma 69 self.ignore_label = ignore_label 70 71 assert impl in self.implementations 72 has_torch_scatter = self.has_torch_scatter() 73 if impl is None: 74 if not has_torch_scatter: 75 pt_scatter = "https://github.com/rusty1s/pytorch_scatter" 76 warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.") 77 self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch 78 elif impl == "scatter": 79 assert has_torch_scatter 80 self._contrastive_impl = self._scatter_impl_batch 81 elif impl == "expand": 82 self._contrastive_impl = self._expand_impl_batch 83 84 # all torch_em classes should store init kwargs to easily recreate the init call 85 self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, 86 "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label, 87 "impl": impl} 88 89 @staticmethod 90 def has_torch_scatter(): 91 """@private 92 """ 93 try: 94 import torch_scatter 95 except ImportError: 96 torch_scatter = None 97 return torch_scatter is not None 98 99 # This implementation expands all tensors to match the instance dimensions. 100 # Hence it's fast, but has high memory consumption. 101 # The implementation does not support masking any instance labels in the loss. 102 def _expand_impl_batch(self, input_batch, target_batch, ndim): 103 # add singleton batch dimension required for further computation 104 input_batch = input_batch.unsqueeze(0) 105 106 # get number of instances in the batch 107 instances = torch.unique(target_batch) 108 assert check_consecutive(instances), f"{instances}" 109 n_instances = instances.size()[0] 110 111 # SPATIAL = D X H X W in 3d case, H X W in 2d case 112 # expand each label as a one-hot vector: N x SPATIAL -> N x C x SPATIAL 113 target_batch = impl.expand_as_one_hot(target_batch, n_instances) 114 115 cluster_means, embeddings_per_instance = impl._compute_cluster_means(input_batch, 116 target_batch, ndim) 117 variance_term = impl._compute_variance_term(cluster_means, embeddings_per_instance, 118 target_batch, ndim, self.norm, self.delta_var) 119 distance_term = impl._compute_distance_term(cluster_means, n_instances, 120 ndim, self.norm, self.delta_dist) 121 regularization_term = impl._compute_regularizer_term(cluster_means, n_instances, 122 ndim, self.norm) 123 # compute total loss 124 return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term 125 126 def _scatter_impl_batch(self, input_batch, target_batch, ndim): 127 # add singleton batch dimension required for further computation 128 input_batch = input_batch.unsqueeze(0) 129 130 instance_ids, instance_sizes = torch.unique(target_batch, return_counts=True) 131 n_instances = len(instance_ids) 132 cluster_means = impl._compute_cluster_means_scatter(input_batch, target_batch, ndim, n_lbl=n_instances) 133 134 variance_term = impl._compute_variance_term_scatter(cluster_means, input_batch, target_batch, self.norm, 135 self.delta_var, instance_sizes) 136 distance_term = impl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist) 137 138 regularization_term = torch.sum( 139 torch.norm(cluster_means, p=self.norm, dim=1) 140 ).div(n_instances) 141 142 # Compute the combined loss. 143 return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term 144 145 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 146 """Compute the discrimantive loss. 147 148 Args: 149 input_: The embedding predictions. 150 target: The target segmentation. 151 152 Returns: 153 The discriminative loss value. 154 """ 155 n_batches = input_.shape[0] 156 assert target.dim() == input_.dim() 157 assert target.shape[1] == 1 158 assert n_batches == target.shape[0] 159 assert input_.size()[2:] == target.size()[2:] 160 161 ndim = input_.dim() - 2 162 assert ndim in (2, 3) 163 164 # iterate over the batches 165 loss = 0.0 166 for input_batch, target_batch in zip(input_, target): 167 loss_batch = self._contrastive_impl(input_batch, target_batch, ndim) 168 loss += loss_batch 169 170 return loss.div(n_batches)
Implementation of a contrastive segmentation loss.
From "Semantic Instance Segmentation with a Discriminative Loss Function": https://arxiv.org/pdf/1708.02551.pdf
This class contains different implementations for the discrimnative loss:
- Based on pure pytorch, expanding the instance dimension, this is not memory efficient.
- Based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient.
Arguments:
- delta_var: The hinge distance for the variance term. The variance term corresponds to the attractive term of the loss function.
- delta_dist: The hinge distance for the distance term. The distance term corresponds to the repulsive term of the loss function.
- norm: The norm to use.
- alpha: Weight for the variance term of the loss.
- beta: Weight for the distance term of the loss.
- gamma: Weight for the regularization term of the loss.
- ignore_label: Ignore label to exclude from the loss computation.
- impl: Implementation of the loss to use, either 'scatter' or 'expand'.
ContrastiveLoss( delta_var: float, delta_dist: float, norm: str = 'fro', alpha: float = 1.0, beta: float = 1.0, gamma: float = 0.001, ignore_label: Optional[int] = None, impl: Optional[str] = None)
50 def __init__( 51 self, 52 delta_var: float, 53 delta_dist: float, 54 norm: str = "fro", 55 alpha: float = 1.0, 56 beta: float = 1.0, 57 gamma: float = 0.001, 58 ignore_label: Optional[int] = None, 59 impl: Optional[str] = None 60 ): 61 assert ignore_label is None, "Not implemented" 62 super().__init__() 63 self.delta_var = delta_var 64 self.delta_dist = delta_dist 65 self.norm = norm 66 self.alpha = alpha 67 self.beta = beta 68 self.gamma = gamma 69 self.ignore_label = ignore_label 70 71 assert impl in self.implementations 72 has_torch_scatter = self.has_torch_scatter() 73 if impl is None: 74 if not has_torch_scatter: 75 pt_scatter = "https://github.com/rusty1s/pytorch_scatter" 76 warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.") 77 self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch 78 elif impl == "scatter": 79 assert has_torch_scatter 80 self._contrastive_impl = self._scatter_impl_batch 81 elif impl == "expand": 82 self._contrastive_impl = self._expand_impl_batch 83 84 # all torch_em classes should store init kwargs to easily recreate the init call 85 self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm, 86 "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label, 87 "impl": impl}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
145 def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 146 """Compute the discrimantive loss. 147 148 Args: 149 input_: The embedding predictions. 150 target: The target segmentation. 151 152 Returns: 153 The discriminative loss value. 154 """ 155 n_batches = input_.shape[0] 156 assert target.dim() == input_.dim() 157 assert target.shape[1] == 1 158 assert n_batches == target.shape[0] 159 assert input_.size()[2:] == target.size()[2:] 160 161 ndim = input_.dim() - 2 162 assert ndim in (2, 3) 163 164 # iterate over the batches 165 loss = 0.0 166 for input_batch, target_batch in zip(input_, target): 167 loss_batch = self._contrastive_impl(input_batch, target_batch, ndim) 168 loss += loss_batch 169 170 return loss.div(n_batches)
Compute the discrimantive loss.
Arguments:
- input_: The embedding predictions.
- target: The target segmentation.
Returns:
The discriminative loss value.