torch_em.loss.contrastive_impl

  1import torch
  2try:
  3    from torch_scatter import scatter_mean
  4except Exception:
  5    scatter_mean = None
  6
  7#
  8# implementation using torch_scatter
  9#
 10
 11
 12def _compute_cluster_means_scatter(input_, target, ndim, n_lbl=None):
 13    assert scatter_mean is not None
 14    assert ndim in (2, 3)
 15    if ndim == 2:
 16        feat = input_.permute(1, 0, 2, 3).flatten(1)
 17    else:
 18        feat = input_.permute(1, 0, 2, 3, 4).flatten(1)
 19    lbl = target.flatten()
 20    if n_lbl is None:
 21        n_lbl = torch.unique(target).size(0)
 22    mean_embeddings = scatter_mean(feat, lbl, dim=1, dim_size=n_lbl)
 23    return mean_embeddings.transpose(1, 0)
 24
 25
 26def _compute_distance_term_scatter(cluster_means, norm, delta_dist, ignore_labels=None):
 27    C = cluster_means.shape[0]
 28    if C == 1:
 29        # just one cluster in the batch, so distance term does not contribute to the loss
 30        return 0.
 31
 32    # expand cluster_means tensor in order to compute the pair-wise distance between cluster means
 33    cluster_means = cluster_means.unsqueeze(0)
 34    shape = list(cluster_means.size())
 35    shape[0] = C
 36
 37    # CxCxE
 38    cm_matrix1 = cluster_means.expand(shape)
 39    # transpose the cluster_means matrix in order to compute pair-wise distances
 40    cm_matrix2 = cm_matrix1.permute(1, 0, 2)
 41    # compute pair-wise distances (CxC)
 42    dist_matrix = torch.norm(cm_matrix1 - cm_matrix2, p=norm, dim=2)
 43
 44    C_norm = C
 45    if ignore_labels is not None:
 46        # TODO implement arbitrary ignore labels
 47        assert ignore_labels == [0], "Only zero ignore label supported so far"
 48        if C == 2:
 49            # just two cluster instances, including one which is ignored,
 50            # i.e. distance term does not contribute to the loss
 51            return 0.0
 52        # set the distance to ignore-labels to be greater than 2*delta_dist,
 53        # so that it does not contribute to the loss because of the hinge at 2*delta_dist
 54
 55        # find minimum dist
 56        d_min = torch.min(dist_matrix[dist_matrix > 0]).item()
 57        # dist_multiplier = 2 * delta_dist / d_min + epsilon
 58        dist_multiplier = 2 * delta_dist / d_min + 1e-3
 59        # create distance mask
 60        dist_mask = torch.ones_like(dist_matrix)
 61        dist_mask[0, 1:] = dist_multiplier
 62        dist_mask[1:, 0] = dist_multiplier
 63
 64        # mask the dist_matrix
 65        dist_matrix = dist_matrix * dist_mask
 66        # decrease number of instances
 67        C_norm -= 1
 68
 69    # create matrix for the repulsion distance (i.e. cluster centers further apart than 2 * delta_dist
 70    # are not longer repulsed)
 71    # CxC
 72    repulsion_dist = 2 * delta_dist * (1 - torch.eye(C, device=cluster_means.device))
 73    # zero out distances grater than 2*delta_dist (CxC)
 74    hinged_dist = torch.clamp(repulsion_dist - dist_matrix, min=0) ** 2
 75    # sum all of the hinged pair-wise distances
 76    hinged_dist = torch.sum(hinged_dist, dim=(0, 1))
 77    # normalized by the number of paris and return
 78    return hinged_dist / (C_norm * (C_norm - 1))
 79
 80
 81# NOTE: it would be better to not expand the instance sizes spatially, but instead expand the
 82# instance dimension once we have the variance summed up and then divide by the instance sizes
 83# (both for performance and numerical stability)
 84def _compute_variance_term_scatter(
 85    cluster_means, embeddings, target, norm, delta_var, instance_sizes, ignore_labels=None
 86):
 87    assert cluster_means.shape[1] == embeddings.shape[1]
 88    ndim = embeddings.ndim - 2
 89    assert ndim in (2, 3), f"{ndim}"
 90    n_instances = cluster_means.shape[0]
 91
 92    # compute the spatial mean and instance fields by scattering with the target tensor
 93    cluster_means_spatial = cluster_means[target]
 94    instance_sizes_spatial = instance_sizes[target]
 95
 96    # permute the embedding dimension to axis 1
 97    if ndim == 2:
 98        cluster_means_spatial = cluster_means_spatial.permute(0, 3, 1, 2)
 99        dim_arg = (1, 2)
100    else:
101        cluster_means_spatial = cluster_means_spatial.permute(0, 4, 1, 2, 3)
102        dim_arg = (1, 2, 3)
103    assert embeddings.shape == cluster_means_spatial.shape
104
105    # compute the variance
106    variance = torch.norm(embeddings - cluster_means_spatial, norm, dim=1)
107
108    # apply the ignore labels (if given)
109    if ignore_labels is not None:
110        assert isinstance(ignore_labels, list)
111        # mask out the ignore labels
112        mask = torch.ones_like(variance)
113        mask[torch.isin(mask, torch.tensor(ignore_labels).to(mask.device))]
114        variance *= mask
115        # decrease number of instances
116        n_instances -= len(ignore_labels)
117        # if there are only ignore labels in the target return 0
118        if n_instances == 0:
119            return 0.0
120
121    # hinge the variance
122    variance = torch.clamp(variance - delta_var, min=0) ** 2
123    assert variance.shape == instance_sizes_spatial.shape
124
125    # normalize the variance by instance sizes and number of instances and sum it up
126    variance = torch.sum(variance / instance_sizes_spatial, dim=dim_arg) / n_instances
127    return variance
128
129
130#
131# pure torch implementation
132#
133
134
135def expand_as_one_hot(input_, C, ignore_label=None):
136    """
137    Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
138    NOTE: make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function
139    won't work.
140
141    SPATIAL = DxHxW in case of 3D or SPATIAL = HxW in case of 2D
142    :param input_: 3D or 4D label image (NxSPATIAL)
143    :param C: number of channels/labels
144    :param ignore_label: ignore index to be kept during the expansion
145    :return: 4D or 5D output image (NxCxSPATIAL)
146    """
147    assert input_.dim() in (3, 4), f"Unsupported input shape {input_.shape}"
148
149    # expand the input_ tensor to Nx1xSPATIAL before scattering
150    input_ = input_.unsqueeze(1)
151    # create result tensor shape (NxCxSPATIAL)
152    output_shape = list(input_.size())
153    output_shape[1] = C
154
155    if ignore_label is not None:
156        # create ignore_label mask for the result
157        mask = input_.expand(output_shape) == ignore_label
158        # clone the src tensor and zero out ignore_label in the input_
159        input_ = input_.clone()
160        input_[input_ == ignore_label] = 0
161        # scatter to get the one-hot tensor
162        result = torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1)
163        # bring back the ignore_label in the result
164        result[mask] = ignore_label
165        return result
166    else:
167        # scatter to get the one-hot tensor
168        return torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1)
169
170
171def _compute_cluster_means(input_, target, ndim):
172
173    dim_arg = (3, 4) if ndim == 2 else (3, 4, 5)
174
175    embedding_dims = input_.size()[1]
176
177    # expand target: NxCxSPATIAL -> # NxCx1xSPATIAL
178    target = target.unsqueeze(2)
179
180    # NOTE we could try to reuse this in '_compute_variance_term',
181    # but it has another dimensionality, so we would need to drop one axis
182    # get number of voxels in each cluster output: NxCx1(SPATIAL)
183    num_voxels_per_instance = torch.sum(target, dim=dim_arg, keepdim=True)
184
185    # expand target: NxCx1xSPATIAL -> # NxCxExSPATIAL
186    shape = list(target.size())
187    shape[2] = embedding_dims
188    target = target.expand(shape)
189
190    # expand input_: NxExSPATIAL -> Nx1xExSPATIAL
191    input_ = input_.unsqueeze(1)
192
193    # sum embeddings in each instance (multiply first via broadcasting) output: NxCxEx1(SPATIAL)
194    embeddings_per_instance = input_ * target
195    num = torch.sum(embeddings_per_instance, dim=dim_arg, keepdim=True)
196
197    # compute mean embeddings per instance NxCxEx1(SPATIAL)
198    mean_embeddings = num / num_voxels_per_instance
199
200    # return mean embeddings and additional tensors needed for further computations
201    return mean_embeddings, embeddings_per_instance
202
203
204def _compute_variance_term(cluster_means, embeddings, target, ndim, norm, delta_var):
205    dim_arg = (2, 3) if ndim == 2 else (2, 3, 4)
206
207    # compute the distance to cluster means, result:(NxCxSPATIAL)
208    variance = torch.norm(embeddings - cluster_means, norm, dim=2)
209
210    # get per instance distances (apply instance mask)
211    assert variance.shape == target.shape
212    variance = variance * target
213
214    # zero out distances less than delta_var and sum to get the variance (NxC)
215    variance = torch.clamp(variance - delta_var, min=0) ** 2
216    variance = torch.sum(variance, dim=dim_arg)
217
218    # get number of voxels per instance (NxC)
219    num_voxels_per_instance = torch.sum(target, dim=dim_arg)
220
221    # normalize the variance term
222    C = target.size()[1]
223    variance = torch.sum(variance / num_voxels_per_instance, dim=1) / C
224    return variance
225
226
227def _compute_distance_term(cluster_means, C, ndim, norm, delta_dist):
228    if C == 1:
229        # just one cluster in the batch, so distance term does not contribute to the loss
230        return 0.
231
232    # squeeze space dims
233    for _ in range(ndim):
234        cluster_means = cluster_means.squeeze(-1)
235    # expand cluster_means tensor in order to compute the pair-wise distance between cluster means
236    cluster_means = cluster_means.unsqueeze(1)
237    shape = list(cluster_means.size())
238    shape[1] = C
239
240    # NxCxCxExSPATIAL(1)
241    cm_matrix1 = cluster_means.expand(shape)
242    # transpose the cluster_means matrix in order to compute pair-wise distances
243    cm_matrix2 = cm_matrix1.permute(0, 2, 1, 3)
244    # compute pair-wise distances (NxCxC)
245    dist_matrix = torch.norm(cm_matrix1 - cm_matrix2, p=norm, dim=3)
246
247    # create matrix for the repulsion distance (i.e. cluster centers further apart than 2 * delta_dist
248    # are not longer repulsed)
249    repulsion_dist = 2 * delta_dist * (1 - torch.eye(C, device=cluster_means.device))
250    # 1xCxC
251    repulsion_dist = repulsion_dist.unsqueeze(0)
252    # zero out distances grater than 2*delta_dist (NxCxC)
253    hinged_dist = torch.clamp(repulsion_dist - dist_matrix, min=0) ** 2
254    # sum all of the hinged pair-wise distances
255    hinged_dist = torch.sum(hinged_dist, dim=(1, 2))
256    # normalized by the number of paris and return
257    return hinged_dist / (C * (C - 1))
258
259
260def _compute_regularizer_term(cluster_means, C, ndim, norm):
261    # squeeze space dims
262    for _ in range(ndim):
263        cluster_means = cluster_means.squeeze(-1)
264    norms = torch.norm(cluster_means, p=norm, dim=2)
265    assert norms.size()[1] == C
266    # return the average norm per batch
267    return torch.sum(norms, dim=1).div(C)
def expand_as_one_hot(input_, C, ignore_label=None):
136def expand_as_one_hot(input_, C, ignore_label=None):
137    """
138    Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
139    NOTE: make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function
140    won't work.
141
142    SPATIAL = DxHxW in case of 3D or SPATIAL = HxW in case of 2D
143    :param input_: 3D or 4D label image (NxSPATIAL)
144    :param C: number of channels/labels
145    :param ignore_label: ignore index to be kept during the expansion
146    :return: 4D or 5D output image (NxCxSPATIAL)
147    """
148    assert input_.dim() in (3, 4), f"Unsupported input shape {input_.shape}"
149
150    # expand the input_ tensor to Nx1xSPATIAL before scattering
151    input_ = input_.unsqueeze(1)
152    # create result tensor shape (NxCxSPATIAL)
153    output_shape = list(input_.size())
154    output_shape[1] = C
155
156    if ignore_label is not None:
157        # create ignore_label mask for the result
158        mask = input_.expand(output_shape) == ignore_label
159        # clone the src tensor and zero out ignore_label in the input_
160        input_ = input_.clone()
161        input_[input_ == ignore_label] = 0
162        # scatter to get the one-hot tensor
163        result = torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1)
164        # bring back the ignore_label in the result
165        result[mask] = ignore_label
166        return result
167    else:
168        # scatter to get the one-hot tensor
169        return torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1)

Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. NOTE: make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function won't work.

SPATIAL = DxHxW in case of 3D or SPATIAL = HxW in case of 2D

Parameters
  • input_: 3D or 4D label image (NxSPATIAL)
  • C: number of channels/labels
  • ignore_label: ignore index to be kept during the expansion
Returns

4D or 5D output image (NxCxSPATIAL)