torch_em.loss.contrastive_impl
1from typing import Optional 2 3import torch 4try: 5 from torch_scatter import scatter_mean 6except Exception: 7 scatter_mean = None 8 9# 10# implementation using torch_scatter 11# 12 13 14def _compute_cluster_means_scatter(input_, target, ndim, n_lbl=None): 15 assert scatter_mean is not None 16 assert ndim in (2, 3) 17 if ndim == 2: 18 feat = input_.permute(1, 0, 2, 3).flatten(1) 19 else: 20 feat = input_.permute(1, 0, 2, 3, 4).flatten(1) 21 lbl = target.flatten() 22 if n_lbl is None: 23 n_lbl = torch.unique(target).size(0) 24 mean_embeddings = scatter_mean(feat, lbl, dim=1, dim_size=n_lbl) 25 return mean_embeddings.transpose(1, 0) 26 27 28def _compute_distance_term_scatter(cluster_means, norm, delta_dist, ignore_labels=None): 29 C = cluster_means.shape[0] 30 if C == 1: 31 # just one cluster in the batch, so distance term does not contribute to the loss 32 return 0. 33 34 # expand cluster_means tensor in order to compute the pair-wise distance between cluster means 35 cluster_means = cluster_means.unsqueeze(0) 36 shape = list(cluster_means.size()) 37 shape[0] = C 38 39 # CxCxE 40 cm_matrix1 = cluster_means.expand(shape) 41 # transpose the cluster_means matrix in order to compute pair-wise distances 42 cm_matrix2 = cm_matrix1.permute(1, 0, 2) 43 # compute pair-wise distances (CxC) 44 dist_matrix = torch.norm(cm_matrix1 - cm_matrix2, p=norm, dim=2) 45 46 C_norm = C 47 if ignore_labels is not None: 48 # TODO implement arbitrary ignore labels 49 assert ignore_labels == [0], "Only zero ignore label supported so far" 50 if C == 2: 51 # just two cluster instances, including one which is ignored, 52 # i.e. distance term does not contribute to the loss 53 return 0.0 54 # set the distance to ignore-labels to be greater than 2*delta_dist, 55 # so that it does not contribute to the loss because of the hinge at 2*delta_dist 56 57 # find minimum dist 58 d_min = torch.min(dist_matrix[dist_matrix > 0]).item() 59 # dist_multiplier = 2 * delta_dist / d_min + epsilon 60 dist_multiplier = 2 * delta_dist / d_min + 1e-3 61 # create distance mask 62 dist_mask = torch.ones_like(dist_matrix) 63 dist_mask[0, 1:] = dist_multiplier 64 dist_mask[1:, 0] = dist_multiplier 65 66 # mask the dist_matrix 67 dist_matrix = dist_matrix * dist_mask 68 # decrease number of instances 69 C_norm -= 1 70 71 # create matrix for the repulsion distance (i.e. cluster centers further apart than 2 * delta_dist 72 # are not longer repulsed) 73 # CxC 74 repulsion_dist = 2 * delta_dist * (1 - torch.eye(C, device=cluster_means.device)) 75 # zero out distances grater than 2*delta_dist (CxC) 76 hinged_dist = torch.clamp(repulsion_dist - dist_matrix, min=0) ** 2 77 # sum all of the hinged pair-wise distances 78 hinged_dist = torch.sum(hinged_dist, dim=(0, 1)) 79 # normalized by the number of paris and return 80 return hinged_dist / (C_norm * (C_norm - 1)) 81 82 83# NOTE: it would be better to not expand the instance sizes spatially, but instead expand the 84# instance dimension once we have the variance summed up and then divide by the instance sizes 85# (both for performance and numerical stability) 86def _compute_variance_term_scatter( 87 cluster_means, embeddings, target, norm, delta_var, instance_sizes, ignore_labels=None 88): 89 assert cluster_means.shape[1] == embeddings.shape[1] 90 ndim = embeddings.ndim - 2 91 assert ndim in (2, 3), f"{ndim}" 92 n_instances = cluster_means.shape[0] 93 94 # compute the spatial mean and instance fields by scattering with the target tensor 95 cluster_means_spatial = cluster_means[target] 96 instance_sizes_spatial = instance_sizes[target] 97 98 # permute the embedding dimension to axis 1 99 if ndim == 2: 100 cluster_means_spatial = cluster_means_spatial.permute(0, 3, 1, 2) 101 dim_arg = (1, 2) 102 else: 103 cluster_means_spatial = cluster_means_spatial.permute(0, 4, 1, 2, 3) 104 dim_arg = (1, 2, 3) 105 assert embeddings.shape == cluster_means_spatial.shape 106 107 # compute the variance 108 variance = torch.norm(embeddings - cluster_means_spatial, norm, dim=1) 109 110 # apply the ignore labels (if given) 111 if ignore_labels is not None: 112 assert isinstance(ignore_labels, list) 113 # mask out the ignore labels 114 mask = torch.ones_like(variance) 115 mask[torch.isin(mask, torch.tensor(ignore_labels).to(mask.device))] 116 variance *= mask 117 # decrease number of instances 118 n_instances -= len(ignore_labels) 119 # if there are only ignore labels in the target return 0 120 if n_instances == 0: 121 return 0.0 122 123 # hinge the variance 124 variance = torch.clamp(variance - delta_var, min=0) ** 2 125 assert variance.shape == instance_sizes_spatial.shape 126 127 # normalize the variance by instance sizes and number of instances and sum it up 128 variance = torch.sum(variance / instance_sizes_spatial, dim=dim_arg) / n_instances 129 return variance 130 131 132# 133# pure torch implementation 134# 135 136 137def expand_as_one_hot(input_: torch.Tensor, C: int, ignore_label: Optional[int] = None) -> torch.Tensor: 138 """Expand labels to a one-hot representation. 139 140 Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. 141 Here, SPATIAL = DxHxW in case of 3D data or SPATIAL = HxW in case of 2D data. 142 Make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function won't work. 143 144 Args: 145 input_: A 3D or 4D label image (NxSPATIAL). 146 C: The number of channels/labels. 147 ignore_label: The ignore index to be discarded during the expansion. 148 149 Returns: 150 A 4D or 5D output (NxCxSPATIAL) with one-hot expanded labels. 151 """ 152 assert input_.dim() in (3, 4), f"Unsupported input shape {input_.shape}" 153 154 # expand the input_ tensor to Nx1xSPATIAL before scattering 155 input_ = input_.unsqueeze(1) 156 # create result tensor shape (NxCxSPATIAL) 157 output_shape = list(input_.size()) 158 output_shape[1] = C 159 160 if ignore_label is not None: 161 # create ignore_label mask for the result 162 mask = input_.expand(output_shape) == ignore_label 163 # clone the src tensor and zero out ignore_label in the input_ 164 input_ = input_.clone() 165 input_[input_ == ignore_label] = 0 166 # scatter to get the one-hot tensor 167 result = torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1) 168 # bring back the ignore_label in the result 169 result[mask] = ignore_label 170 return result 171 else: 172 # scatter to get the one-hot tensor 173 return torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1) 174 175 176def _compute_cluster_means(input_, target, ndim): 177 178 dim_arg = (3, 4) if ndim == 2 else (3, 4, 5) 179 180 embedding_dims = input_.size()[1] 181 182 # expand target: NxCxSPATIAL -> # NxCx1xSPATIAL 183 target = target.unsqueeze(2) 184 185 # NOTE we could try to reuse this in '_compute_variance_term', 186 # but it has another dimensionality, so we would need to drop one axis 187 # get number of voxels in each cluster output: NxCx1(SPATIAL) 188 num_voxels_per_instance = torch.sum(target, dim=dim_arg, keepdim=True) 189 190 # expand target: NxCx1xSPATIAL -> # NxCxExSPATIAL 191 shape = list(target.size()) 192 shape[2] = embedding_dims 193 target = target.expand(shape) 194 195 # expand input_: NxExSPATIAL -> Nx1xExSPATIAL 196 input_ = input_.unsqueeze(1) 197 198 # sum embeddings in each instance (multiply first via broadcasting) output: NxCxEx1(SPATIAL) 199 embeddings_per_instance = input_ * target 200 num = torch.sum(embeddings_per_instance, dim=dim_arg, keepdim=True) 201 202 # compute mean embeddings per instance NxCxEx1(SPATIAL) 203 mean_embeddings = num / num_voxels_per_instance 204 205 # return mean embeddings and additional tensors needed for further computations 206 return mean_embeddings, embeddings_per_instance 207 208 209def _compute_variance_term(cluster_means, embeddings, target, ndim, norm, delta_var): 210 dim_arg = (2, 3) if ndim == 2 else (2, 3, 4) 211 212 # compute the distance to cluster means, result:(NxCxSPATIAL) 213 variance = torch.norm(embeddings - cluster_means, norm, dim=2) 214 215 # get per instance distances (apply instance mask) 216 assert variance.shape == target.shape 217 variance = variance * target 218 219 # zero out distances less than delta_var and sum to get the variance (NxC) 220 variance = torch.clamp(variance - delta_var, min=0) ** 2 221 variance = torch.sum(variance, dim=dim_arg) 222 223 # get number of voxels per instance (NxC) 224 num_voxels_per_instance = torch.sum(target, dim=dim_arg) 225 226 # normalize the variance term 227 C = target.size()[1] 228 variance = torch.sum(variance / num_voxels_per_instance, dim=1) / C 229 return variance 230 231 232def _compute_distance_term(cluster_means, C, ndim, norm, delta_dist): 233 if C == 1: 234 # just one cluster in the batch, so distance term does not contribute to the loss 235 return 0. 236 237 # squeeze space dims 238 for _ in range(ndim): 239 cluster_means = cluster_means.squeeze(-1) 240 # expand cluster_means tensor in order to compute the pair-wise distance between cluster means 241 cluster_means = cluster_means.unsqueeze(1) 242 shape = list(cluster_means.size()) 243 shape[1] = C 244 245 # NxCxCxExSPATIAL(1) 246 cm_matrix1 = cluster_means.expand(shape) 247 # transpose the cluster_means matrix in order to compute pair-wise distances 248 cm_matrix2 = cm_matrix1.permute(0, 2, 1, 3) 249 # compute pair-wise distances (NxCxC) 250 dist_matrix = torch.norm(cm_matrix1 - cm_matrix2, p=norm, dim=3) 251 252 # create matrix for the repulsion distance (i.e. cluster centers further apart than 2 * delta_dist 253 # are not longer repulsed) 254 repulsion_dist = 2 * delta_dist * (1 - torch.eye(C, device=cluster_means.device)) 255 # 1xCxC 256 repulsion_dist = repulsion_dist.unsqueeze(0) 257 # zero out distances grater than 2*delta_dist (NxCxC) 258 hinged_dist = torch.clamp(repulsion_dist - dist_matrix, min=0) ** 2 259 # sum all of the hinged pair-wise distances 260 hinged_dist = torch.sum(hinged_dist, dim=(1, 2)) 261 # normalized by the number of paris and return 262 return hinged_dist / (C * (C - 1)) 263 264 265def _compute_regularizer_term(cluster_means, C, ndim, norm): 266 # squeeze space dims 267 for _ in range(ndim): 268 cluster_means = cluster_means.squeeze(-1) 269 norms = torch.norm(cluster_means, p=norm, dim=2) 270 assert norms.size()[1] == C 271 # return the average norm per batch 272 return torch.sum(norms, dim=1).div(C)
def
expand_as_one_hot( input_: torch.Tensor, C: int, ignore_label: Optional[int] = None) -> torch.Tensor:
138def expand_as_one_hot(input_: torch.Tensor, C: int, ignore_label: Optional[int] = None) -> torch.Tensor: 139 """Expand labels to a one-hot representation. 140 141 Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. 142 Here, SPATIAL = DxHxW in case of 3D data or SPATIAL = HxW in case of 2D data. 143 Make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function won't work. 144 145 Args: 146 input_: A 3D or 4D label image (NxSPATIAL). 147 C: The number of channels/labels. 148 ignore_label: The ignore index to be discarded during the expansion. 149 150 Returns: 151 A 4D or 5D output (NxCxSPATIAL) with one-hot expanded labels. 152 """ 153 assert input_.dim() in (3, 4), f"Unsupported input shape {input_.shape}" 154 155 # expand the input_ tensor to Nx1xSPATIAL before scattering 156 input_ = input_.unsqueeze(1) 157 # create result tensor shape (NxCxSPATIAL) 158 output_shape = list(input_.size()) 159 output_shape[1] = C 160 161 if ignore_label is not None: 162 # create ignore_label mask for the result 163 mask = input_.expand(output_shape) == ignore_label 164 # clone the src tensor and zero out ignore_label in the input_ 165 input_ = input_.clone() 166 input_[input_ == ignore_label] = 0 167 # scatter to get the one-hot tensor 168 result = torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1) 169 # bring back the ignore_label in the result 170 result[mask] = ignore_label 171 return result 172 else: 173 # scatter to get the one-hot tensor 174 return torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1)
Expand labels to a one-hot representation.
Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. Here, SPATIAL = DxHxW in case of 3D data or SPATIAL = HxW in case of 2D data. Make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function won't work.
Arguments:
- input_: A 3D or 4D label image (NxSPATIAL).
- C: The number of channels/labels.
- ignore_label: The ignore index to be discarded during the expansion.
Returns:
A 4D or 5D output (NxCxSPATIAL) with one-hot expanded labels.