torch_em.self_training.loss
1from typing import Optional 2 3import torch 4import torch_em 5import torch.nn as nn 6from torch_em.loss import DiceLoss 7 8 9class DefaultSelfTrainingLoss(nn.Module): 10 """Loss function for self training. 11 12 This loss takes as input a model and its input, as well as (pseudo) labels and potentially 13 a mask for the labels. It then runs prediction with the model and compares the outputs 14 to the (pseudo) labels using an internal loss function. Typically, the labels are derived 15 from the predictions of a teacher model, and the model passed is the student model. 16 17 Args: 18 loss: The internal loss function to use for comparing predictions of the teacher and student model. 19 activation: The activation function to be applied to the prediction before passing it to the loss. 20 """ 21 def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None): 22 super().__init__() 23 self.activation = activation 24 self.loss = loss 25 # TODO serialize the class names and kwargs instead 26 self.init_kwargs = {} 27 28 def __call__( 29 self, model: nn.Module, input_: torch.Tensor, labels: torch.Tensor, label_filter: Optional[torch.Tensor] = None 30 ) -> torch.Tensor: 31 """Compute the loss for self-training. 32 33 Args: 34 model: The model. 35 input_: The model inputs for this batch. 36 labels: The (pseudo) labels for this batch. 37 label_filter: A mask to exclude from the loss computation. 38 39 Returns: 40 The loss value. 41 """ 42 prediction = model(input_) 43 if self.activation is not None: 44 prediction = self.activation(prediction) 45 if label_filter is None: 46 loss = self.loss(prediction, labels) 47 else: 48 loss = self.loss(prediction * label_filter, labels * label_filter) 49 return loss 50 51 52class DefaultSelfTrainingLossAndMetric(nn.Module): 53 """Loss and metric function for self training. 54 55 Similar to `DefaultSelfTrainingLoss`, but computes loss and metric value in one call 56 to avoid running prediction with the model twice. 57 58 Args: 59 loss: The internal loss function to use for comparing predictions of the teacher and student model. 60 metric: The internal metric function to use for comparing predictions of the teacher and student model. 61 activation: The activation function to be applied to the prediction before passing it to the loss. 62 """ 63 def __init__( 64 self, 65 loss: nn.Module = torch_em.loss.DiceLoss(), 66 metric: nn.Module = torch_em.loss.DiceLoss(), 67 activation: Optional[nn.Module] = None 68 ): 69 super().__init__() 70 self.activation = activation 71 self.loss = loss 72 self.metric = metric 73 # TODO serialize the class names and dicts instead 74 self.init_kwargs = {} 75 76 def __call__(self, model, input_, labels, label_filter=None): 77 prediction = model(input_) 78 if self.activation is not None: 79 prediction = self.activation(prediction) 80 if label_filter is None: 81 loss = self.loss(prediction, labels) 82 else: 83 loss = self.loss(prediction * label_filter, labels * label_filter) 84 metric = self.metric(prediction, labels) 85 return loss, metric 86 87 88# TODO: The probabilistic U-Net related code should be refactored to `torch_em.loss` 89# and should be documented properly. 90 91 92def l2_regularisation(m): 93 """@private 94 """ 95 l2_reg = None 96 for W in m.parameters(): 97 if l2_reg is None: 98 l2_reg = W.norm(2) 99 else: 100 l2_reg = l2_reg + W.norm(2) 101 return l2_reg 102 103 104class ProbabilisticUNetLoss(nn.Module): 105 """@private 106 """ 107 # """Loss function for Probabilistic UNet 108 109 # Args: 110 # # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.) 111 # loss [nn.Module] - the loss function to be used. (default: None) 112 # """ 113 def __init__(self, loss=None): 114 super().__init__() 115 self.loss = loss 116 117 def __call__(self, model, input_, labels, label_filter=None): 118 model.forward(input_, labels) 119 120 if self.loss is None: 121 elbo = model.elbo(labels, label_filter) 122 reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ 123 l2_regularisation(model.fcomb.layers) 124 loss = -elbo + 1e-5 * reg_loss 125 126 return loss 127 128 129class ProbabilisticUNetLossAndMetric(nn.Module): 130 """@private 131 """ 132 # """Loss and metric function for Probabilistic UNet. 133 134 # Args: 135 # # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.) 136 # loss [nn.Module] - the loss function to be used. (default: None) 137 138 # metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss) 139 # activation [nn.Module, callable] - the activation function to be applied to the prediction 140 # before evaluating the average predictions. (default: None) 141 # """ 142 def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16): 143 super().__init__() 144 self.activation = activation 145 self.metric = metric 146 self.loss = loss 147 self.prior_samples = prior_samples 148 149 def __call__(self, model, input_, labels, label_filter=None): 150 model.forward(input_, labels) 151 152 if self.loss is None: 153 elbo = model.elbo(labels, label_filter) 154 reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ 155 l2_regularisation(model.fcomb.layers) 156 loss = -elbo + 1e-5 * reg_loss 157 158 samples_per_distribution = [] 159 for _ in range(self.prior_samples): 160 samples = model.sample(testing=False) 161 if self.activation is not None: 162 samples = self.activation(samples) 163 samples_per_distribution.append(samples) 164 165 avg_samples = torch.stack(samples_per_distribution, dim=0).sum(dim=0) / len(samples_per_distribution) 166 metric = self.metric(avg_samples, labels) 167 168 return loss, metric 169 170 171class SelfTrainingLossWithInvertibleAugmentations(nn.Module): 172 """Loss function for self-training with invertible augmentations. 173 174 Variant of `DefaultSelfTrainingLoss` for use with `FixMatchTrainerWithInvertibleAugmentations` 175 and `MeanTeacherTrainerWithInvertibleAugmentations`. Unlike `DefaultSelfTrainingLoss`, this loss 176 receives pre-computed predictions directly rather than a model and input, because the trainer 177 already applies the model and invertible augmentations before calling the loss. 178 179 Args: 180 loss: The internal loss function used to compare student predictions to pseudo-labels. 181 activation: Optional activation applied to the prediction before the loss. 182 """ 183 def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None): 184 super().__init__() 185 self.activation = activation 186 self.loss = loss 187 # TODO serialize the class names and kwargs instead 188 self.init_kwargs = {} 189 190 def __call__( 191 self, 192 prediction: torch.Tensor, 193 labels: torch.Tensor, 194 label_filter: Optional[torch.Tensor] = None, 195 ) -> torch.Tensor: 196 """Compute the self-training loss. 197 198 Args: 199 prediction: Student model predictions, already mapped to the reference frame 200 via the inverse augmentation transform. 201 labels: The (pseudo) labels, mapped to the same reference frame. 202 label_filter: Optional mask or weight tensor. Where provided, both prediction 203 and labels are multiplied by this tensor before the loss is computed. 204 205 Returns: 206 The loss value. 207 """ 208 209 if self.activation is not None: 210 prediction = self.activation(prediction) 211 if label_filter is None: 212 loss = self.loss(prediction, labels) 213 else: 214 loss = self.loss(prediction * label_filter, labels * label_filter) 215 return loss 216 217 218class SelfTrainingLossAndMetricWithInvertibleAugmentations(nn.Module): 219 """Loss and metric function for self-training with invertible augmentations. 220 221 Variant of `DefaultSelfTrainingLossAndMetric` for use with 222 `FixMatchTrainerWithInvertibleAugmentations` and `MeanTeacherTrainerWithInvertibleAugmentations`. 223 Computes both loss and metric in a single call from pre-computed predictions, avoiding a 224 second forward pass. Used during validation where the trainer already holds the predictions. 225 226 Args: 227 loss: The internal loss function used to compare student predictions to pseudo-labels. 228 metric: The internal metric function used to evaluate student predictions against pseudo-labels. 229 activation: Optional activation applied to the prediction before the loss and metric. 230 """ 231 def __init__( 232 self, 233 loss: nn.Module = torch_em.loss.DiceLoss(), 234 metric: nn.Module = torch_em.loss.DiceLoss(), 235 activation: Optional[nn.Module] = None 236 ): 237 super().__init__() 238 self.activation = activation 239 self.loss = loss 240 self.metric = metric 241 # TODO serialize the class names and dicts instead 242 self.init_kwargs = {} 243 244 def __call__( 245 self, 246 prediction: torch.Tensor, 247 labels: torch.Tensor, 248 label_filter: Optional[torch.Tensor] = None, 249 ): 250 """Compute the self-training loss. 251 252 Args: 253 prediction: Student model predictions, already mapped to the reference frame 254 via the inverse augmentation transform. 255 labels: The (pseudo) labels, mapped to the same reference frame. 256 label_filter: Optional mask or weight tensor. Where provided, both prediction 257 and labels are multiplied by this tensor before the loss is computed. 258 259 Returns: 260 The loss and metric value. 261 """ 262 if self.activation is not None: 263 prediction = self.activation(prediction) 264 if label_filter is None: 265 loss = self.loss(prediction, labels) 266 else: 267 loss = self.loss(prediction * label_filter, labels * label_filter) 268 metric = self.metric(prediction, labels) 269 return loss, metric 270 271 272class UniMatchv2Loss(nn.Module): 273 """Loss function for `UniMatchv2Trainer`. 274 275 Extends `SelfTrainingLossWithInvertibleAugmentations` to support the two-student-view scheme 276 of UniMatch v2. When `pred_dim=2`, `prediction` is expected to be a stacked tensor of two 277 student predictions `[pred_s1_inv, pred_s2_inv]`, and the loss is averaged over both views. 278 When `pred_dim=1`, it falls back to the standard single-prediction behaviour. 279 280 Args: 281 loss: The internal loss function used to compare student predictions to pseudo-labels. 282 activation: Optional activation applied to the predictions before the loss. 283 """ 284 def __init__(self, loss: nn.Module = DiceLoss(), activation: Optional[nn.Module] = None): 285 super().__init__() 286 self.activation = activation 287 self.loss = loss 288 self.init_kwargs = {} 289 290 def __call__( 291 self, 292 prediction: torch.Tensor, 293 labels: torch.Tensor, 294 label_filter: Optional[torch.Tensor] = None, 295 pred_dim: int = 1, 296 ) -> torch.Tensor: 297 """Compute the UniMatch v2 self-training loss. 298 299 Args: 300 prediction: Student predictions mapped to the reference frame. When `pred_dim=2`, 301 a stacked tensor of shape `(2, B, C, ...)` holding the two strong-view predictions. 302 When `pred_dim=1`, a standard `(B, C, ...)` prediction tensor. 303 labels: The (pseudo) labels, mapped to the same reference frame. 304 label_filter: Optional mask or weight tensor applied to both prediction and labels 305 before the loss is computed. 306 pred_dim: Number of student views. Use `2` for the standard UniMatch v2 dual-view 307 training and `1` for single-view inference or validation. 308 309 Returns: 310 The loss value. 311 """ 312 313 assert pred_dim in (1, 2), "pred_dim must be either 1 or 2" 314 315 if self.activation is not None: 316 prediction = self.activation(prediction) 317 318 if pred_dim == 2: 319 if label_filter is None: 320 loss = (self.loss(prediction[0], labels) + self.loss(prediction[1], labels)) / 2 321 else: 322 loss = (self.loss( 323 prediction[0] * label_filter, labels * label_filter 324 ) + self.loss(prediction[1] * label_filter, labels * label_filter)) / 2 325 return loss 326 327 else: 328 if label_filter is None: 329 loss = self.loss(prediction, labels) 330 else: 331 loss = self.loss(prediction * label_filter, labels * label_filter) 332 return loss 333 334 335class UniMatchv2LossAndMetric(nn.Module): 336 """Loss and metric function for `UniMatchv2Trainer`. 337 338 Extends `SelfTrainingLossAndMetricWithInvertibleAugmentations` to support the two-student-view scheme 339 of UniMatch v2. `pred_dim` depends on how many views the student model processes 340 at the same time. Supports the same dual-view `pred_dim=2` convention: when two student 341 predictions are stacked, loss and metric are each averaged over both views. 342 When `pred_dim=1`, it falls back to the standard single-prediction behaviour. 343 344 Args: 345 loss: The internal loss function used to compare student predictions to pseudo-labels. 346 metric: The internal metric function used to evaluate student predictions against pseudo-labels. 347 activation: Optional activation applied to the predictions before the loss and metric. 348 """ 349 def __init__( 350 self, 351 loss: nn.Module = DiceLoss(), 352 metric: nn.Module = DiceLoss(), 353 activation: Optional[nn.Module] = None 354 ): 355 super().__init__() 356 self.activation = activation 357 self.loss = loss 358 self.metric = metric 359 self.init_kwargs = {} 360 361 def __call__( 362 self, 363 prediction: torch.Tensor, 364 labels: torch.Tensor, 365 label_filter: Optional[torch.Tensor] = None, 366 pred_dim: int = 1, 367 ): 368 """Compute the UniMatch v2 self-training loss. 369 370 Args: 371 prediction: Student predictions mapped to the reference frame. When `pred_dim=2`, 372 a stacked tensor of shape `(2, B, C, ...)` holding the two strong-view predictions. 373 When `pred_dim=1`, a standard `(B, C, ...)` prediction tensor. 374 labels: The (pseudo) labels, mapped to the same reference frame. 375 label_filter: Optional mask or weight tensor applied to both prediction and labels 376 before the loss is computed. 377 pred_dim: Number of student views. Use `2` for the standard UniMatch v2 dual-view 378 training and `1` for single-view inference or validation. 379 380 Returns: 381 The loss and metric value. 382 """ 383 384 if self.activation is not None: 385 prediction = self.activation(prediction) 386 387 assert pred_dim in (1, 2), "pred_dim must be either 1 or 2" 388 389 if pred_dim == 2: 390 assert len(prediction) == 2, "only implemented for list of len 2" 391 if label_filter is None: 392 loss = (self.loss(prediction[0], labels) + self.loss(prediction[1], labels)) / 2 393 else: 394 loss = (self.loss( 395 prediction[0] * label_filter, labels * label_filter 396 ) + self.loss(prediction[1] * label_filter, labels * label_filter)) / 2 397 metric = (self.metric(prediction[0], labels) + self.metric(prediction[1], labels)) / 2 398 return loss, metric 399 400 else: 401 if label_filter is None: 402 loss = self.loss(prediction, labels) 403 else: 404 loss = self.loss(prediction * label_filter, labels * label_filter) 405 metric = self.metric(prediction, labels) 406 return loss, metric
10class DefaultSelfTrainingLoss(nn.Module): 11 """Loss function for self training. 12 13 This loss takes as input a model and its input, as well as (pseudo) labels and potentially 14 a mask for the labels. It then runs prediction with the model and compares the outputs 15 to the (pseudo) labels using an internal loss function. Typically, the labels are derived 16 from the predictions of a teacher model, and the model passed is the student model. 17 18 Args: 19 loss: The internal loss function to use for comparing predictions of the teacher and student model. 20 activation: The activation function to be applied to the prediction before passing it to the loss. 21 """ 22 def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None): 23 super().__init__() 24 self.activation = activation 25 self.loss = loss 26 # TODO serialize the class names and kwargs instead 27 self.init_kwargs = {} 28 29 def __call__( 30 self, model: nn.Module, input_: torch.Tensor, labels: torch.Tensor, label_filter: Optional[torch.Tensor] = None 31 ) -> torch.Tensor: 32 """Compute the loss for self-training. 33 34 Args: 35 model: The model. 36 input_: The model inputs for this batch. 37 labels: The (pseudo) labels for this batch. 38 label_filter: A mask to exclude from the loss computation. 39 40 Returns: 41 The loss value. 42 """ 43 prediction = model(input_) 44 if self.activation is not None: 45 prediction = self.activation(prediction) 46 if label_filter is None: 47 loss = self.loss(prediction, labels) 48 else: 49 loss = self.loss(prediction * label_filter, labels * label_filter) 50 return loss
Loss function for self training.
This loss takes as input a model and its input, as well as (pseudo) labels and potentially a mask for the labels. It then runs prediction with the model and compares the outputs to the (pseudo) labels using an internal loss function. Typically, the labels are derived from the predictions of a teacher model, and the model passed is the student model.
Arguments:
- loss: The internal loss function to use for comparing predictions of the teacher and student model.
- activation: The activation function to be applied to the prediction before passing it to the loss.
22 def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None): 23 super().__init__() 24 self.activation = activation 25 self.loss = loss 26 # TODO serialize the class names and kwargs instead 27 self.init_kwargs = {}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
53class DefaultSelfTrainingLossAndMetric(nn.Module): 54 """Loss and metric function for self training. 55 56 Similar to `DefaultSelfTrainingLoss`, but computes loss and metric value in one call 57 to avoid running prediction with the model twice. 58 59 Args: 60 loss: The internal loss function to use for comparing predictions of the teacher and student model. 61 metric: The internal metric function to use for comparing predictions of the teacher and student model. 62 activation: The activation function to be applied to the prediction before passing it to the loss. 63 """ 64 def __init__( 65 self, 66 loss: nn.Module = torch_em.loss.DiceLoss(), 67 metric: nn.Module = torch_em.loss.DiceLoss(), 68 activation: Optional[nn.Module] = None 69 ): 70 super().__init__() 71 self.activation = activation 72 self.loss = loss 73 self.metric = metric 74 # TODO serialize the class names and dicts instead 75 self.init_kwargs = {} 76 77 def __call__(self, model, input_, labels, label_filter=None): 78 prediction = model(input_) 79 if self.activation is not None: 80 prediction = self.activation(prediction) 81 if label_filter is None: 82 loss = self.loss(prediction, labels) 83 else: 84 loss = self.loss(prediction * label_filter, labels * label_filter) 85 metric = self.metric(prediction, labels) 86 return loss, metric
Loss and metric function for self training.
Similar to DefaultSelfTrainingLoss, but computes loss and metric value in one call
to avoid running prediction with the model twice.
Arguments:
- loss: The internal loss function to use for comparing predictions of the teacher and student model.
- metric: The internal metric function to use for comparing predictions of the teacher and student model.
- activation: The activation function to be applied to the prediction before passing it to the loss.
64 def __init__( 65 self, 66 loss: nn.Module = torch_em.loss.DiceLoss(), 67 metric: nn.Module = torch_em.loss.DiceLoss(), 68 activation: Optional[nn.Module] = None 69 ): 70 super().__init__() 71 self.activation = activation 72 self.loss = loss 73 self.metric = metric 74 # TODO serialize the class names and dicts instead 75 self.init_kwargs = {}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
172class SelfTrainingLossWithInvertibleAugmentations(nn.Module): 173 """Loss function for self-training with invertible augmentations. 174 175 Variant of `DefaultSelfTrainingLoss` for use with `FixMatchTrainerWithInvertibleAugmentations` 176 and `MeanTeacherTrainerWithInvertibleAugmentations`. Unlike `DefaultSelfTrainingLoss`, this loss 177 receives pre-computed predictions directly rather than a model and input, because the trainer 178 already applies the model and invertible augmentations before calling the loss. 179 180 Args: 181 loss: The internal loss function used to compare student predictions to pseudo-labels. 182 activation: Optional activation applied to the prediction before the loss. 183 """ 184 def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None): 185 super().__init__() 186 self.activation = activation 187 self.loss = loss 188 # TODO serialize the class names and kwargs instead 189 self.init_kwargs = {} 190 191 def __call__( 192 self, 193 prediction: torch.Tensor, 194 labels: torch.Tensor, 195 label_filter: Optional[torch.Tensor] = None, 196 ) -> torch.Tensor: 197 """Compute the self-training loss. 198 199 Args: 200 prediction: Student model predictions, already mapped to the reference frame 201 via the inverse augmentation transform. 202 labels: The (pseudo) labels, mapped to the same reference frame. 203 label_filter: Optional mask or weight tensor. Where provided, both prediction 204 and labels are multiplied by this tensor before the loss is computed. 205 206 Returns: 207 The loss value. 208 """ 209 210 if self.activation is not None: 211 prediction = self.activation(prediction) 212 if label_filter is None: 213 loss = self.loss(prediction, labels) 214 else: 215 loss = self.loss(prediction * label_filter, labels * label_filter) 216 return loss
Loss function for self-training with invertible augmentations.
Variant of DefaultSelfTrainingLoss for use with FixMatchTrainerWithInvertibleAugmentations
and MeanTeacherTrainerWithInvertibleAugmentations. Unlike DefaultSelfTrainingLoss, this loss
receives pre-computed predictions directly rather than a model and input, because the trainer
already applies the model and invertible augmentations before calling the loss.
Arguments:
- loss: The internal loss function used to compare student predictions to pseudo-labels.
- activation: Optional activation applied to the prediction before the loss.
184 def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None): 185 super().__init__() 186 self.activation = activation 187 self.loss = loss 188 # TODO serialize the class names and kwargs instead 189 self.init_kwargs = {}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
219class SelfTrainingLossAndMetricWithInvertibleAugmentations(nn.Module): 220 """Loss and metric function for self-training with invertible augmentations. 221 222 Variant of `DefaultSelfTrainingLossAndMetric` for use with 223 `FixMatchTrainerWithInvertibleAugmentations` and `MeanTeacherTrainerWithInvertibleAugmentations`. 224 Computes both loss and metric in a single call from pre-computed predictions, avoiding a 225 second forward pass. Used during validation where the trainer already holds the predictions. 226 227 Args: 228 loss: The internal loss function used to compare student predictions to pseudo-labels. 229 metric: The internal metric function used to evaluate student predictions against pseudo-labels. 230 activation: Optional activation applied to the prediction before the loss and metric. 231 """ 232 def __init__( 233 self, 234 loss: nn.Module = torch_em.loss.DiceLoss(), 235 metric: nn.Module = torch_em.loss.DiceLoss(), 236 activation: Optional[nn.Module] = None 237 ): 238 super().__init__() 239 self.activation = activation 240 self.loss = loss 241 self.metric = metric 242 # TODO serialize the class names and dicts instead 243 self.init_kwargs = {} 244 245 def __call__( 246 self, 247 prediction: torch.Tensor, 248 labels: torch.Tensor, 249 label_filter: Optional[torch.Tensor] = None, 250 ): 251 """Compute the self-training loss. 252 253 Args: 254 prediction: Student model predictions, already mapped to the reference frame 255 via the inverse augmentation transform. 256 labels: The (pseudo) labels, mapped to the same reference frame. 257 label_filter: Optional mask or weight tensor. Where provided, both prediction 258 and labels are multiplied by this tensor before the loss is computed. 259 260 Returns: 261 The loss and metric value. 262 """ 263 if self.activation is not None: 264 prediction = self.activation(prediction) 265 if label_filter is None: 266 loss = self.loss(prediction, labels) 267 else: 268 loss = self.loss(prediction * label_filter, labels * label_filter) 269 metric = self.metric(prediction, labels) 270 return loss, metric
Loss and metric function for self-training with invertible augmentations.
Variant of DefaultSelfTrainingLossAndMetric for use with
FixMatchTrainerWithInvertibleAugmentations and MeanTeacherTrainerWithInvertibleAugmentations.
Computes both loss and metric in a single call from pre-computed predictions, avoiding a
second forward pass. Used during validation where the trainer already holds the predictions.
Arguments:
- loss: The internal loss function used to compare student predictions to pseudo-labels.
- metric: The internal metric function used to evaluate student predictions against pseudo-labels.
- activation: Optional activation applied to the prediction before the loss and metric.
232 def __init__( 233 self, 234 loss: nn.Module = torch_em.loss.DiceLoss(), 235 metric: nn.Module = torch_em.loss.DiceLoss(), 236 activation: Optional[nn.Module] = None 237 ): 238 super().__init__() 239 self.activation = activation 240 self.loss = loss 241 self.metric = metric 242 # TODO serialize the class names and dicts instead 243 self.init_kwargs = {}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
273class UniMatchv2Loss(nn.Module): 274 """Loss function for `UniMatchv2Trainer`. 275 276 Extends `SelfTrainingLossWithInvertibleAugmentations` to support the two-student-view scheme 277 of UniMatch v2. When `pred_dim=2`, `prediction` is expected to be a stacked tensor of two 278 student predictions `[pred_s1_inv, pred_s2_inv]`, and the loss is averaged over both views. 279 When `pred_dim=1`, it falls back to the standard single-prediction behaviour. 280 281 Args: 282 loss: The internal loss function used to compare student predictions to pseudo-labels. 283 activation: Optional activation applied to the predictions before the loss. 284 """ 285 def __init__(self, loss: nn.Module = DiceLoss(), activation: Optional[nn.Module] = None): 286 super().__init__() 287 self.activation = activation 288 self.loss = loss 289 self.init_kwargs = {} 290 291 def __call__( 292 self, 293 prediction: torch.Tensor, 294 labels: torch.Tensor, 295 label_filter: Optional[torch.Tensor] = None, 296 pred_dim: int = 1, 297 ) -> torch.Tensor: 298 """Compute the UniMatch v2 self-training loss. 299 300 Args: 301 prediction: Student predictions mapped to the reference frame. When `pred_dim=2`, 302 a stacked tensor of shape `(2, B, C, ...)` holding the two strong-view predictions. 303 When `pred_dim=1`, a standard `(B, C, ...)` prediction tensor. 304 labels: The (pseudo) labels, mapped to the same reference frame. 305 label_filter: Optional mask or weight tensor applied to both prediction and labels 306 before the loss is computed. 307 pred_dim: Number of student views. Use `2` for the standard UniMatch v2 dual-view 308 training and `1` for single-view inference or validation. 309 310 Returns: 311 The loss value. 312 """ 313 314 assert pred_dim in (1, 2), "pred_dim must be either 1 or 2" 315 316 if self.activation is not None: 317 prediction = self.activation(prediction) 318 319 if pred_dim == 2: 320 if label_filter is None: 321 loss = (self.loss(prediction[0], labels) + self.loss(prediction[1], labels)) / 2 322 else: 323 loss = (self.loss( 324 prediction[0] * label_filter, labels * label_filter 325 ) + self.loss(prediction[1] * label_filter, labels * label_filter)) / 2 326 return loss 327 328 else: 329 if label_filter is None: 330 loss = self.loss(prediction, labels) 331 else: 332 loss = self.loss(prediction * label_filter, labels * label_filter) 333 return loss
Loss function for UniMatchv2Trainer.
Extends SelfTrainingLossWithInvertibleAugmentations to support the two-student-view scheme
of UniMatch v2. When pred_dim=2, prediction is expected to be a stacked tensor of two
student predictions [pred_s1_inv, pred_s2_inv], and the loss is averaged over both views.
When pred_dim=1, it falls back to the standard single-prediction behaviour.
Arguments:
- loss: The internal loss function used to compare student predictions to pseudo-labels.
- activation: Optional activation applied to the predictions before the loss.
285 def __init__(self, loss: nn.Module = DiceLoss(), activation: Optional[nn.Module] = None): 286 super().__init__() 287 self.activation = activation 288 self.loss = loss 289 self.init_kwargs = {}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
336class UniMatchv2LossAndMetric(nn.Module): 337 """Loss and metric function for `UniMatchv2Trainer`. 338 339 Extends `SelfTrainingLossAndMetricWithInvertibleAugmentations` to support the two-student-view scheme 340 of UniMatch v2. `pred_dim` depends on how many views the student model processes 341 at the same time. Supports the same dual-view `pred_dim=2` convention: when two student 342 predictions are stacked, loss and metric are each averaged over both views. 343 When `pred_dim=1`, it falls back to the standard single-prediction behaviour. 344 345 Args: 346 loss: The internal loss function used to compare student predictions to pseudo-labels. 347 metric: The internal metric function used to evaluate student predictions against pseudo-labels. 348 activation: Optional activation applied to the predictions before the loss and metric. 349 """ 350 def __init__( 351 self, 352 loss: nn.Module = DiceLoss(), 353 metric: nn.Module = DiceLoss(), 354 activation: Optional[nn.Module] = None 355 ): 356 super().__init__() 357 self.activation = activation 358 self.loss = loss 359 self.metric = metric 360 self.init_kwargs = {} 361 362 def __call__( 363 self, 364 prediction: torch.Tensor, 365 labels: torch.Tensor, 366 label_filter: Optional[torch.Tensor] = None, 367 pred_dim: int = 1, 368 ): 369 """Compute the UniMatch v2 self-training loss. 370 371 Args: 372 prediction: Student predictions mapped to the reference frame. When `pred_dim=2`, 373 a stacked tensor of shape `(2, B, C, ...)` holding the two strong-view predictions. 374 When `pred_dim=1`, a standard `(B, C, ...)` prediction tensor. 375 labels: The (pseudo) labels, mapped to the same reference frame. 376 label_filter: Optional mask or weight tensor applied to both prediction and labels 377 before the loss is computed. 378 pred_dim: Number of student views. Use `2` for the standard UniMatch v2 dual-view 379 training and `1` for single-view inference or validation. 380 381 Returns: 382 The loss and metric value. 383 """ 384 385 if self.activation is not None: 386 prediction = self.activation(prediction) 387 388 assert pred_dim in (1, 2), "pred_dim must be either 1 or 2" 389 390 if pred_dim == 2: 391 assert len(prediction) == 2, "only implemented for list of len 2" 392 if label_filter is None: 393 loss = (self.loss(prediction[0], labels) + self.loss(prediction[1], labels)) / 2 394 else: 395 loss = (self.loss( 396 prediction[0] * label_filter, labels * label_filter 397 ) + self.loss(prediction[1] * label_filter, labels * label_filter)) / 2 398 metric = (self.metric(prediction[0], labels) + self.metric(prediction[1], labels)) / 2 399 return loss, metric 400 401 else: 402 if label_filter is None: 403 loss = self.loss(prediction, labels) 404 else: 405 loss = self.loss(prediction * label_filter, labels * label_filter) 406 metric = self.metric(prediction, labels) 407 return loss, metric
Loss and metric function for UniMatchv2Trainer.
Extends SelfTrainingLossAndMetricWithInvertibleAugmentations to support the two-student-view scheme
of UniMatch v2. pred_dim depends on how many views the student model processes
at the same time. Supports the same dual-view pred_dim=2 convention: when two student
predictions are stacked, loss and metric are each averaged over both views.
When pred_dim=1, it falls back to the standard single-prediction behaviour.
Arguments:
- loss: The internal loss function used to compare student predictions to pseudo-labels.
- metric: The internal metric function used to evaluate student predictions against pseudo-labels.
- activation: Optional activation applied to the predictions before the loss and metric.
350 def __init__( 351 self, 352 loss: nn.Module = DiceLoss(), 353 metric: nn.Module = DiceLoss(), 354 activation: Optional[nn.Module] = None 355 ): 356 super().__init__() 357 self.activation = activation 358 self.loss = loss 359 self.metric = metric 360 self.init_kwargs = {}
Initialize internal Module state, shared by both nn.Module and ScriptModule.