torch_em.self_training.loss
1from typing import Optional 2 3import torch 4import torch_em 5import torch.nn as nn 6from torch_em.loss import DiceLoss 7 8 9class DefaultSelfTrainingLoss(nn.Module): 10 """Loss function for self training. 11 12 This loss takes as input a model and its input, as well as (pseudo) labels and potentially 13 a mask for the labels. It then runs prediction with the model and compares the outputs 14 to the (pseudo) labels using an internal loss function. Typically, the labels are derived 15 from the predictions of a teacher model, and the model passed is the student model. 16 17 Args: 18 loss: The internal loss function to use for comparing predictions of the teacher and student model. 19 activation: The activation function to be applied to the prediction before passing it to the loss. 20 """ 21 def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None): 22 super().__init__() 23 self.activation = activation 24 self.loss = loss 25 # TODO serialize the class names and kwargs instead 26 self.init_kwargs = {} 27 28 def __call__( 29 self, model: nn.Module, input_: torch.Tensor, labels: torch.Tensor, label_filter: Optional[torch.Tensor] = None 30 ) -> torch.Tensor: 31 """Compute the loss for self-training. 32 33 Args: 34 model: The model. 35 input_: The model inputs for this batch. 36 labels: The (pseudo) labels for this batch. 37 label_filter: A mask to exclude from the loss computation. 38 39 Returns: 40 The loss value. 41 """ 42 prediction = model(input_) 43 if self.activation is not None: 44 prediction = self.activation(prediction) 45 if label_filter is None: 46 loss = self.loss(prediction, labels) 47 else: 48 loss = self.loss(prediction * label_filter, labels * label_filter) 49 return loss 50 51 52class DefaultSelfTrainingLossAndMetric(nn.Module): 53 """Loss and metric function for self training. 54 55 Similar to `DefaultSelfTrainingLoss`, but computes loss and metric value in one call 56 to avoid running prediction with the model twice. 57 58 Args: 59 loss: The internal loss function to use for comparing predictions of the teacher and student model. 60 metric: The internal metric function to use for comparing predictions of the teacher and student model. 61 activation: The activation function to be applied to the prediction before passing it to the loss. 62 """ 63 def __init__( 64 self, 65 loss: nn.Module = torch_em.loss.DiceLoss(), 66 metric: nn.Module = torch_em.loss.DiceLoss(), 67 activation: Optional[nn.Module] = None 68 ): 69 super().__init__() 70 self.activation = activation 71 self.loss = loss 72 self.metric = metric 73 # TODO serialize the class names and dicts instead 74 self.init_kwargs = {} 75 76 def __call__(self, model, input_, labels, label_filter=None): 77 prediction = model(input_) 78 if self.activation is not None: 79 prediction = self.activation(prediction) 80 if label_filter is None: 81 loss = self.loss(prediction, labels) 82 else: 83 loss = self.loss(prediction * label_filter, labels * label_filter) 84 metric = self.metric(prediction, labels) 85 return loss, metric 86 87 88# TODO: The probabilistic U-Net related code should be refactored to `torch_em.loss` 89# and should be documented properly. 90 91 92def l2_regularisation(m): 93 """@private 94 """ 95 l2_reg = None 96 for W in m.parameters(): 97 if l2_reg is None: 98 l2_reg = W.norm(2) 99 else: 100 l2_reg = l2_reg + W.norm(2) 101 return l2_reg 102 103 104class ProbabilisticUNetLoss(nn.Module): 105 """@private 106 """ 107 # """Loss function for Probabilistic UNet 108 109 # Args: 110 # # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.) 111 # loss [nn.Module] - the loss function to be used. (default: None) 112 # """ 113 def __init__(self, loss=None): 114 super().__init__() 115 self.loss = loss 116 117 def __call__(self, model, input_, labels, label_filter=None): 118 model.forward(input_, labels) 119 120 if self.loss is None: 121 elbo = model.elbo(labels, label_filter) 122 reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ 123 l2_regularisation(model.fcomb.layers) 124 loss = -elbo + 1e-5 * reg_loss 125 126 return loss 127 128 129class ProbabilisticUNetLossAndMetric(nn.Module): 130 """@private 131 """ 132 # """Loss and metric function for Probabilistic UNet. 133 134 # Args: 135 # # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.) 136 # loss [nn.Module] - the loss function to be used. (default: None) 137 138 # metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss) 139 # activation [nn.Module, callable] - the activation function to be applied to the prediction 140 # before evaluating the average predictions. (default: None) 141 # """ 142 def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16): 143 super().__init__() 144 self.activation = activation 145 self.metric = metric 146 self.loss = loss 147 self.prior_samples = prior_samples 148 149 def __call__(self, model, input_, labels, label_filter=None): 150 model.forward(input_, labels) 151 152 if self.loss is None: 153 elbo = model.elbo(labels, label_filter) 154 reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ 155 l2_regularisation(model.fcomb.layers) 156 loss = -elbo + 1e-5 * reg_loss 157 158 samples_per_distribution = [] 159 for _ in range(self.prior_samples): 160 samples = model.sample(testing=False) 161 if self.activation is not None: 162 samples = self.activation(samples) 163 samples_per_distribution.append(samples) 164 165 avg_samples = torch.stack(samples_per_distribution, dim=0).sum(dim=0) / len(samples_per_distribution) 166 metric = self.metric(avg_samples, labels) 167 168 return loss, metric
class
DefaultSelfTrainingLoss(torch.nn.modules.module.Module):
10class DefaultSelfTrainingLoss(nn.Module): 11 """Loss function for self training. 12 13 This loss takes as input a model and its input, as well as (pseudo) labels and potentially 14 a mask for the labels. It then runs prediction with the model and compares the outputs 15 to the (pseudo) labels using an internal loss function. Typically, the labels are derived 16 from the predictions of a teacher model, and the model passed is the student model. 17 18 Args: 19 loss: The internal loss function to use for comparing predictions of the teacher and student model. 20 activation: The activation function to be applied to the prediction before passing it to the loss. 21 """ 22 def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None): 23 super().__init__() 24 self.activation = activation 25 self.loss = loss 26 # TODO serialize the class names and kwargs instead 27 self.init_kwargs = {} 28 29 def __call__( 30 self, model: nn.Module, input_: torch.Tensor, labels: torch.Tensor, label_filter: Optional[torch.Tensor] = None 31 ) -> torch.Tensor: 32 """Compute the loss for self-training. 33 34 Args: 35 model: The model. 36 input_: The model inputs for this batch. 37 labels: The (pseudo) labels for this batch. 38 label_filter: A mask to exclude from the loss computation. 39 40 Returns: 41 The loss value. 42 """ 43 prediction = model(input_) 44 if self.activation is not None: 45 prediction = self.activation(prediction) 46 if label_filter is None: 47 loss = self.loss(prediction, labels) 48 else: 49 loss = self.loss(prediction * label_filter, labels * label_filter) 50 return loss
Loss function for self training.
This loss takes as input a model and its input, as well as (pseudo) labels and potentially a mask for the labels. It then runs prediction with the model and compares the outputs to the (pseudo) labels using an internal loss function. Typically, the labels are derived from the predictions of a teacher model, and the model passed is the student model.
Arguments:
- loss: The internal loss function to use for comparing predictions of the teacher and student model.
- activation: The activation function to be applied to the prediction before passing it to the loss.
DefaultSelfTrainingLoss( loss: torch.nn.modules.module.Module = DiceLoss(), activation: Optional[torch.nn.modules.module.Module] = None)
22 def __init__(self, loss: nn.Module = torch_em.loss.DiceLoss(), activation: Optional[nn.Module] = None): 23 super().__init__() 24 self.activation = activation 25 self.loss = loss 26 # TODO serialize the class names and kwargs instead 27 self.init_kwargs = {}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
class
DefaultSelfTrainingLossAndMetric(torch.nn.modules.module.Module):
53class DefaultSelfTrainingLossAndMetric(nn.Module): 54 """Loss and metric function for self training. 55 56 Similar to `DefaultSelfTrainingLoss`, but computes loss and metric value in one call 57 to avoid running prediction with the model twice. 58 59 Args: 60 loss: The internal loss function to use for comparing predictions of the teacher and student model. 61 metric: The internal metric function to use for comparing predictions of the teacher and student model. 62 activation: The activation function to be applied to the prediction before passing it to the loss. 63 """ 64 def __init__( 65 self, 66 loss: nn.Module = torch_em.loss.DiceLoss(), 67 metric: nn.Module = torch_em.loss.DiceLoss(), 68 activation: Optional[nn.Module] = None 69 ): 70 super().__init__() 71 self.activation = activation 72 self.loss = loss 73 self.metric = metric 74 # TODO serialize the class names and dicts instead 75 self.init_kwargs = {} 76 77 def __call__(self, model, input_, labels, label_filter=None): 78 prediction = model(input_) 79 if self.activation is not None: 80 prediction = self.activation(prediction) 81 if label_filter is None: 82 loss = self.loss(prediction, labels) 83 else: 84 loss = self.loss(prediction * label_filter, labels * label_filter) 85 metric = self.metric(prediction, labels) 86 return loss, metric
Loss and metric function for self training.
Similar to DefaultSelfTrainingLoss
, but computes loss and metric value in one call
to avoid running prediction with the model twice.
Arguments:
- loss: The internal loss function to use for comparing predictions of the teacher and student model.
- metric: The internal metric function to use for comparing predictions of the teacher and student model.
- activation: The activation function to be applied to the prediction before passing it to the loss.
DefaultSelfTrainingLossAndMetric( loss: torch.nn.modules.module.Module = DiceLoss(), metric: torch.nn.modules.module.Module = DiceLoss(), activation: Optional[torch.nn.modules.module.Module] = None)
64 def __init__( 65 self, 66 loss: nn.Module = torch_em.loss.DiceLoss(), 67 metric: nn.Module = torch_em.loss.DiceLoss(), 68 activation: Optional[nn.Module] = None 69 ): 70 super().__init__() 71 self.activation = activation 72 self.loss = loss 73 self.metric = metric 74 # TODO serialize the class names and dicts instead 75 self.init_kwargs = {}
Initialize internal Module state, shared by both nn.Module and ScriptModule.