torch_em.self_training.loss

  1from typing import Optional
  2
  3import torch
  4import torch_em
  5import torch.nn as nn
  6from torch_em.loss import DiceLoss
  7
  8
  9class DefaultSelfTrainingLoss(nn.Module):
 10    """Loss function for self training.
 11
 12    This loss takes as input a model and its input, as well as (pseudo) labels and potentially
 13    a mask for the labels. It then runs prediction with the model and compares the outputs
 14    to the (pseudo) labels using an internal loss function. Typically, the labels are derived
 15    from the predictions of a teacher model, and the model passed is the student model.
 16
 17    Args:
 18        loss: The internal loss function to use for comparing predictions of the teacher and student model.
 19        activation: The activation function to be applied to the prediction before passing it to the loss.
 20    """
 21    def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None):
 22        super().__init__()
 23        self.activation = activation
 24        self.loss = loss
 25        # TODO serialize the class names and kwargs instead
 26        self.init_kwargs = {}
 27
 28    def __call__(
 29        self, model: nn.Module, input_: torch.Tensor, labels: torch.Tensor, label_filter: Optional[torch.Tensor] = None
 30    ) -> torch.Tensor:
 31        """Compute the loss for self-training.
 32
 33        Args:
 34            model: The model.
 35            input_: The model inputs for this batch.
 36            labels: The (pseudo) labels for this batch.
 37            label_filter: A mask to exclude from the loss computation.
 38
 39        Returns:
 40            The loss value.
 41        """
 42        prediction = model(input_)
 43        if self.activation is not None:
 44            prediction = self.activation(prediction)
 45        if label_filter is None:
 46            loss = self.loss(prediction, labels)
 47        else:
 48            loss = self.loss(prediction * label_filter, labels * label_filter)
 49        return loss
 50
 51
 52class DefaultSelfTrainingLossAndMetric(nn.Module):
 53    """Loss and metric function for self training.
 54
 55    Similar to `DefaultSelfTrainingLoss`, but computes loss and metric value in one call
 56    to avoid running prediction with the model twice.
 57
 58    Args:
 59        loss: The internal loss function to use for comparing predictions of the teacher and student model.
 60        metric: The internal metric function to use for comparing predictions of the teacher and student model.
 61        activation: The activation function to be applied to the prediction before passing it to the loss.
 62    """
 63    def __init__(
 64        self,
 65        loss: nn.Module = torch_em.loss.DiceLoss(),
 66        metric: nn.Module = torch_em.loss.DiceLoss(),
 67        activation: Optional[nn.Module] = None
 68    ):
 69        super().__init__()
 70        self.activation = activation
 71        self.loss = loss
 72        self.metric = metric
 73        # TODO serialize the class names and dicts instead
 74        self.init_kwargs = {}
 75
 76    def __call__(self, model, input_, labels, label_filter=None):
 77        prediction = model(input_)
 78        if self.activation is not None:
 79            prediction = self.activation(prediction)
 80        if label_filter is None:
 81            loss = self.loss(prediction, labels)
 82        else:
 83            loss = self.loss(prediction * label_filter, labels * label_filter)
 84        metric = self.metric(prediction, labels)
 85        return loss, metric
 86
 87
 88# TODO: The probabilistic U-Net related code should be refactored to `torch_em.loss`
 89# and should be documented properly.
 90
 91
 92def l2_regularisation(m):
 93    """@private
 94    """
 95    l2_reg = None
 96    for W in m.parameters():
 97        if l2_reg is None:
 98            l2_reg = W.norm(2)
 99        else:
100            l2_reg = l2_reg + W.norm(2)
101    return l2_reg
102
103
104class ProbabilisticUNetLoss(nn.Module):
105    """@private
106    """
107    # """Loss function for Probabilistic UNet
108
109    # Args:
110    #     # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
111    #     loss [nn.Module] - the loss function to be used. (default: None)
112    # """
113    def __init__(self, loss=None):
114        super().__init__()
115        self.loss = loss
116
117    def __call__(self, model, input_, labels, label_filter=None):
118        model.forward(input_, labels)
119
120        if self.loss is None:
121            elbo = model.elbo(labels, label_filter)
122            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
123                l2_regularisation(model.fcomb.layers)
124            loss = -elbo + 1e-5 * reg_loss
125
126        return loss
127
128
129class ProbabilisticUNetLossAndMetric(nn.Module):
130    """@private
131    """
132    # """Loss and metric function for Probabilistic UNet.
133
134    # Args:
135    #     # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
136    #     loss [nn.Module] - the loss function to be used. (default: None)
137
138    #     metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
139    #     activation [nn.Module, callable] - the activation function to be applied to the prediction
140    #         before evaluating the average predictions. (default: None)
141    # """
142    def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16):
143        super().__init__()
144        self.activation = activation
145        self.metric = metric
146        self.loss = loss
147        self.prior_samples = prior_samples
148
149    def __call__(self, model, input_, labels, label_filter=None):
150        model.forward(input_, labels)
151
152        if self.loss is None:
153            elbo = model.elbo(labels, label_filter)
154            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
155                l2_regularisation(model.fcomb.layers)
156            loss = -elbo + 1e-5 * reg_loss
157
158        samples_per_distribution = []
159        for _ in range(self.prior_samples):
160            samples = model.sample(testing=False)
161            if self.activation is not None:
162                samples = self.activation(samples)
163            samples_per_distribution.append(samples)
164
165        avg_samples = torch.stack(samples_per_distribution, dim=0).sum(dim=0) / len(samples_per_distribution)
166        metric = self.metric(avg_samples, labels)
167
168        return loss, metric
class DefaultSelfTrainingLoss(torch.nn.modules.module.Module):
10class DefaultSelfTrainingLoss(nn.Module):
11    """Loss function for self training.
12
13    This loss takes as input a model and its input, as well as (pseudo) labels and potentially
14    a mask for the labels. It then runs prediction with the model and compares the outputs
15    to the (pseudo) labels using an internal loss function. Typically, the labels are derived
16    from the predictions of a teacher model, and the model passed is the student model.
17
18    Args:
19        loss: The internal loss function to use for comparing predictions of the teacher and student model.
20        activation: The activation function to be applied to the prediction before passing it to the loss.
21    """
22    def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None):
23        super().__init__()
24        self.activation = activation
25        self.loss = loss
26        # TODO serialize the class names and kwargs instead
27        self.init_kwargs = {}
28
29    def __call__(
30        self, model: nn.Module, input_: torch.Tensor, labels: torch.Tensor, label_filter: Optional[torch.Tensor] = None
31    ) -> torch.Tensor:
32        """Compute the loss for self-training.
33
34        Args:
35            model: The model.
36            input_: The model inputs for this batch.
37            labels: The (pseudo) labels for this batch.
38            label_filter: A mask to exclude from the loss computation.
39
40        Returns:
41            The loss value.
42        """
43        prediction = model(input_)
44        if self.activation is not None:
45            prediction = self.activation(prediction)
46        if label_filter is None:
47            loss = self.loss(prediction, labels)
48        else:
49            loss = self.loss(prediction * label_filter, labels * label_filter)
50        return loss

Loss function for self training.

This loss takes as input a model and its input, as well as (pseudo) labels and potentially a mask for the labels. It then runs prediction with the model and compares the outputs to the (pseudo) labels using an internal loss function. Typically, the labels are derived from the predictions of a teacher model, and the model passed is the student model.

Arguments:
  • loss: The internal loss function to use for comparing predictions of the teacher and student model.
  • activation: The activation function to be applied to the prediction before passing it to the loss.
DefaultSelfTrainingLoss( loss: torch.nn.modules.module.Module = DiceLoss(), activation: Optional[torch.nn.modules.module.Module] = None)
22    def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None):
23        super().__init__()
24        self.activation = activation
25        self.loss = loss
26        # TODO serialize the class names and kwargs instead
27        self.init_kwargs = {}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

activation
loss
init_kwargs
class DefaultSelfTrainingLossAndMetric(torch.nn.modules.module.Module):
53class DefaultSelfTrainingLossAndMetric(nn.Module):
54    """Loss and metric function for self training.
55
56    Similar to `DefaultSelfTrainingLoss`, but computes loss and metric value in one call
57    to avoid running prediction with the model twice.
58
59    Args:
60        loss: The internal loss function to use for comparing predictions of the teacher and student model.
61        metric: The internal metric function to use for comparing predictions of the teacher and student model.
62        activation: The activation function to be applied to the prediction before passing it to the loss.
63    """
64    def __init__(
65        self,
66        loss: nn.Module = torch_em.loss.DiceLoss(),
67        metric: nn.Module = torch_em.loss.DiceLoss(),
68        activation: Optional[nn.Module] = None
69    ):
70        super().__init__()
71        self.activation = activation
72        self.loss = loss
73        self.metric = metric
74        # TODO serialize the class names and dicts instead
75        self.init_kwargs = {}
76
77    def __call__(self, model, input_, labels, label_filter=None):
78        prediction = model(input_)
79        if self.activation is not None:
80            prediction = self.activation(prediction)
81        if label_filter is None:
82            loss = self.loss(prediction, labels)
83        else:
84            loss = self.loss(prediction * label_filter, labels * label_filter)
85        metric = self.metric(prediction, labels)
86        return loss, metric

Loss and metric function for self training.

Similar to DefaultSelfTrainingLoss, but computes loss and metric value in one call to avoid running prediction with the model twice.

Arguments:
  • loss: The internal loss function to use for comparing predictions of the teacher and student model.
  • metric: The internal metric function to use for comparing predictions of the teacher and student model.
  • activation: The activation function to be applied to the prediction before passing it to the loss.
DefaultSelfTrainingLossAndMetric( loss: torch.nn.modules.module.Module = DiceLoss(), metric: torch.nn.modules.module.Module = DiceLoss(), activation: Optional[torch.nn.modules.module.Module] = None)
64    def __init__(
65        self,
66        loss: nn.Module = torch_em.loss.DiceLoss(),
67        metric: nn.Module = torch_em.loss.DiceLoss(),
68        activation: Optional[nn.Module] = None
69    ):
70        super().__init__()
71        self.activation = activation
72        self.loss = loss
73        self.metric = metric
74        # TODO serialize the class names and dicts instead
75        self.init_kwargs = {}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

activation
loss
metric
init_kwargs