torch_em.loss.contrastive

  1from warnings import warn
  2from typing import Optional
  3
  4import torch
  5import torch.nn as nn
  6from . import contrastive_impl as impl
  7
  8
  9def check_consecutive(labels: torch.Tensor) -> bool:
 10    """Check that the input labels are consecutive and start at zero.
 11
 12    Args:
 13        labels: The labels to check.
 14
 15    Returns:
 16        Whether the labels are consecutive.
 17    """
 18    diff = labels[1:] - labels[:-1]
 19    return (labels[0] == 0) and (diff == 1).all()
 20
 21
 22# TODO support more sophisticated ignore labels:
 23# - ignore_dist: ignored in distance term
 24# - ignore_var: ignored in variance term
 25class ContrastiveLoss(nn.Module):
 26    """Implementation of a contrastive segmentation loss.
 27
 28    From "Semantic Instance Segmentation with a Discriminative Loss Function":
 29    https://arxiv.org/pdf/1708.02551.pdf
 30
 31    This class contains different implementations for the discrimnative loss:
 32    - Based on pure pytorch, expanding the instance dimension, this is not memory efficient.
 33    - Based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient.
 34
 35    Args:
 36        delta_var: The hinge distance for the variance term.
 37            The variance term corresponds to the attractive term of the loss function.
 38        delta_dist: The hinge distance for the distance term.
 39            The distance term corresponds to the repulsive term of the loss function.
 40        norm: The norm to use.
 41        alpha: Weight for the variance term of the loss.
 42        beta: Weight for the distance term of the loss.
 43        gamma: Weight for the regularization term of the loss.
 44        ignore_label: Ignore label to exclude from the loss computation.
 45        impl: Implementation of the loss to use, either 'scatter' or 'expand'.
 46    """
 47    implementations = (None, "scatter", "expand")
 48
 49    def __init__(
 50        self,
 51        delta_var: float,
 52        delta_dist: float,
 53        norm: str = "fro",
 54        alpha: float = 1.0,
 55        beta: float = 1.0,
 56        gamma: float = 0.001,
 57        ignore_label: Optional[int] = None,
 58        impl: Optional[str] = None
 59    ):
 60        assert ignore_label is None, "Not implemented"
 61        super().__init__()
 62        self.delta_var = delta_var
 63        self.delta_dist = delta_dist
 64        self.norm = norm
 65        self.alpha = alpha
 66        self.beta = beta
 67        self.gamma = gamma
 68        self.ignore_label = ignore_label
 69
 70        assert impl in self.implementations
 71        has_torch_scatter = self.has_torch_scatter()
 72        if impl is None:
 73            if not has_torch_scatter:
 74                pt_scatter = "https://github.com/rusty1s/pytorch_scatter"
 75                warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.")
 76            self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch
 77        elif impl == "scatter":
 78            assert has_torch_scatter
 79            self._contrastive_impl = self._scatter_impl_batch
 80        elif impl == "expand":
 81            self._contrastive_impl = self._expand_impl_batch
 82
 83        # all torch_em classes should store init kwargs to easily recreate the init call
 84        self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm,
 85                            "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label,
 86                            "impl": impl}
 87
 88    @staticmethod
 89    def has_torch_scatter():
 90        """@private
 91        """
 92        try:
 93            import torch_scatter
 94        except ImportError:
 95            torch_scatter = None
 96        return torch_scatter is not None
 97
 98    # This implementation expands all tensors to match the instance dimensions.
 99    # Hence it's fast, but has high memory consumption.
100    # The implementation does not support masking any instance labels in the loss.
101    def _expand_impl_batch(self, input_batch, target_batch, ndim):
102        # add singleton batch dimension required for further computation
103        input_batch = input_batch.unsqueeze(0)
104
105        # get number of instances in the batch
106        instances = torch.unique(target_batch)
107        assert check_consecutive(instances), f"{instances}"
108        n_instances = instances.size()[0]
109
110        # SPATIAL = D X H X W in 3d case, H X W in 2d case
111        # expand each label as a one-hot vector: N x SPATIAL -> N x C x SPATIAL
112        target_batch = impl.expand_as_one_hot(target_batch, n_instances)
113
114        cluster_means, embeddings_per_instance = impl._compute_cluster_means(input_batch,
115                                                                             target_batch, ndim)
116        variance_term = impl._compute_variance_term(cluster_means, embeddings_per_instance,
117                                                    target_batch, ndim, self.norm, self.delta_var)
118        distance_term = impl._compute_distance_term(cluster_means, n_instances,
119                                                    ndim, self.norm, self.delta_dist)
120        regularization_term = impl._compute_regularizer_term(cluster_means, n_instances,
121                                                             ndim, self.norm)
122        # compute total loss
123        return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term
124
125    def _scatter_impl_batch(self, input_batch, target_batch, ndim):
126        # add singleton batch dimension required for further computation
127        input_batch = input_batch.unsqueeze(0)
128
129        instance_ids, instance_sizes = torch.unique(target_batch, return_counts=True)
130        n_instances = len(instance_ids)
131        cluster_means = impl._compute_cluster_means_scatter(input_batch, target_batch, ndim, n_lbl=n_instances)
132
133        variance_term = impl._compute_variance_term_scatter(cluster_means, input_batch, target_batch, self.norm,
134                                                            self.delta_var, instance_sizes)
135        distance_term = impl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist)
136
137        regularization_term = torch.sum(
138            torch.norm(cluster_means, p=self.norm, dim=1)
139        ).div(n_instances)
140
141        # Compute the combined loss.
142        return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term
143
144    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
145        """Compute the discrimantive loss.
146
147        Args:
148            input_: The embedding predictions.
149            target: The target segmentation.
150
151        Returns:
152            The discriminative loss value.
153        """
154        n_batches = input_.shape[0]
155        assert target.dim() == input_.dim()
156        assert target.shape[1] == 1
157        assert n_batches == target.shape[0]
158        assert input_.size()[2:] == target.size()[2:]
159
160        ndim = input_.dim() - 2
161        assert ndim in (2, 3)
162
163        # iterate over the batches
164        loss = 0.0
165        for input_batch, target_batch in zip(input_, target):
166            loss_batch = self._contrastive_impl(input_batch, target_batch, ndim)
167            loss += loss_batch
168
169        return loss.div(n_batches)
def check_consecutive(labels: torch.Tensor) -> bool:
10def check_consecutive(labels: torch.Tensor) -> bool:
11    """Check that the input labels are consecutive and start at zero.
12
13    Args:
14        labels: The labels to check.
15
16    Returns:
17        Whether the labels are consecutive.
18    """
19    diff = labels[1:] - labels[:-1]
20    return (labels[0] == 0) and (diff == 1).all()

Check that the input labels are consecutive and start at zero.

Arguments:
  • labels: The labels to check.
Returns:

Whether the labels are consecutive.

class ContrastiveLoss(torch.nn.modules.module.Module):
 26class ContrastiveLoss(nn.Module):
 27    """Implementation of a contrastive segmentation loss.
 28
 29    From "Semantic Instance Segmentation with a Discriminative Loss Function":
 30    https://arxiv.org/pdf/1708.02551.pdf
 31
 32    This class contains different implementations for the discrimnative loss:
 33    - Based on pure pytorch, expanding the instance dimension, this is not memory efficient.
 34    - Based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient.
 35
 36    Args:
 37        delta_var: The hinge distance for the variance term.
 38            The variance term corresponds to the attractive term of the loss function.
 39        delta_dist: The hinge distance for the distance term.
 40            The distance term corresponds to the repulsive term of the loss function.
 41        norm: The norm to use.
 42        alpha: Weight for the variance term of the loss.
 43        beta: Weight for the distance term of the loss.
 44        gamma: Weight for the regularization term of the loss.
 45        ignore_label: Ignore label to exclude from the loss computation.
 46        impl: Implementation of the loss to use, either 'scatter' or 'expand'.
 47    """
 48    implementations = (None, "scatter", "expand")
 49
 50    def __init__(
 51        self,
 52        delta_var: float,
 53        delta_dist: float,
 54        norm: str = "fro",
 55        alpha: float = 1.0,
 56        beta: float = 1.0,
 57        gamma: float = 0.001,
 58        ignore_label: Optional[int] = None,
 59        impl: Optional[str] = None
 60    ):
 61        assert ignore_label is None, "Not implemented"
 62        super().__init__()
 63        self.delta_var = delta_var
 64        self.delta_dist = delta_dist
 65        self.norm = norm
 66        self.alpha = alpha
 67        self.beta = beta
 68        self.gamma = gamma
 69        self.ignore_label = ignore_label
 70
 71        assert impl in self.implementations
 72        has_torch_scatter = self.has_torch_scatter()
 73        if impl is None:
 74            if not has_torch_scatter:
 75                pt_scatter = "https://github.com/rusty1s/pytorch_scatter"
 76                warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.")
 77            self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch
 78        elif impl == "scatter":
 79            assert has_torch_scatter
 80            self._contrastive_impl = self._scatter_impl_batch
 81        elif impl == "expand":
 82            self._contrastive_impl = self._expand_impl_batch
 83
 84        # all torch_em classes should store init kwargs to easily recreate the init call
 85        self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm,
 86                            "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label,
 87                            "impl": impl}
 88
 89    @staticmethod
 90    def has_torch_scatter():
 91        """@private
 92        """
 93        try:
 94            import torch_scatter
 95        except ImportError:
 96            torch_scatter = None
 97        return torch_scatter is not None
 98
 99    # This implementation expands all tensors to match the instance dimensions.
100    # Hence it's fast, but has high memory consumption.
101    # The implementation does not support masking any instance labels in the loss.
102    def _expand_impl_batch(self, input_batch, target_batch, ndim):
103        # add singleton batch dimension required for further computation
104        input_batch = input_batch.unsqueeze(0)
105
106        # get number of instances in the batch
107        instances = torch.unique(target_batch)
108        assert check_consecutive(instances), f"{instances}"
109        n_instances = instances.size()[0]
110
111        # SPATIAL = D X H X W in 3d case, H X W in 2d case
112        # expand each label as a one-hot vector: N x SPATIAL -> N x C x SPATIAL
113        target_batch = impl.expand_as_one_hot(target_batch, n_instances)
114
115        cluster_means, embeddings_per_instance = impl._compute_cluster_means(input_batch,
116                                                                             target_batch, ndim)
117        variance_term = impl._compute_variance_term(cluster_means, embeddings_per_instance,
118                                                    target_batch, ndim, self.norm, self.delta_var)
119        distance_term = impl._compute_distance_term(cluster_means, n_instances,
120                                                    ndim, self.norm, self.delta_dist)
121        regularization_term = impl._compute_regularizer_term(cluster_means, n_instances,
122                                                             ndim, self.norm)
123        # compute total loss
124        return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term
125
126    def _scatter_impl_batch(self, input_batch, target_batch, ndim):
127        # add singleton batch dimension required for further computation
128        input_batch = input_batch.unsqueeze(0)
129
130        instance_ids, instance_sizes = torch.unique(target_batch, return_counts=True)
131        n_instances = len(instance_ids)
132        cluster_means = impl._compute_cluster_means_scatter(input_batch, target_batch, ndim, n_lbl=n_instances)
133
134        variance_term = impl._compute_variance_term_scatter(cluster_means, input_batch, target_batch, self.norm,
135                                                            self.delta_var, instance_sizes)
136        distance_term = impl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist)
137
138        regularization_term = torch.sum(
139            torch.norm(cluster_means, p=self.norm, dim=1)
140        ).div(n_instances)
141
142        # Compute the combined loss.
143        return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term
144
145    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
146        """Compute the discrimantive loss.
147
148        Args:
149            input_: The embedding predictions.
150            target: The target segmentation.
151
152        Returns:
153            The discriminative loss value.
154        """
155        n_batches = input_.shape[0]
156        assert target.dim() == input_.dim()
157        assert target.shape[1] == 1
158        assert n_batches == target.shape[0]
159        assert input_.size()[2:] == target.size()[2:]
160
161        ndim = input_.dim() - 2
162        assert ndim in (2, 3)
163
164        # iterate over the batches
165        loss = 0.0
166        for input_batch, target_batch in zip(input_, target):
167            loss_batch = self._contrastive_impl(input_batch, target_batch, ndim)
168            loss += loss_batch
169
170        return loss.div(n_batches)

Implementation of a contrastive segmentation loss.

From "Semantic Instance Segmentation with a Discriminative Loss Function": https://arxiv.org/pdf/1708.02551.pdf

This class contains different implementations for the discrimnative loss:

  • Based on pure pytorch, expanding the instance dimension, this is not memory efficient.
  • Based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient.
Arguments:
  • delta_var: The hinge distance for the variance term. The variance term corresponds to the attractive term of the loss function.
  • delta_dist: The hinge distance for the distance term. The distance term corresponds to the repulsive term of the loss function.
  • norm: The norm to use.
  • alpha: Weight for the variance term of the loss.
  • beta: Weight for the distance term of the loss.
  • gamma: Weight for the regularization term of the loss.
  • ignore_label: Ignore label to exclude from the loss computation.
  • impl: Implementation of the loss to use, either 'scatter' or 'expand'.
ContrastiveLoss( delta_var: float, delta_dist: float, norm: str = 'fro', alpha: float = 1.0, beta: float = 1.0, gamma: float = 0.001, ignore_label: Optional[int] = None, impl: Optional[str] = None)
50    def __init__(
51        self,
52        delta_var: float,
53        delta_dist: float,
54        norm: str = "fro",
55        alpha: float = 1.0,
56        beta: float = 1.0,
57        gamma: float = 0.001,
58        ignore_label: Optional[int] = None,
59        impl: Optional[str] = None
60    ):
61        assert ignore_label is None, "Not implemented"
62        super().__init__()
63        self.delta_var = delta_var
64        self.delta_dist = delta_dist
65        self.norm = norm
66        self.alpha = alpha
67        self.beta = beta
68        self.gamma = gamma
69        self.ignore_label = ignore_label
70
71        assert impl in self.implementations
72        has_torch_scatter = self.has_torch_scatter()
73        if impl is None:
74            if not has_torch_scatter:
75                pt_scatter = "https://github.com/rusty1s/pytorch_scatter"
76                warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.")
77            self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch
78        elif impl == "scatter":
79            assert has_torch_scatter
80            self._contrastive_impl = self._scatter_impl_batch
81        elif impl == "expand":
82            self._contrastive_impl = self._expand_impl_batch
83
84        # all torch_em classes should store init kwargs to easily recreate the init call
85        self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm,
86                            "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label,
87                            "impl": impl}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

implementations = (None, 'scatter', 'expand')
delta_var
delta_dist
norm
alpha
beta
gamma
ignore_label
init_kwargs
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
145    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
146        """Compute the discrimantive loss.
147
148        Args:
149            input_: The embedding predictions.
150            target: The target segmentation.
151
152        Returns:
153            The discriminative loss value.
154        """
155        n_batches = input_.shape[0]
156        assert target.dim() == input_.dim()
157        assert target.shape[1] == 1
158        assert n_batches == target.shape[0]
159        assert input_.size()[2:] == target.size()[2:]
160
161        ndim = input_.dim() - 2
162        assert ndim in (2, 3)
163
164        # iterate over the batches
165        loss = 0.0
166        for input_batch, target_batch in zip(input_, target):
167            loss_batch = self._contrastive_impl(input_batch, target_batch, ndim)
168            loss += loss_batch
169
170        return loss.div(n_batches)

Compute the discrimantive loss.

Arguments:
  • input_: The embedding predictions.
  • target: The target segmentation.
Returns:

The discriminative loss value.