torch_em.self_training.loss

  1from typing import Optional
  2
  3import torch
  4import torch_em
  5import torch.nn as nn
  6from torch_em.loss import DiceLoss
  7
  8
  9class DefaultSelfTrainingLoss(nn.Module):
 10    """Loss function for self training.
 11
 12    This loss takes as input a model and its input, as well as (pseudo) labels and potentially
 13    a mask for the labels. It then runs prediction with the model and compares the outputs
 14    to the (pseudo) labels using an internal loss function. Typically, the labels are derived
 15    from the predictions of a teacher model, and the model passed is the student model.
 16
 17    Args:
 18        loss: The internal loss function to use for comparing predictions of the teacher and student model.
 19        activation: The activation function to be applied to the prediction before passing it to the loss.
 20    """
 21    def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None):
 22        super().__init__()
 23        self.activation = activation
 24        self.loss = loss
 25        # TODO serialize the class names and kwargs instead
 26        self.init_kwargs = {}
 27
 28    def __call__(
 29        self, model: nn.Module, input_: torch.Tensor, labels: torch.Tensor, label_filter: Optional[torch.Tensor] = None
 30    ) -> torch.Tensor:
 31        """Compute the loss for self-training.
 32
 33        Args:
 34            model: The model.
 35            input_: The model inputs for this batch.
 36            labels: The (pseudo) labels for this batch.
 37            label_filter: A mask to exclude from the loss computation.
 38
 39        Returns:
 40            The loss value.
 41        """
 42        prediction = model(input_)
 43        if self.activation is not None:
 44            prediction = self.activation(prediction)
 45        if label_filter is None:
 46            loss = self.loss(prediction, labels)
 47        else:
 48            loss = self.loss(prediction * label_filter, labels * label_filter)
 49        return loss
 50
 51
 52class DefaultSelfTrainingLossAndMetric(nn.Module):
 53    """Loss and metric function for self training.
 54
 55    Similar to `DefaultSelfTrainingLoss`, but computes loss and metric value in one call
 56    to avoid running prediction with the model twice.
 57
 58    Args:
 59        loss: The internal loss function to use for comparing predictions of the teacher and student model.
 60        metric: The internal metric function to use for comparing predictions of the teacher and student model.
 61        activation: The activation function to be applied to the prediction before passing it to the loss.
 62    """
 63    def __init__(
 64        self,
 65        loss: nn.Module = torch_em.loss.DiceLoss(),
 66        metric: nn.Module = torch_em.loss.DiceLoss(),
 67        activation: Optional[nn.Module] = None
 68    ):
 69        super().__init__()
 70        self.activation = activation
 71        self.loss = loss
 72        self.metric = metric
 73        # TODO serialize the class names and dicts instead
 74        self.init_kwargs = {}
 75
 76    def __call__(self, model, input_, labels, label_filter=None):
 77        prediction = model(input_)
 78        if self.activation is not None:
 79            prediction = self.activation(prediction)
 80        if label_filter is None:
 81            loss = self.loss(prediction, labels)
 82        else:
 83            loss = self.loss(prediction * label_filter, labels * label_filter)
 84        metric = self.metric(prediction, labels)
 85        return loss, metric
 86
 87
 88# TODO: The probabilistic U-Net related code should be refactored to `torch_em.loss`
 89# and should be documented properly.
 90
 91
 92def l2_regularisation(m):
 93    """@private
 94    """
 95    l2_reg = None
 96    for W in m.parameters():
 97        if l2_reg is None:
 98            l2_reg = W.norm(2)
 99        else:
100            l2_reg = l2_reg + W.norm(2)
101    return l2_reg
102
103
104class ProbabilisticUNetLoss(nn.Module):
105    """@private
106    """
107    # """Loss function for Probabilistic UNet
108
109    # Args:
110    #     # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
111    #     loss [nn.Module] - the loss function to be used. (default: None)
112    # """
113    def __init__(self, loss=None):
114        super().__init__()
115        self.loss = loss
116
117    def __call__(self, model, input_, labels, label_filter=None):
118        model.forward(input_, labels)
119
120        if self.loss is None:
121            elbo = model.elbo(labels, label_filter)
122            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
123                l2_regularisation(model.fcomb.layers)
124            loss = -elbo + 1e-5 * reg_loss
125
126        return loss
127
128
129class ProbabilisticUNetLossAndMetric(nn.Module):
130    """@private
131    """
132    # """Loss and metric function for Probabilistic UNet.
133
134    # Args:
135    #     # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
136    #     loss [nn.Module] - the loss function to be used. (default: None)
137
138    #     metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
139    #     activation [nn.Module, callable] - the activation function to be applied to the prediction
140    #         before evaluating the average predictions. (default: None)
141    # """
142    def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16):
143        super().__init__()
144        self.activation = activation
145        self.metric = metric
146        self.loss = loss
147        self.prior_samples = prior_samples
148
149    def __call__(self, model, input_, labels, label_filter=None):
150        model.forward(input_, labels)
151
152        if self.loss is None:
153            elbo = model.elbo(labels, label_filter)
154            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
155                l2_regularisation(model.fcomb.layers)
156            loss = -elbo + 1e-5 * reg_loss
157
158        samples_per_distribution = []
159        for _ in range(self.prior_samples):
160            samples = model.sample(testing=False)
161            if self.activation is not None:
162                samples = self.activation(samples)
163            samples_per_distribution.append(samples)
164
165        avg_samples = torch.stack(samples_per_distribution, dim=0).sum(dim=0) / len(samples_per_distribution)
166        metric = self.metric(avg_samples, labels)
167
168        return loss, metric
169
170
171class SelfTrainingLossWithInvertibleAugmentations(nn.Module):
172    """Loss function for self-training with invertible augmentations.
173
174    Variant of `DefaultSelfTrainingLoss` for use with `FixMatchTrainerWithInvertibleAugmentations`
175    and `MeanTeacherTrainerWithInvertibleAugmentations`. Unlike `DefaultSelfTrainingLoss`, this loss
176    receives pre-computed predictions directly rather than a model and input, because the trainer
177    already applies the model and invertible augmentations before calling the loss.
178
179    Args:
180        loss: The internal loss function used to compare student predictions to pseudo-labels.
181        activation: Optional activation applied to the prediction before the loss.
182    """
183    def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None):
184        super().__init__()
185        self.activation = activation
186        self.loss = loss
187        # TODO serialize the class names and kwargs instead
188        self.init_kwargs = {}
189
190    def __call__(
191        self,
192        prediction: torch.Tensor,
193        labels: torch.Tensor,
194        label_filter: Optional[torch.Tensor] = None,
195    ) -> torch.Tensor:
196        """Compute the self-training loss.
197
198        Args:
199            prediction: Student model predictions, already mapped to the reference frame
200                via the inverse augmentation transform.
201            labels: The (pseudo) labels, mapped to the same reference frame.
202            label_filter: Optional mask or weight tensor. Where provided, both prediction
203                and labels are multiplied by this tensor before the loss is computed.
204
205        Returns:
206            The loss value.
207        """
208
209        if self.activation is not None:
210            prediction = self.activation(prediction)
211        if label_filter is None:
212            loss = self.loss(prediction, labels)
213        else:
214            loss = self.loss(prediction * label_filter, labels * label_filter)
215        return loss
216
217
218class SelfTrainingLossAndMetricWithInvertibleAugmentations(nn.Module):
219    """Loss and metric function for self-training with invertible augmentations.
220
221    Variant of `DefaultSelfTrainingLossAndMetric` for use with
222    `FixMatchTrainerWithInvertibleAugmentations` and `MeanTeacherTrainerWithInvertibleAugmentations`.
223    Computes both loss and metric in a single call from pre-computed predictions, avoiding a
224    second forward pass. Used during validation where the trainer already holds the predictions.
225
226    Args:
227        loss: The internal loss function used to compare student predictions to pseudo-labels.
228        metric: The internal metric function used to evaluate student predictions against pseudo-labels.
229        activation: Optional activation applied to the prediction before the loss and metric.
230    """
231    def __init__(
232        self,
233        loss: nn.Module = torch_em.loss.DiceLoss(),
234        metric: nn.Module = torch_em.loss.DiceLoss(),
235        activation: Optional[nn.Module] = None
236    ):
237        super().__init__()
238        self.activation = activation
239        self.loss = loss
240        self.metric = metric
241        # TODO serialize the class names and dicts instead
242        self.init_kwargs = {}
243
244    def __call__(
245        self,
246        prediction: torch.Tensor,
247        labels: torch.Tensor,
248        label_filter: Optional[torch.Tensor] = None,
249    ):
250        """Compute the self-training loss.
251
252        Args:
253            prediction: Student model predictions, already mapped to the reference frame
254                via the inverse augmentation transform.
255            labels: The (pseudo) labels, mapped to the same reference frame.
256            label_filter: Optional mask or weight tensor. Where provided, both prediction
257                and labels are multiplied by this tensor before the loss is computed.
258
259        Returns:
260            The loss and metric value.
261        """
262        if self.activation is not None:
263            prediction = self.activation(prediction)
264        if label_filter is None:
265            loss = self.loss(prediction, labels)
266        else:
267            loss = self.loss(prediction * label_filter, labels * label_filter)
268        metric = self.metric(prediction, labels)
269        return loss, metric
270
271
272class UniMatchv2Loss(nn.Module):
273    """Loss function for `UniMatchv2Trainer`.
274
275    Extends `SelfTrainingLossWithInvertibleAugmentations` to support the two-student-view scheme
276    of UniMatch v2. When `pred_dim=2`, `prediction` is expected to be a stacked tensor of two
277    student predictions `[pred_s1_inv, pred_s2_inv]`, and the loss is averaged over both views.
278    When `pred_dim=1`, it falls back to the standard single-prediction behaviour.
279
280    Args:
281        loss: The internal loss function used to compare student predictions to pseudo-labels.
282        activation: Optional activation applied to the predictions before the loss.
283    """
284    def __init__(self, loss: nn.Module = DiceLoss(), activation: Optional[nn.Module] = None):
285        super().__init__()
286        self.activation = activation
287        self.loss = loss
288        self.init_kwargs = {}
289
290    def __call__(
291        self,
292        prediction: torch.Tensor,
293        labels: torch.Tensor,
294        label_filter: Optional[torch.Tensor] = None,
295        pred_dim: int = 1,
296    ) -> torch.Tensor:
297        """Compute the UniMatch v2 self-training loss.
298
299        Args:
300            prediction: Student predictions mapped to the reference frame. When `pred_dim=2`,
301                a stacked tensor of shape `(2, B, C, ...)` holding the two strong-view predictions.
302                When `pred_dim=1`, a standard `(B, C, ...)` prediction tensor.
303            labels: The (pseudo) labels, mapped to the same reference frame.
304            label_filter: Optional mask or weight tensor applied to both prediction and labels
305                before the loss is computed.
306            pred_dim: Number of student views. Use `2` for the standard UniMatch v2 dual-view
307                training and `1` for single-view inference or validation.
308
309        Returns:
310            The loss value.
311        """
312
313        assert pred_dim in (1, 2), "pred_dim must be either 1 or 2"
314
315        if self.activation is not None:
316            prediction = self.activation(prediction)
317
318        if pred_dim == 2:
319            if label_filter is None:
320                loss = (self.loss(prediction[0], labels) + self.loss(prediction[1], labels)) / 2
321            else:
322                loss = (self.loss(
323                    prediction[0] * label_filter, labels * label_filter
324                ) + self.loss(prediction[1] * label_filter, labels * label_filter)) / 2
325            return loss
326
327        else:
328            if label_filter is None:
329                loss = self.loss(prediction, labels)
330            else:
331                loss = self.loss(prediction * label_filter, labels * label_filter)
332            return loss
333
334
335class UniMatchv2LossAndMetric(nn.Module):
336    """Loss and metric function for `UniMatchv2Trainer`.
337
338    Extends `SelfTrainingLossAndMetricWithInvertibleAugmentations` to support the two-student-view scheme
339    of UniMatch v2. `pred_dim` depends on how many views the student model processes
340    at the same time. Supports the same dual-view `pred_dim=2` convention: when two student
341    predictions are stacked, loss and metric are each averaged over both views.
342    When `pred_dim=1`, it falls back to the standard single-prediction behaviour.
343
344    Args:
345        loss: The internal loss function used to compare student predictions to pseudo-labels.
346        metric: The internal metric function used to evaluate student predictions against pseudo-labels.
347        activation: Optional activation applied to the predictions before the loss and metric.
348    """
349    def __init__(
350        self,
351        loss: nn.Module = DiceLoss(),
352        metric: nn.Module = DiceLoss(),
353        activation: Optional[nn.Module] = None
354    ):
355        super().__init__()
356        self.activation = activation
357        self.loss = loss
358        self.metric = metric
359        self.init_kwargs = {}
360
361    def __call__(
362        self,
363        prediction: torch.Tensor,
364        labels: torch.Tensor,
365        label_filter: Optional[torch.Tensor] = None,
366        pred_dim: int = 1,
367    ):
368        """Compute the UniMatch v2 self-training loss.
369
370        Args:
371            prediction: Student predictions mapped to the reference frame. When `pred_dim=2`,
372                a stacked tensor of shape `(2, B, C, ...)` holding the two strong-view predictions.
373                When `pred_dim=1`, a standard `(B, C, ...)` prediction tensor.
374            labels: The (pseudo) labels, mapped to the same reference frame.
375            label_filter: Optional mask or weight tensor applied to both prediction and labels
376                before the loss is computed.
377            pred_dim: Number of student views. Use `2` for the standard UniMatch v2 dual-view
378                training and `1` for single-view inference or validation.
379
380        Returns:
381            The loss and metric value.
382        """
383
384        if self.activation is not None:
385            prediction = self.activation(prediction)
386
387        assert pred_dim in (1, 2), "pred_dim must be either 1 or 2"
388
389        if pred_dim == 2:
390            assert len(prediction) == 2, "only implemented for list of len 2"
391            if label_filter is None:
392                loss = (self.loss(prediction[0], labels) + self.loss(prediction[1], labels)) / 2
393            else:
394                loss = (self.loss(
395                    prediction[0] * label_filter, labels * label_filter
396                ) + self.loss(prediction[1] * label_filter, labels * label_filter)) / 2
397            metric = (self.metric(prediction[0], labels) + self.metric(prediction[1], labels)) / 2
398            return loss, metric
399
400        else:
401            if label_filter is None:
402                loss = self.loss(prediction, labels)
403            else:
404                loss = self.loss(prediction * label_filter, labels * label_filter)
405            metric = self.metric(prediction, labels)
406            return loss, metric
class DefaultSelfTrainingLoss(torch.nn.modules.module.Module):
10class DefaultSelfTrainingLoss(nn.Module):
11    """Loss function for self training.
12
13    This loss takes as input a model and its input, as well as (pseudo) labels and potentially
14    a mask for the labels. It then runs prediction with the model and compares the outputs
15    to the (pseudo) labels using an internal loss function. Typically, the labels are derived
16    from the predictions of a teacher model, and the model passed is the student model.
17
18    Args:
19        loss: The internal loss function to use for comparing predictions of the teacher and student model.
20        activation: The activation function to be applied to the prediction before passing it to the loss.
21    """
22    def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None):
23        super().__init__()
24        self.activation = activation
25        self.loss = loss
26        # TODO serialize the class names and kwargs instead
27        self.init_kwargs = {}
28
29    def __call__(
30        self, model: nn.Module, input_: torch.Tensor, labels: torch.Tensor, label_filter: Optional[torch.Tensor] = None
31    ) -> torch.Tensor:
32        """Compute the loss for self-training.
33
34        Args:
35            model: The model.
36            input_: The model inputs for this batch.
37            labels: The (pseudo) labels for this batch.
38            label_filter: A mask to exclude from the loss computation.
39
40        Returns:
41            The loss value.
42        """
43        prediction = model(input_)
44        if self.activation is not None:
45            prediction = self.activation(prediction)
46        if label_filter is None:
47            loss = self.loss(prediction, labels)
48        else:
49            loss = self.loss(prediction * label_filter, labels * label_filter)
50        return loss

Loss function for self training.

This loss takes as input a model and its input, as well as (pseudo) labels and potentially a mask for the labels. It then runs prediction with the model and compares the outputs to the (pseudo) labels using an internal loss function. Typically, the labels are derived from the predictions of a teacher model, and the model passed is the student model.

Arguments:
  • loss: The internal loss function to use for comparing predictions of the teacher and student model.
  • activation: The activation function to be applied to the prediction before passing it to the loss.
DefaultSelfTrainingLoss( loss: torch.nn.modules.module.Module = DiceLoss(), activation: Optional[torch.nn.modules.module.Module] = None)
22    def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None):
23        super().__init__()
24        self.activation = activation
25        self.loss = loss
26        # TODO serialize the class names and kwargs instead
27        self.init_kwargs = {}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

activation
loss
init_kwargs
class DefaultSelfTrainingLossAndMetric(torch.nn.modules.module.Module):
53class DefaultSelfTrainingLossAndMetric(nn.Module):
54    """Loss and metric function for self training.
55
56    Similar to `DefaultSelfTrainingLoss`, but computes loss and metric value in one call
57    to avoid running prediction with the model twice.
58
59    Args:
60        loss: The internal loss function to use for comparing predictions of the teacher and student model.
61        metric: The internal metric function to use for comparing predictions of the teacher and student model.
62        activation: The activation function to be applied to the prediction before passing it to the loss.
63    """
64    def __init__(
65        self,
66        loss: nn.Module = torch_em.loss.DiceLoss(),
67        metric: nn.Module = torch_em.loss.DiceLoss(),
68        activation: Optional[nn.Module] = None
69    ):
70        super().__init__()
71        self.activation = activation
72        self.loss = loss
73        self.metric = metric
74        # TODO serialize the class names and dicts instead
75        self.init_kwargs = {}
76
77    def __call__(self, model, input_, labels, label_filter=None):
78        prediction = model(input_)
79        if self.activation is not None:
80            prediction = self.activation(prediction)
81        if label_filter is None:
82            loss = self.loss(prediction, labels)
83        else:
84            loss = self.loss(prediction * label_filter, labels * label_filter)
85        metric = self.metric(prediction, labels)
86        return loss, metric

Loss and metric function for self training.

Similar to DefaultSelfTrainingLoss, but computes loss and metric value in one call to avoid running prediction with the model twice.

Arguments:
  • loss: The internal loss function to use for comparing predictions of the teacher and student model.
  • metric: The internal metric function to use for comparing predictions of the teacher and student model.
  • activation: The activation function to be applied to the prediction before passing it to the loss.
DefaultSelfTrainingLossAndMetric( loss: torch.nn.modules.module.Module = DiceLoss(), metric: torch.nn.modules.module.Module = DiceLoss(), activation: Optional[torch.nn.modules.module.Module] = None)
64    def __init__(
65        self,
66        loss: nn.Module = torch_em.loss.DiceLoss(),
67        metric: nn.Module = torch_em.loss.DiceLoss(),
68        activation: Optional[nn.Module] = None
69    ):
70        super().__init__()
71        self.activation = activation
72        self.loss = loss
73        self.metric = metric
74        # TODO serialize the class names and dicts instead
75        self.init_kwargs = {}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

activation
loss
metric
init_kwargs
class SelfTrainingLossWithInvertibleAugmentations(torch.nn.modules.module.Module):
172class SelfTrainingLossWithInvertibleAugmentations(nn.Module):
173    """Loss function for self-training with invertible augmentations.
174
175    Variant of `DefaultSelfTrainingLoss` for use with `FixMatchTrainerWithInvertibleAugmentations`
176    and `MeanTeacherTrainerWithInvertibleAugmentations`. Unlike `DefaultSelfTrainingLoss`, this loss
177    receives pre-computed predictions directly rather than a model and input, because the trainer
178    already applies the model and invertible augmentations before calling the loss.
179
180    Args:
181        loss: The internal loss function used to compare student predictions to pseudo-labels.
182        activation: Optional activation applied to the prediction before the loss.
183    """
184    def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None):
185        super().__init__()
186        self.activation = activation
187        self.loss = loss
188        # TODO serialize the class names and kwargs instead
189        self.init_kwargs = {}
190
191    def __call__(
192        self,
193        prediction: torch.Tensor,
194        labels: torch.Tensor,
195        label_filter: Optional[torch.Tensor] = None,
196    ) -> torch.Tensor:
197        """Compute the self-training loss.
198
199        Args:
200            prediction: Student model predictions, already mapped to the reference frame
201                via the inverse augmentation transform.
202            labels: The (pseudo) labels, mapped to the same reference frame.
203            label_filter: Optional mask or weight tensor. Where provided, both prediction
204                and labels are multiplied by this tensor before the loss is computed.
205
206        Returns:
207            The loss value.
208        """
209
210        if self.activation is not None:
211            prediction = self.activation(prediction)
212        if label_filter is None:
213            loss = self.loss(prediction, labels)
214        else:
215            loss = self.loss(prediction * label_filter, labels * label_filter)
216        return loss

Loss function for self-training with invertible augmentations.

Variant of DefaultSelfTrainingLoss for use with FixMatchTrainerWithInvertibleAugmentations and MeanTeacherTrainerWithInvertibleAugmentations. Unlike DefaultSelfTrainingLoss, this loss receives pre-computed predictions directly rather than a model and input, because the trainer already applies the model and invertible augmentations before calling the loss.

Arguments:
  • loss: The internal loss function used to compare student predictions to pseudo-labels.
  • activation: Optional activation applied to the prediction before the loss.
SelfTrainingLossWithInvertibleAugmentations( loss: torch.nn.modules.module.Module = DiceLoss(), activation: Optional[torch.nn.modules.module.Module] = None)
184    def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None):
185        super().__init__()
186        self.activation = activation
187        self.loss = loss
188        # TODO serialize the class names and kwargs instead
189        self.init_kwargs = {}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

activation
loss
init_kwargs
class SelfTrainingLossAndMetricWithInvertibleAugmentations(torch.nn.modules.module.Module):
219class SelfTrainingLossAndMetricWithInvertibleAugmentations(nn.Module):
220    """Loss and metric function for self-training with invertible augmentations.
221
222    Variant of `DefaultSelfTrainingLossAndMetric` for use with
223    `FixMatchTrainerWithInvertibleAugmentations` and `MeanTeacherTrainerWithInvertibleAugmentations`.
224    Computes both loss and metric in a single call from pre-computed predictions, avoiding a
225    second forward pass. Used during validation where the trainer already holds the predictions.
226
227    Args:
228        loss: The internal loss function used to compare student predictions to pseudo-labels.
229        metric: The internal metric function used to evaluate student predictions against pseudo-labels.
230        activation: Optional activation applied to the prediction before the loss and metric.
231    """
232    def __init__(
233        self,
234        loss: nn.Module = torch_em.loss.DiceLoss(),
235        metric: nn.Module = torch_em.loss.DiceLoss(),
236        activation: Optional[nn.Module] = None
237    ):
238        super().__init__()
239        self.activation = activation
240        self.loss = loss
241        self.metric = metric
242        # TODO serialize the class names and dicts instead
243        self.init_kwargs = {}
244
245    def __call__(
246        self,
247        prediction: torch.Tensor,
248        labels: torch.Tensor,
249        label_filter: Optional[torch.Tensor] = None,
250    ):
251        """Compute the self-training loss.
252
253        Args:
254            prediction: Student model predictions, already mapped to the reference frame
255                via the inverse augmentation transform.
256            labels: The (pseudo) labels, mapped to the same reference frame.
257            label_filter: Optional mask or weight tensor. Where provided, both prediction
258                and labels are multiplied by this tensor before the loss is computed.
259
260        Returns:
261            The loss and metric value.
262        """
263        if self.activation is not None:
264            prediction = self.activation(prediction)
265        if label_filter is None:
266            loss = self.loss(prediction, labels)
267        else:
268            loss = self.loss(prediction * label_filter, labels * label_filter)
269        metric = self.metric(prediction, labels)
270        return loss, metric

Loss and metric function for self-training with invertible augmentations.

Variant of DefaultSelfTrainingLossAndMetric for use with FixMatchTrainerWithInvertibleAugmentations and MeanTeacherTrainerWithInvertibleAugmentations. Computes both loss and metric in a single call from pre-computed predictions, avoiding a second forward pass. Used during validation where the trainer already holds the predictions.

Arguments:
  • loss: The internal loss function used to compare student predictions to pseudo-labels.
  • metric: The internal metric function used to evaluate student predictions against pseudo-labels.
  • activation: Optional activation applied to the prediction before the loss and metric.
SelfTrainingLossAndMetricWithInvertibleAugmentations( loss: torch.nn.modules.module.Module = DiceLoss(), metric: torch.nn.modules.module.Module = DiceLoss(), activation: Optional[torch.nn.modules.module.Module] = None)
232    def __init__(
233        self,
234        loss: nn.Module = torch_em.loss.DiceLoss(),
235        metric: nn.Module = torch_em.loss.DiceLoss(),
236        activation: Optional[nn.Module] = None
237    ):
238        super().__init__()
239        self.activation = activation
240        self.loss = loss
241        self.metric = metric
242        # TODO serialize the class names and dicts instead
243        self.init_kwargs = {}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

activation
loss
metric
init_kwargs
class UniMatchv2Loss(torch.nn.modules.module.Module):
273class UniMatchv2Loss(nn.Module):
274    """Loss function for `UniMatchv2Trainer`.
275
276    Extends `SelfTrainingLossWithInvertibleAugmentations` to support the two-student-view scheme
277    of UniMatch v2. When `pred_dim=2`, `prediction` is expected to be a stacked tensor of two
278    student predictions `[pred_s1_inv, pred_s2_inv]`, and the loss is averaged over both views.
279    When `pred_dim=1`, it falls back to the standard single-prediction behaviour.
280
281    Args:
282        loss: The internal loss function used to compare student predictions to pseudo-labels.
283        activation: Optional activation applied to the predictions before the loss.
284    """
285    def __init__(self, loss: nn.Module = DiceLoss(), activation: Optional[nn.Module] = None):
286        super().__init__()
287        self.activation = activation
288        self.loss = loss
289        self.init_kwargs = {}
290
291    def __call__(
292        self,
293        prediction: torch.Tensor,
294        labels: torch.Tensor,
295        label_filter: Optional[torch.Tensor] = None,
296        pred_dim: int = 1,
297    ) -> torch.Tensor:
298        """Compute the UniMatch v2 self-training loss.
299
300        Args:
301            prediction: Student predictions mapped to the reference frame. When `pred_dim=2`,
302                a stacked tensor of shape `(2, B, C, ...)` holding the two strong-view predictions.
303                When `pred_dim=1`, a standard `(B, C, ...)` prediction tensor.
304            labels: The (pseudo) labels, mapped to the same reference frame.
305            label_filter: Optional mask or weight tensor applied to both prediction and labels
306                before the loss is computed.
307            pred_dim: Number of student views. Use `2` for the standard UniMatch v2 dual-view
308                training and `1` for single-view inference or validation.
309
310        Returns:
311            The loss value.
312        """
313
314        assert pred_dim in (1, 2), "pred_dim must be either 1 or 2"
315
316        if self.activation is not None:
317            prediction = self.activation(prediction)
318
319        if pred_dim == 2:
320            if label_filter is None:
321                loss = (self.loss(prediction[0], labels) + self.loss(prediction[1], labels)) / 2
322            else:
323                loss = (self.loss(
324                    prediction[0] * label_filter, labels * label_filter
325                ) + self.loss(prediction[1] * label_filter, labels * label_filter)) / 2
326            return loss
327
328        else:
329            if label_filter is None:
330                loss = self.loss(prediction, labels)
331            else:
332                loss = self.loss(prediction * label_filter, labels * label_filter)
333            return loss

Loss function for UniMatchv2Trainer.

Extends SelfTrainingLossWithInvertibleAugmentations to support the two-student-view scheme of UniMatch v2. When pred_dim=2, prediction is expected to be a stacked tensor of two student predictions [pred_s1_inv, pred_s2_inv], and the loss is averaged over both views. When pred_dim=1, it falls back to the standard single-prediction behaviour.

Arguments:
  • loss: The internal loss function used to compare student predictions to pseudo-labels.
  • activation: Optional activation applied to the predictions before the loss.
UniMatchv2Loss( loss: torch.nn.modules.module.Module = DiceLoss(), activation: Optional[torch.nn.modules.module.Module] = None)
285    def __init__(self, loss: nn.Module = DiceLoss(), activation: Optional[nn.Module] = None):
286        super().__init__()
287        self.activation = activation
288        self.loss = loss
289        self.init_kwargs = {}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

activation
loss
init_kwargs
class UniMatchv2LossAndMetric(torch.nn.modules.module.Module):
336class UniMatchv2LossAndMetric(nn.Module):
337    """Loss and metric function for `UniMatchv2Trainer`.
338
339    Extends `SelfTrainingLossAndMetricWithInvertibleAugmentations` to support the two-student-view scheme
340    of UniMatch v2. `pred_dim` depends on how many views the student model processes
341    at the same time. Supports the same dual-view `pred_dim=2` convention: when two student
342    predictions are stacked, loss and metric are each averaged over both views.
343    When `pred_dim=1`, it falls back to the standard single-prediction behaviour.
344
345    Args:
346        loss: The internal loss function used to compare student predictions to pseudo-labels.
347        metric: The internal metric function used to evaluate student predictions against pseudo-labels.
348        activation: Optional activation applied to the predictions before the loss and metric.
349    """
350    def __init__(
351        self,
352        loss: nn.Module = DiceLoss(),
353        metric: nn.Module = DiceLoss(),
354        activation: Optional[nn.Module] = None
355    ):
356        super().__init__()
357        self.activation = activation
358        self.loss = loss
359        self.metric = metric
360        self.init_kwargs = {}
361
362    def __call__(
363        self,
364        prediction: torch.Tensor,
365        labels: torch.Tensor,
366        label_filter: Optional[torch.Tensor] = None,
367        pred_dim: int = 1,
368    ):
369        """Compute the UniMatch v2 self-training loss.
370
371        Args:
372            prediction: Student predictions mapped to the reference frame. When `pred_dim=2`,
373                a stacked tensor of shape `(2, B, C, ...)` holding the two strong-view predictions.
374                When `pred_dim=1`, a standard `(B, C, ...)` prediction tensor.
375            labels: The (pseudo) labels, mapped to the same reference frame.
376            label_filter: Optional mask or weight tensor applied to both prediction and labels
377                before the loss is computed.
378            pred_dim: Number of student views. Use `2` for the standard UniMatch v2 dual-view
379                training and `1` for single-view inference or validation.
380
381        Returns:
382            The loss and metric value.
383        """
384
385        if self.activation is not None:
386            prediction = self.activation(prediction)
387
388        assert pred_dim in (1, 2), "pred_dim must be either 1 or 2"
389
390        if pred_dim == 2:
391            assert len(prediction) == 2, "only implemented for list of len 2"
392            if label_filter is None:
393                loss = (self.loss(prediction[0], labels) + self.loss(prediction[1], labels)) / 2
394            else:
395                loss = (self.loss(
396                    prediction[0] * label_filter, labels * label_filter
397                ) + self.loss(prediction[1] * label_filter, labels * label_filter)) / 2
398            metric = (self.metric(prediction[0], labels) + self.metric(prediction[1], labels)) / 2
399            return loss, metric
400
401        else:
402            if label_filter is None:
403                loss = self.loss(prediction, labels)
404            else:
405                loss = self.loss(prediction * label_filter, labels * label_filter)
406            metric = self.metric(prediction, labels)
407            return loss, metric

Loss and metric function for UniMatchv2Trainer.

Extends SelfTrainingLossAndMetricWithInvertibleAugmentations to support the two-student-view scheme of UniMatch v2. pred_dim depends on how many views the student model processes at the same time. Supports the same dual-view pred_dim=2 convention: when two student predictions are stacked, loss and metric are each averaged over both views. When pred_dim=1, it falls back to the standard single-prediction behaviour.

Arguments:
  • loss: The internal loss function used to compare student predictions to pseudo-labels.
  • metric: The internal metric function used to evaluate student predictions against pseudo-labels.
  • activation: Optional activation applied to the predictions before the loss and metric.
UniMatchv2LossAndMetric( loss: torch.nn.modules.module.Module = DiceLoss(), metric: torch.nn.modules.module.Module = DiceLoss(), activation: Optional[torch.nn.modules.module.Module] = None)
350    def __init__(
351        self,
352        loss: nn.Module = DiceLoss(),
353        metric: nn.Module = DiceLoss(),
354        activation: Optional[nn.Module] = None
355    ):
356        super().__init__()
357        self.activation = activation
358        self.loss = loss
359        self.metric = metric
360        self.init_kwargs = {}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

activation
loss
metric
init_kwargs