torch_em.loss.contrastive_impl
1import torch 2try: 3 from torch_scatter import scatter_mean 4except Exception: 5 scatter_mean = None 6 7# 8# implementation using torch_scatter 9# 10 11 12def _compute_cluster_means_scatter(input_, target, ndim, n_lbl=None): 13 assert scatter_mean is not None 14 assert ndim in (2, 3) 15 if ndim == 2: 16 feat = input_.permute(1, 0, 2, 3).flatten(1) 17 else: 18 feat = input_.permute(1, 0, 2, 3, 4).flatten(1) 19 lbl = target.flatten() 20 if n_lbl is None: 21 n_lbl = torch.unique(target).size(0) 22 mean_embeddings = scatter_mean(feat, lbl, dim=1, dim_size=n_lbl) 23 return mean_embeddings.transpose(1, 0) 24 25 26def _compute_distance_term_scatter(cluster_means, norm, delta_dist, ignore_labels=None): 27 C = cluster_means.shape[0] 28 if C == 1: 29 # just one cluster in the batch, so distance term does not contribute to the loss 30 return 0. 31 32 # expand cluster_means tensor in order to compute the pair-wise distance between cluster means 33 cluster_means = cluster_means.unsqueeze(0) 34 shape = list(cluster_means.size()) 35 shape[0] = C 36 37 # CxCxE 38 cm_matrix1 = cluster_means.expand(shape) 39 # transpose the cluster_means matrix in order to compute pair-wise distances 40 cm_matrix2 = cm_matrix1.permute(1, 0, 2) 41 # compute pair-wise distances (CxC) 42 dist_matrix = torch.norm(cm_matrix1 - cm_matrix2, p=norm, dim=2) 43 44 C_norm = C 45 if ignore_labels is not None: 46 # TODO implement arbitrary ignore labels 47 assert ignore_labels == [0], "Only zero ignore label supported so far" 48 if C == 2: 49 # just two cluster instances, including one which is ignored, 50 # i.e. distance term does not contribute to the loss 51 return 0.0 52 # set the distance to ignore-labels to be greater than 2*delta_dist, 53 # so that it does not contribute to the loss because of the hinge at 2*delta_dist 54 55 # find minimum dist 56 d_min = torch.min(dist_matrix[dist_matrix > 0]).item() 57 # dist_multiplier = 2 * delta_dist / d_min + epsilon 58 dist_multiplier = 2 * delta_dist / d_min + 1e-3 59 # create distance mask 60 dist_mask = torch.ones_like(dist_matrix) 61 dist_mask[0, 1:] = dist_multiplier 62 dist_mask[1:, 0] = dist_multiplier 63 64 # mask the dist_matrix 65 dist_matrix = dist_matrix * dist_mask 66 # decrease number of instances 67 C_norm -= 1 68 69 # create matrix for the repulsion distance (i.e. cluster centers further apart than 2 * delta_dist 70 # are not longer repulsed) 71 # CxC 72 repulsion_dist = 2 * delta_dist * (1 - torch.eye(C, device=cluster_means.device)) 73 # zero out distances grater than 2*delta_dist (CxC) 74 hinged_dist = torch.clamp(repulsion_dist - dist_matrix, min=0) ** 2 75 # sum all of the hinged pair-wise distances 76 hinged_dist = torch.sum(hinged_dist, dim=(0, 1)) 77 # normalized by the number of paris and return 78 return hinged_dist / (C_norm * (C_norm - 1)) 79 80 81# NOTE: it would be better to not expand the instance sizes spatially, but instead expand the 82# instance dimension once we have the variance summed up and then divide by the instance sizes 83# (both for performance and numerical stability) 84def _compute_variance_term_scatter( 85 cluster_means, embeddings, target, norm, delta_var, instance_sizes, ignore_labels=None 86): 87 assert cluster_means.shape[1] == embeddings.shape[1] 88 ndim = embeddings.ndim - 2 89 assert ndim in (2, 3), f"{ndim}" 90 n_instances = cluster_means.shape[0] 91 92 # compute the spatial mean and instance fields by scattering with the target tensor 93 cluster_means_spatial = cluster_means[target] 94 instance_sizes_spatial = instance_sizes[target] 95 96 # permute the embedding dimension to axis 1 97 if ndim == 2: 98 cluster_means_spatial = cluster_means_spatial.permute(0, 3, 1, 2) 99 dim_arg = (1, 2) 100 else: 101 cluster_means_spatial = cluster_means_spatial.permute(0, 4, 1, 2, 3) 102 dim_arg = (1, 2, 3) 103 assert embeddings.shape == cluster_means_spatial.shape 104 105 # compute the variance 106 variance = torch.norm(embeddings - cluster_means_spatial, norm, dim=1) 107 108 # apply the ignore labels (if given) 109 if ignore_labels is not None: 110 assert isinstance(ignore_labels, list) 111 # mask out the ignore labels 112 mask = torch.ones_like(variance) 113 mask[torch.isin(mask, torch.tensor(ignore_labels).to(mask.device))] 114 variance *= mask 115 # decrease number of instances 116 n_instances -= len(ignore_labels) 117 # if there are only ignore labels in the target return 0 118 if n_instances == 0: 119 return 0.0 120 121 # hinge the variance 122 variance = torch.clamp(variance - delta_var, min=0) ** 2 123 assert variance.shape == instance_sizes_spatial.shape 124 125 # normalize the variance by instance sizes and number of instances and sum it up 126 variance = torch.sum(variance / instance_sizes_spatial, dim=dim_arg) / n_instances 127 return variance 128 129 130# 131# pure torch implementation 132# 133 134 135def expand_as_one_hot(input_, C, ignore_label=None): 136 """ 137 Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. 138 NOTE: make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function 139 won't work. 140 141 SPATIAL = DxHxW in case of 3D or SPATIAL = HxW in case of 2D 142 :param input_: 3D or 4D label image (NxSPATIAL) 143 :param C: number of channels/labels 144 :param ignore_label: ignore index to be kept during the expansion 145 :return: 4D or 5D output image (NxCxSPATIAL) 146 """ 147 assert input_.dim() in (3, 4), f"Unsupported input shape {input_.shape}" 148 149 # expand the input_ tensor to Nx1xSPATIAL before scattering 150 input_ = input_.unsqueeze(1) 151 # create result tensor shape (NxCxSPATIAL) 152 output_shape = list(input_.size()) 153 output_shape[1] = C 154 155 if ignore_label is not None: 156 # create ignore_label mask for the result 157 mask = input_.expand(output_shape) == ignore_label 158 # clone the src tensor and zero out ignore_label in the input_ 159 input_ = input_.clone() 160 input_[input_ == ignore_label] = 0 161 # scatter to get the one-hot tensor 162 result = torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1) 163 # bring back the ignore_label in the result 164 result[mask] = ignore_label 165 return result 166 else: 167 # scatter to get the one-hot tensor 168 return torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1) 169 170 171def _compute_cluster_means(input_, target, ndim): 172 173 dim_arg = (3, 4) if ndim == 2 else (3, 4, 5) 174 175 embedding_dims = input_.size()[1] 176 177 # expand target: NxCxSPATIAL -> # NxCx1xSPATIAL 178 target = target.unsqueeze(2) 179 180 # NOTE we could try to reuse this in '_compute_variance_term', 181 # but it has another dimensionality, so we would need to drop one axis 182 # get number of voxels in each cluster output: NxCx1(SPATIAL) 183 num_voxels_per_instance = torch.sum(target, dim=dim_arg, keepdim=True) 184 185 # expand target: NxCx1xSPATIAL -> # NxCxExSPATIAL 186 shape = list(target.size()) 187 shape[2] = embedding_dims 188 target = target.expand(shape) 189 190 # expand input_: NxExSPATIAL -> Nx1xExSPATIAL 191 input_ = input_.unsqueeze(1) 192 193 # sum embeddings in each instance (multiply first via broadcasting) output: NxCxEx1(SPATIAL) 194 embeddings_per_instance = input_ * target 195 num = torch.sum(embeddings_per_instance, dim=dim_arg, keepdim=True) 196 197 # compute mean embeddings per instance NxCxEx1(SPATIAL) 198 mean_embeddings = num / num_voxels_per_instance 199 200 # return mean embeddings and additional tensors needed for further computations 201 return mean_embeddings, embeddings_per_instance 202 203 204def _compute_variance_term(cluster_means, embeddings, target, ndim, norm, delta_var): 205 dim_arg = (2, 3) if ndim == 2 else (2, 3, 4) 206 207 # compute the distance to cluster means, result:(NxCxSPATIAL) 208 variance = torch.norm(embeddings - cluster_means, norm, dim=2) 209 210 # get per instance distances (apply instance mask) 211 assert variance.shape == target.shape 212 variance = variance * target 213 214 # zero out distances less than delta_var and sum to get the variance (NxC) 215 variance = torch.clamp(variance - delta_var, min=0) ** 2 216 variance = torch.sum(variance, dim=dim_arg) 217 218 # get number of voxels per instance (NxC) 219 num_voxels_per_instance = torch.sum(target, dim=dim_arg) 220 221 # normalize the variance term 222 C = target.size()[1] 223 variance = torch.sum(variance / num_voxels_per_instance, dim=1) / C 224 return variance 225 226 227def _compute_distance_term(cluster_means, C, ndim, norm, delta_dist): 228 if C == 1: 229 # just one cluster in the batch, so distance term does not contribute to the loss 230 return 0. 231 232 # squeeze space dims 233 for _ in range(ndim): 234 cluster_means = cluster_means.squeeze(-1) 235 # expand cluster_means tensor in order to compute the pair-wise distance between cluster means 236 cluster_means = cluster_means.unsqueeze(1) 237 shape = list(cluster_means.size()) 238 shape[1] = C 239 240 # NxCxCxExSPATIAL(1) 241 cm_matrix1 = cluster_means.expand(shape) 242 # transpose the cluster_means matrix in order to compute pair-wise distances 243 cm_matrix2 = cm_matrix1.permute(0, 2, 1, 3) 244 # compute pair-wise distances (NxCxC) 245 dist_matrix = torch.norm(cm_matrix1 - cm_matrix2, p=norm, dim=3) 246 247 # create matrix for the repulsion distance (i.e. cluster centers further apart than 2 * delta_dist 248 # are not longer repulsed) 249 repulsion_dist = 2 * delta_dist * (1 - torch.eye(C, device=cluster_means.device)) 250 # 1xCxC 251 repulsion_dist = repulsion_dist.unsqueeze(0) 252 # zero out distances grater than 2*delta_dist (NxCxC) 253 hinged_dist = torch.clamp(repulsion_dist - dist_matrix, min=0) ** 2 254 # sum all of the hinged pair-wise distances 255 hinged_dist = torch.sum(hinged_dist, dim=(1, 2)) 256 # normalized by the number of paris and return 257 return hinged_dist / (C * (C - 1)) 258 259 260def _compute_regularizer_term(cluster_means, C, ndim, norm): 261 # squeeze space dims 262 for _ in range(ndim): 263 cluster_means = cluster_means.squeeze(-1) 264 norms = torch.norm(cluster_means, p=norm, dim=2) 265 assert norms.size()[1] == C 266 # return the average norm per batch 267 return torch.sum(norms, dim=1).div(C)
def
expand_as_one_hot(input_, C, ignore_label=None):
136def expand_as_one_hot(input_, C, ignore_label=None): 137 """ 138 Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. 139 NOTE: make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function 140 won't work. 141 142 SPATIAL = DxHxW in case of 3D or SPATIAL = HxW in case of 2D 143 :param input_: 3D or 4D label image (NxSPATIAL) 144 :param C: number of channels/labels 145 :param ignore_label: ignore index to be kept during the expansion 146 :return: 4D or 5D output image (NxCxSPATIAL) 147 """ 148 assert input_.dim() in (3, 4), f"Unsupported input shape {input_.shape}" 149 150 # expand the input_ tensor to Nx1xSPATIAL before scattering 151 input_ = input_.unsqueeze(1) 152 # create result tensor shape (NxCxSPATIAL) 153 output_shape = list(input_.size()) 154 output_shape[1] = C 155 156 if ignore_label is not None: 157 # create ignore_label mask for the result 158 mask = input_.expand(output_shape) == ignore_label 159 # clone the src tensor and zero out ignore_label in the input_ 160 input_ = input_.clone() 161 input_[input_ == ignore_label] = 0 162 # scatter to get the one-hot tensor 163 result = torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1) 164 # bring back the ignore_label in the result 165 result[mask] = ignore_label 166 return result 167 else: 168 # scatter to get the one-hot tensor 169 return torch.zeros(output_shape, device=input_.device).scatter_(1, input_, 1)
Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. NOTE: make sure that the input_ contains consecutive numbers starting from 0, otherwise the scatter_ function won't work.
SPATIAL = DxHxW in case of 3D or SPATIAL = HxW in case of 2D
Parameters
- input_: 3D or 4D label image (NxSPATIAL)
- C: number of channels/labels
- ignore_label: ignore index to be kept during the expansion
Returns
4D or 5D output image (NxCxSPATIAL)