torch_em.loss.contrastive

  1from warnings import warn
  2
  3import torch
  4import torch.nn as nn
  5from . import contrastive_impl as impl
  6
  7
  8def check_consecutive(labels):
  9    """Check that the input labels are consecutive and start at zero.
 10    """
 11    diff = labels[1:] - labels[:-1]
 12    return (labels[0] == 0) and (diff == 1).all()
 13
 14
 15# TODO support more sophisticated ignore labels:
 16# - ignore_dist: ignored in distance term
 17# - ignore_var: ignored in variance term
 18class ContrastiveLoss(nn.Module):
 19    """Implementation of contrastive loss defined in https://arxiv.org/pdf/1708.02551.pdf
 20    Semantic Instance Segmentation with a Discriminative Loss Function
 21
 22    This class contians different implementations for the discrimnative loss:
 23    - based on pure pytorch, expanding the instance dimension, this is not memory efficient
 24    - based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient
 25
 26    Arguments:
 27        delta_var [float] -
 28        delta_dist [float] -
 29        norm [str] -
 30        aplpha [float] -
 31        beta [float] -
 32        gamma [float] -
 33        ignore_label [int] -
 34        impl [str] -
 35    """
 36    implementations = (None, "scatter", "expand")
 37
 38    def __init__(self, delta_var, delta_dist, norm="fro",
 39                 alpha=1.0, beta=1.0, gamma=0.001,
 40                 ignore_label=None, impl=None):
 41        assert ignore_label is None, "Not implemented"  # TODO
 42        super().__init__()
 43        self.delta_var = delta_var
 44        self.delta_dist = delta_dist
 45        self.norm = norm
 46        self.alpha = alpha
 47        self.beta = beta
 48        self.gamma = gamma
 49        self.ignore_label = ignore_label
 50
 51        assert impl in self.implementations
 52        has_torch_scatter = self.has_torch_scatter()
 53        if impl is None:
 54            if not has_torch_scatter:
 55                pt_scatter = "https://github.com/rusty1s/pytorch_scatter"
 56                warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.")
 57            self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch
 58        elif impl == "scatter":
 59            assert has_torch_scatter
 60            self._contrastive_impl = self._scatter_impl_batch
 61        elif impl == "expand":
 62            self._contrastive_impl = self._expand_impl_batch
 63
 64        # all torch_em classes should store init kwargs to easily recreate the init call
 65        self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm,
 66                            "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label,
 67                            "impl": impl}
 68
 69    @staticmethod
 70    def has_torch_scatter():
 71        try:
 72            import torch_scatter
 73        except ImportError:
 74            torch_scatter = None
 75        return torch_scatter is not None
 76
 77    # This implementation expands all tensors to match the instance dimensions.
 78    # Hence it's fast, but has high memory consumption.
 79    # The implementation does not support masking any instance labels in the loss.
 80    def _expand_impl_batch(self, input_batch, target_batch, ndim):
 81        # add singleton batch dimension required for further computation
 82        input_batch = input_batch.unsqueeze(0)
 83
 84        # get number of instances in the batch
 85        instances = torch.unique(target_batch)
 86        assert check_consecutive(instances), f"{instances}"
 87        n_instances = instances.size()[0]
 88
 89        # SPATIAL = D X H X W in 3d case, H X W in 2d case
 90        # expand each label as a one-hot vector: N x SPATIAL -> N x C x SPATIAL
 91        target_batch = impl.expand_as_one_hot(target_batch, n_instances)
 92
 93        cluster_means, embeddings_per_instance = impl._compute_cluster_means(input_batch,
 94                                                                             target_batch, ndim)
 95        variance_term = impl._compute_variance_term(cluster_means, embeddings_per_instance,
 96                                                    target_batch, ndim, self.norm, self.delta_var)
 97        distance_term = impl._compute_distance_term(cluster_means, n_instances,
 98                                                    ndim, self.norm, self.delta_dist)
 99        regularization_term = impl._compute_regularizer_term(cluster_means, n_instances,
100                                                             ndim, self.norm)
101        # compute total loss
102        return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term
103
104    def _scatter_impl_batch(self, input_batch, target_batch, ndim):
105        # add singleton batch dimension required for further computation
106        input_batch = input_batch.unsqueeze(0)
107
108        instance_ids, instance_sizes = torch.unique(target_batch, return_counts=True)
109        n_instances = len(instance_ids)
110        cluster_means = impl._compute_cluster_means_scatter(input_batch, target_batch, ndim, n_lbl=n_instances)
111
112        variance_term = impl._compute_variance_term_scatter(cluster_means, input_batch, target_batch, self.norm,
113                                                            self.delta_var, instance_sizes)
114        distance_term = impl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist)
115
116        regularization_term = torch.sum(
117            torch.norm(cluster_means, p=self.norm, dim=1)
118        ).div(n_instances)
119
120        # compute total loss
121        return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term
122
123    def forward(self, input_, target):
124        n_batches = input_.shape[0]
125        assert target.dim() == input_.dim()
126        assert target.shape[1] == 1
127        assert n_batches == target.shape[0]
128        assert input_.size()[2:] == target.size()[2:]
129
130        ndim = input_.dim() - 2
131        assert ndim in (2, 3)
132
133        # iterate over the batches
134        loss = 0.0
135        for input_batch, target_batch in zip(input_, target):
136            loss_batch = self._contrastive_impl(input_batch, target_batch, ndim)
137            loss += loss_batch
138
139        return loss.div(n_batches)
def check_consecutive(labels):
 9def check_consecutive(labels):
10    """Check that the input labels are consecutive and start at zero.
11    """
12    diff = labels[1:] - labels[:-1]
13    return (labels[0] == 0) and (diff == 1).all()

Check that the input labels are consecutive and start at zero.

class ContrastiveLoss(torch.nn.modules.module.Module):
 19class ContrastiveLoss(nn.Module):
 20    """Implementation of contrastive loss defined in https://arxiv.org/pdf/1708.02551.pdf
 21    Semantic Instance Segmentation with a Discriminative Loss Function
 22
 23    This class contians different implementations for the discrimnative loss:
 24    - based on pure pytorch, expanding the instance dimension, this is not memory efficient
 25    - based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient
 26
 27    Arguments:
 28        delta_var [float] -
 29        delta_dist [float] -
 30        norm [str] -
 31        aplpha [float] -
 32        beta [float] -
 33        gamma [float] -
 34        ignore_label [int] -
 35        impl [str] -
 36    """
 37    implementations = (None, "scatter", "expand")
 38
 39    def __init__(self, delta_var, delta_dist, norm="fro",
 40                 alpha=1.0, beta=1.0, gamma=0.001,
 41                 ignore_label=None, impl=None):
 42        assert ignore_label is None, "Not implemented"  # TODO
 43        super().__init__()
 44        self.delta_var = delta_var
 45        self.delta_dist = delta_dist
 46        self.norm = norm
 47        self.alpha = alpha
 48        self.beta = beta
 49        self.gamma = gamma
 50        self.ignore_label = ignore_label
 51
 52        assert impl in self.implementations
 53        has_torch_scatter = self.has_torch_scatter()
 54        if impl is None:
 55            if not has_torch_scatter:
 56                pt_scatter = "https://github.com/rusty1s/pytorch_scatter"
 57                warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.")
 58            self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch
 59        elif impl == "scatter":
 60            assert has_torch_scatter
 61            self._contrastive_impl = self._scatter_impl_batch
 62        elif impl == "expand":
 63            self._contrastive_impl = self._expand_impl_batch
 64
 65        # all torch_em classes should store init kwargs to easily recreate the init call
 66        self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm,
 67                            "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label,
 68                            "impl": impl}
 69
 70    @staticmethod
 71    def has_torch_scatter():
 72        try:
 73            import torch_scatter
 74        except ImportError:
 75            torch_scatter = None
 76        return torch_scatter is not None
 77
 78    # This implementation expands all tensors to match the instance dimensions.
 79    # Hence it's fast, but has high memory consumption.
 80    # The implementation does not support masking any instance labels in the loss.
 81    def _expand_impl_batch(self, input_batch, target_batch, ndim):
 82        # add singleton batch dimension required for further computation
 83        input_batch = input_batch.unsqueeze(0)
 84
 85        # get number of instances in the batch
 86        instances = torch.unique(target_batch)
 87        assert check_consecutive(instances), f"{instances}"
 88        n_instances = instances.size()[0]
 89
 90        # SPATIAL = D X H X W in 3d case, H X W in 2d case
 91        # expand each label as a one-hot vector: N x SPATIAL -> N x C x SPATIAL
 92        target_batch = impl.expand_as_one_hot(target_batch, n_instances)
 93
 94        cluster_means, embeddings_per_instance = impl._compute_cluster_means(input_batch,
 95                                                                             target_batch, ndim)
 96        variance_term = impl._compute_variance_term(cluster_means, embeddings_per_instance,
 97                                                    target_batch, ndim, self.norm, self.delta_var)
 98        distance_term = impl._compute_distance_term(cluster_means, n_instances,
 99                                                    ndim, self.norm, self.delta_dist)
100        regularization_term = impl._compute_regularizer_term(cluster_means, n_instances,
101                                                             ndim, self.norm)
102        # compute total loss
103        return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term
104
105    def _scatter_impl_batch(self, input_batch, target_batch, ndim):
106        # add singleton batch dimension required for further computation
107        input_batch = input_batch.unsqueeze(0)
108
109        instance_ids, instance_sizes = torch.unique(target_batch, return_counts=True)
110        n_instances = len(instance_ids)
111        cluster_means = impl._compute_cluster_means_scatter(input_batch, target_batch, ndim, n_lbl=n_instances)
112
113        variance_term = impl._compute_variance_term_scatter(cluster_means, input_batch, target_batch, self.norm,
114                                                            self.delta_var, instance_sizes)
115        distance_term = impl._compute_distance_term_scatter(cluster_means, self.norm, self.delta_dist)
116
117        regularization_term = torch.sum(
118            torch.norm(cluster_means, p=self.norm, dim=1)
119        ).div(n_instances)
120
121        # compute total loss
122        return self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term
123
124    def forward(self, input_, target):
125        n_batches = input_.shape[0]
126        assert target.dim() == input_.dim()
127        assert target.shape[1] == 1
128        assert n_batches == target.shape[0]
129        assert input_.size()[2:] == target.size()[2:]
130
131        ndim = input_.dim() - 2
132        assert ndim in (2, 3)
133
134        # iterate over the batches
135        loss = 0.0
136        for input_batch, target_batch in zip(input_, target):
137            loss_batch = self._contrastive_impl(input_batch, target_batch, ndim)
138            loss += loss_batch
139
140        return loss.div(n_batches)

Implementation of contrastive loss defined in https://arxiv.org/pdf/1708.02551.pdf Semantic Instance Segmentation with a Discriminative Loss Function

This class contians different implementations for the discrimnative loss:

  • based on pure pytorch, expanding the instance dimension, this is not memory efficient
  • based on pytorch_scatter (https://github.com/rusty1s/pytorch_scatter), this is memory efficient
Arguments:
  • delta_var [float] -
  • delta_dist [float] -
  • norm [str] -
  • aplpha [float] -
  • beta [float] -
  • gamma [float] -
  • ignore_label [int] -
  • impl [str] -
ContrastiveLoss( delta_var, delta_dist, norm='fro', alpha=1.0, beta=1.0, gamma=0.001, ignore_label=None, impl=None)
39    def __init__(self, delta_var, delta_dist, norm="fro",
40                 alpha=1.0, beta=1.0, gamma=0.001,
41                 ignore_label=None, impl=None):
42        assert ignore_label is None, "Not implemented"  # TODO
43        super().__init__()
44        self.delta_var = delta_var
45        self.delta_dist = delta_dist
46        self.norm = norm
47        self.alpha = alpha
48        self.beta = beta
49        self.gamma = gamma
50        self.ignore_label = ignore_label
51
52        assert impl in self.implementations
53        has_torch_scatter = self.has_torch_scatter()
54        if impl is None:
55            if not has_torch_scatter:
56                pt_scatter = "https://github.com/rusty1s/pytorch_scatter"
57                warn(f"ContrastiveLoss: using pure pytorch implementation. Install {pt_scatter} for memory efficiency.")
58            self._contrastive_impl = self._scatter_impl_batch if has_torch_scatter else self._expand_impl_batch
59        elif impl == "scatter":
60            assert has_torch_scatter
61            self._contrastive_impl = self._scatter_impl_batch
62        elif impl == "expand":
63            self._contrastive_impl = self._expand_impl_batch
64
65        # all torch_em classes should store init kwargs to easily recreate the init call
66        self.init_kwargs = {"delta_var": delta_var, "delta_dist": delta_dist, "norm": norm,
67                            "alpha": alpha, "beta": beta, "gamma": gamma, "ignore_label": ignore_label,
68                            "impl": impl}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

implementations = (None, 'scatter', 'expand')
delta_var
delta_dist
norm
alpha
beta
gamma
ignore_label
init_kwargs
@staticmethod
def has_torch_scatter():
70    @staticmethod
71    def has_torch_scatter():
72        try:
73            import torch_scatter
74        except ImportError:
75            torch_scatter = None
76        return torch_scatter is not None
def forward(self, input_, target):
124    def forward(self, input_, target):
125        n_batches = input_.shape[0]
126        assert target.dim() == input_.dim()
127        assert target.shape[1] == 1
128        assert n_batches == target.shape[0]
129        assert input_.size()[2:] == target.size()[2:]
130
131        ndim = input_.dim() - 2
132        assert ndim in (2, 3)
133
134        # iterate over the batches
135        loss = 0.0
136        for input_batch, target_batch in zip(input_, target):
137            loss_batch = self._contrastive_impl(input_batch, target_batch, ndim)
138            loss += loss_batch
139
140        return loss.div(n_batches)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile