torch_em.loss.contrastive_impl

  1from typing import Optional
  2
  3import torch
  4try:
  5    from torch_scatter import scatter_mean
  6except Exception:
  7    scatter_mean = None
  8
  9#
 10# implementation using torch_scatter
 11#
 12
 13
 14def _compute_cluster_means_scatter(input_, target, ndim, n_lbl=None):
 15    assert scatter_mean is not None
 16    assert ndim in (2, 3)
 17    if ndim == 2:
 18        feat = input_.permute(1, 0, 2, 3).flatten(1)
 19    else:
 20        feat = input_.permute(1, 0, 2, 3, 4).flatten(1)
 21    lbl = target.flatten()
 22    if n_lbl is None:
 23        n_lbl = torch.unique(target).size(0)
 24    mean_embeddings = scatter_mean(feat, lbl, dim=1, dim_size=n_lbl)
 25    return mean_embeddings.transpose(1, 0)
 26
 27
 28def _compute_distance_term_scatter(cluster_means, norm, delta_dist, ignore_labels=None):
 29    C = cluster_means.shape[0]
 30    if C == 1:
 31        # just one cluster in the batch, so distance term does not contribute to the loss
 32        return 0.
 33
 34    # expand cluster_means tensor in order to compute the pair-wise distance between cluster means
 35    cluster_means = cluster_means.unsqueeze(0)
 36    shape = list(cluster_means.size())
 37    shape[0] = C
 38
 39    # CxCxE
 40    cm_matrix1 = cluster_means.expand(shape)
 41    # transpose the cluster_means matrix in order to compute pair-wise distances
 42    cm_matrix2 = cm_matrix1.permute(1, 0, 2)
 43    # compute pair-wise distances (CxC)
 44    dist_matrix = torch.norm(cm_matrix1 - cm_matrix2, p=norm, dim=2)
 45
 46    C_norm = C
 47    if ignore_labels is not None:
 48        # TODO implement arbitrary ignore labels
 49        assert ignore_labels == [0], "Only zero ignore label supported so far"
 50        if C == 2:
 51            # just two cluster instances, including one which is ignored,
 52            # i.e. distance term does not contribute to the loss
 53            return 0.0
 54        # set the distance to ignore-labels to be greater than 2*delta_dist,
 55        # so that it does not contribute to the loss because of the hinge at 2*delta_dist
 56
 57        # find minimum dist
 58        d_min = torch.min(dist_matrix[dist_matrix > 0]).item()
 59        # dist_multiplier = 2 * delta_dist / d_min + epsilon
 60        dist_multiplier = 2 * delta_dist / d_min + 1e-3
 61        # create distance mask
 62        dist_mask = torch.ones_like(dist_matrix)
 63        dist_mask[0, 1:] = dist_multiplier
 64        dist_mask[1:, 0] = dist_multiplier
 65
 66        # mask the dist_matrix
 67        dist_matrix = dist_matrix * dist_mask
 68        # decrease number of instances
 69        C_norm -= 1
 70
 71    # create matrix for the repulsion distance (i.e. cluster centers further apart than 2 * delta_dist
 72    # are not longer repulsed)
 73    # CxC
 74    repulsion_dist = 2 * delta_dist * (1 - torch.eye(C, device=cluster_means.device))
 75    # zero out distances grater than 2*delta_dist (CxC)
 76    hinged_dist = torch.clamp(repulsion_dist - dist_matrix, min=0) ** 2
 77    # sum all of the hinged pair-wise distances
 78    hinged_dist = torch.sum(hinged_dist, dim=(0, 1))
 79    # normalized by the number of paris and return
 80    return hinged_dist / (C_norm * (C_norm - 1))
 81
 82
 83# NOTE: it would be better to not expand the instance sizes spatially, but instead expand the
 84# instance dimension once we have the variance summed up and then divide by the instance sizes
 85# (both for performance and numerical stability)
 86def _compute_variance_term_scatter(
 87    cluster_means, embeddings, target, norm, delta_var, instance_sizes, ignore_labels=None
 88):
 89    assert cluster_means.shape[1] == embeddings.shape[1]
 90    ndim = embeddings.ndim - 2
 91    assert ndim in (2, 3), f"{ndim}"
 92    n_instances = cluster_means.shape[0]
 93
 94    # compute the spatial mean and instance fields by scattering with the target tensor
 95    cluster_means_spatial = cluster_means[target]
 96    instance_sizes_spatial = instance_sizes[target]
 97
 98    # permute the embedding dimension to axis 1
 99    if ndim == 2:
100        cluster_means_spatial = cluster_means_spatial.permute(0, 3, 1, 2)
101        dim_arg = (1, 2)
102    else:
103        cluster_means_spatial = cluster_means_spatial.permute(0, 4, 1, 2, 3)
104        dim_arg = (1, 2, 3)
105    assert embeddings.shape == cluster_means_spatial.shape
106
107    # compute the variance
108    variance = torch.norm(embeddings - cluster_means_spatial, norm, dim=1)
109
110    # apply the ignore labels (if given)
111    if ignore_labels is not None:
112        assert isinstance(ignore_labels, list)
113        # mask out the ignore labels
114        mask = torch.ones_like(variance)
115        mask[torch.isin(mask, torch.tensor(ignore_labels).to(mask.device))]
116        variance *= mask
117        # decrease number of instances
118        n_instances -= len(ignore_labels)
119        # if there are only ignore labels in the target return 0
120        if n_instances == 0:
121            return 0.0
122
123    # hinge the variance
124    variance = torch.clamp(variance - delta_var, min=0) ** 2
125    assert variance.shape == instance_sizes_spatial.shape
126
127    # normalize the variance by instance sizes and number of instances and sum it up
128    variance = torch.sum(variance / instance_sizes_spatial, dim=dim_arg) / n_instances
129    return variance
130
131
132#
133# pure torch implementation
134#
135
136
137def expand_as_one_hot(input_: torch.Tensor, C: int, ignore_label: Optional[int] = None) -> torch.Tensor:
138    """Expand labels to a one-hot representation.
139
140    Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
141    Here, SPATIAL = DxHxW in case of 3D data or SPATIAL = HxW in case of 2D data.
142    Make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function won't work.
143
144    Args:
145        input_: A 3D or 4D label image (NxSPATIAL).
146        C: The number of channels/labels.
147        ignore_label: The ignore index to be discarded during the expansion.
148
149    Returns:
150        A 4D or 5D output (NxCxSPATIAL) with one-hot expanded labels.
151    """
152    assert input_.dim() in (3, 4), f"Unsupported input shape {input_.shape}"
153
154    # expand the input_ tensor to Nx1xSPATIAL before scattering
155    input_ = input_.unsqueeze(1)
156    # create result tensor shape (NxCxSPATIAL)
157    output_shape = list(input_.size())
158    output_shape[1] = C
159
160    if ignore_label is not None:
161        # create ignore_label mask for the result
162        mask = input_.expand(output_shape) == ignore_label
163        # clone the src tensor and zero out ignore_label in the input_
164        input_ = input_.clone()
165        input_[input_ == ignore_label] = 0
166        # scatter to get the one-hot tensor
167        result = torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1)
168        # bring back the ignore_label in the result
169        result[mask] = ignore_label
170        return result
171    else:
172        # scatter to get the one-hot tensor
173        return torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1)
174
175
176def _compute_cluster_means(input_, target, ndim):
177
178    dim_arg = (3, 4) if ndim == 2 else (3, 4, 5)
179
180    embedding_dims = input_.size()[1]
181
182    # expand target: NxCxSPATIAL -> # NxCx1xSPATIAL
183    target = target.unsqueeze(2)
184
185    # NOTE we could try to reuse this in '_compute_variance_term',
186    # but it has another dimensionality, so we would need to drop one axis
187    # get number of voxels in each cluster output: NxCx1(SPATIAL)
188    num_voxels_per_instance = torch.sum(target, dim=dim_arg, keepdim=True)
189
190    # expand target: NxCx1xSPATIAL -> # NxCxExSPATIAL
191    shape = list(target.size())
192    shape[2] = embedding_dims
193    target = target.expand(shape)
194
195    # expand input_: NxExSPATIAL -> Nx1xExSPATIAL
196    input_ = input_.unsqueeze(1)
197
198    # sum embeddings in each instance (multiply first via broadcasting) output: NxCxEx1(SPATIAL)
199    embeddings_per_instance = input_ * target
200    num = torch.sum(embeddings_per_instance, dim=dim_arg, keepdim=True)
201
202    # compute mean embeddings per instance NxCxEx1(SPATIAL)
203    mean_embeddings = num / num_voxels_per_instance
204
205    # return mean embeddings and additional tensors needed for further computations
206    return mean_embeddings, embeddings_per_instance
207
208
209def _compute_variance_term(cluster_means, embeddings, target, ndim, norm, delta_var):
210    dim_arg = (2, 3) if ndim == 2 else (2, 3, 4)
211
212    # compute the distance to cluster means, result:(NxCxSPATIAL)
213    variance = torch.norm(embeddings - cluster_means, norm, dim=2)
214
215    # get per instance distances (apply instance mask)
216    assert variance.shape == target.shape
217    variance = variance * target
218
219    # zero out distances less than delta_var and sum to get the variance (NxC)
220    variance = torch.clamp(variance - delta_var, min=0) ** 2
221    variance = torch.sum(variance, dim=dim_arg)
222
223    # get number of voxels per instance (NxC)
224    num_voxels_per_instance = torch.sum(target, dim=dim_arg)
225
226    # normalize the variance term
227    C = target.size()[1]
228    variance = torch.sum(variance / num_voxels_per_instance, dim=1) / C
229    return variance
230
231
232def _compute_distance_term(cluster_means, C, ndim, norm, delta_dist):
233    if C == 1:
234        # just one cluster in the batch, so distance term does not contribute to the loss
235        return 0.
236
237    # squeeze space dims
238    for _ in range(ndim):
239        cluster_means = cluster_means.squeeze(-1)
240    # expand cluster_means tensor in order to compute the pair-wise distance between cluster means
241    cluster_means = cluster_means.unsqueeze(1)
242    shape = list(cluster_means.size())
243    shape[1] = C
244
245    # NxCxCxExSPATIAL(1)
246    cm_matrix1 = cluster_means.expand(shape)
247    # transpose the cluster_means matrix in order to compute pair-wise distances
248    cm_matrix2 = cm_matrix1.permute(0, 2, 1, 3)
249    # compute pair-wise distances (NxCxC)
250    dist_matrix = torch.norm(cm_matrix1 - cm_matrix2, p=norm, dim=3)
251
252    # create matrix for the repulsion distance (i.e. cluster centers further apart than 2 * delta_dist
253    # are not longer repulsed)
254    repulsion_dist = 2 * delta_dist * (1 - torch.eye(C, device=cluster_means.device))
255    # 1xCxC
256    repulsion_dist = repulsion_dist.unsqueeze(0)
257    # zero out distances grater than 2*delta_dist (NxCxC)
258    hinged_dist = torch.clamp(repulsion_dist - dist_matrix, min=0) ** 2
259    # sum all of the hinged pair-wise distances
260    hinged_dist = torch.sum(hinged_dist, dim=(1, 2))
261    # normalized by the number of paris and return
262    return hinged_dist / (C * (C - 1))
263
264
265def _compute_regularizer_term(cluster_means, C, ndim, norm):
266    # squeeze space dims
267    for _ in range(ndim):
268        cluster_means = cluster_means.squeeze(-1)
269    norms = torch.norm(cluster_means, p=norm, dim=2)
270    assert norms.size()[1] == C
271    # return the average norm per batch
272    return torch.sum(norms, dim=1).div(C)
def expand_as_one_hot( input_: torch.Tensor, C: int, ignore_label: Optional[int] = None) -> torch.Tensor:
138def expand_as_one_hot(input_: torch.Tensor, C: int, ignore_label: Optional[int] = None) -> torch.Tensor:
139    """Expand labels to a one-hot representation.
140
141    Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
142    Here, SPATIAL = DxHxW in case of 3D data or SPATIAL = HxW in case of 2D data.
143    Make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function won't work.
144
145    Args:
146        input_: A 3D or 4D label image (NxSPATIAL).
147        C: The number of channels/labels.
148        ignore_label: The ignore index to be discarded during the expansion.
149
150    Returns:
151        A 4D or 5D output (NxCxSPATIAL) with one-hot expanded labels.
152    """
153    assert input_.dim() in (3, 4), f"Unsupported input shape {input_.shape}"
154
155    # expand the input_ tensor to Nx1xSPATIAL before scattering
156    input_ = input_.unsqueeze(1)
157    # create result tensor shape (NxCxSPATIAL)
158    output_shape = list(input_.size())
159    output_shape[1] = C
160
161    if ignore_label is not None:
162        # create ignore_label mask for the result
163        mask = input_.expand(output_shape) == ignore_label
164        # clone the src tensor and zero out ignore_label in the input_
165        input_ = input_.clone()
166        input_[input_ == ignore_label] = 0
167        # scatter to get the one-hot tensor
168        result = torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1)
169        # bring back the ignore_label in the result
170        result[mask] = ignore_label
171        return result
172    else:
173        # scatter to get the one-hot tensor
174        return torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1)

Expand labels to a one-hot representation.

Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. Here, SPATIAL = DxHxW in case of 3D data or SPATIAL = HxW in case of 2D data. Make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function won't work.

Arguments:
  • input_: A 3D or 4D label image (NxSPATIAL).
  • C: The number of channels/labels.
  • ignore_label: The ignore index to be discarded during the expansion.
Returns:

A 4D or 5D output (NxCxSPATIAL) with one-hot expanded labels.