torch_em.self_training.loss

  1import torch
  2import torch_em
  3import torch.nn as nn
  4from torch_em.loss import DiceLoss
  5
  6
  7class DefaultSelfTrainingLoss(nn.Module):
  8    """Loss function for self training.
  9
 10    Parameters:
 11        loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss)
 12        activation [nn.Module, callable] - the activation function to be applied to the prediction
 13            before passing it to the loss. (default: None)
 14    """
 15    def __init__(self, loss=torch_em.loss.DiceLoss(), activation=None):
 16        super().__init__()
 17        self.activation = activation
 18        self.loss = loss
 19        # TODO serialize the class names and kwargs instead
 20        self.init_kwargs = {}
 21
 22    def __call__(self, model, input_, labels, label_filter=None):
 23        prediction = model(input_)
 24        if self.activation is not None:
 25            prediction = self.activation(prediction)
 26        if label_filter is None:
 27            loss = self.loss(prediction, labels)
 28        else:
 29            loss = self.loss(prediction * label_filter, labels * label_filter)
 30        return loss
 31
 32
 33class DefaultSelfTrainingLossAndMetric(nn.Module):
 34    """Loss and metric function for self training.
 35
 36    Parameters:
 37        loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss)
 38        metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
 39        activation [nn.Module, callable] - the activation function to be applied to the prediction
 40            before passing it to the loss. (default: None)
 41    """
 42    def __init__(self, loss=torch_em.loss.DiceLoss(), metric=torch_em.loss.DiceLoss(), activation=None):
 43        super().__init__()
 44        self.activation = activation
 45        self.loss = loss
 46        self.metric = metric
 47        # TODO serialize the class names and dicts instead
 48        self.init_kwargs = {}
 49
 50    def __call__(self, model, input_, labels, label_filter=None):
 51        prediction = model(input_)
 52        if self.activation is not None:
 53            prediction = self.activation(prediction)
 54        if label_filter is None:
 55            loss = self.loss(prediction, labels)
 56        else:
 57            loss = self.loss(prediction * label_filter, labels * label_filter)
 58        metric = self.metric(prediction, labels)
 59        return loss, metric
 60
 61
 62def l2_regularisation(m):
 63    l2_reg = None
 64
 65    for W in m.parameters():
 66        if l2_reg is None:
 67            l2_reg = W.norm(2)
 68        else:
 69            l2_reg = l2_reg + W.norm(2)
 70    return l2_reg
 71
 72
 73class ProbabilisticUNetLoss(nn.Module):
 74    """
 75    Loss function for Probabilistic UNet
 76
 77    Parameters :
 78        # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
 79        loss [nn.Module] - the loss function to be used. (default: None)
 80    """
 81    def __init__(self, loss=None):
 82        super().__init__()
 83        self.loss = loss
 84
 85    def __call__(self, model, input_, labels, label_filter=None):
 86        model.forward(input_, labels)
 87
 88        if self.loss is None:
 89            elbo = model.elbo(labels, label_filter)
 90            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
 91                l2_regularisation(model.fcomb.layers)
 92            loss = -elbo + 1e-5 * reg_loss
 93
 94        return loss
 95
 96
 97class ProbabilisticUNetLossAndMetric(nn.Module):
 98    """Loss and metric function for Probabilistic UNet.
 99
100    Parameters:
101        # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
102        loss [nn.Module] - the loss function to be used. (default: None)
103
104        metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
105        activation [nn.Module, callable] - the activation function to be applied to the prediction
106            before evaluating the average predictions. (default: None)
107    """
108    def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16):
109        super().__init__()
110        self.activation = activation
111        self.metric = metric
112        self.loss = loss
113        self.prior_samples = prior_samples
114
115    def __call__(self, model, input_, labels, label_filter=None):
116        model.forward(input_, labels)
117
118        if self.loss is None:
119            elbo = model.elbo(labels, label_filter)
120            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
121                l2_regularisation(model.fcomb.layers)
122            loss = -elbo + 1e-5 * reg_loss
123
124        samples_per_distribution = []
125        for _ in range(self.prior_samples):
126            samples = model.sample(testing=False)
127            if self.activation is not None:
128                samples = self.activation(samples)
129            samples_per_distribution.append(samples)
130
131        avg_samples = torch.stack(samples_per_distribution, dim=0).sum(dim=0) / len(samples_per_distribution)
132        metric = self.metric(avg_samples, labels)
133
134        return loss, metric
class DefaultSelfTrainingLoss(torch.nn.modules.module.Module):
 8class DefaultSelfTrainingLoss(nn.Module):
 9    """Loss function for self training.
10
11    Parameters:
12        loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss)
13        activation [nn.Module, callable] - the activation function to be applied to the prediction
14            before passing it to the loss. (default: None)
15    """
16    def __init__(self, loss=torch_em.loss.DiceLoss(), activation=None):
17        super().__init__()
18        self.activation = activation
19        self.loss = loss
20        # TODO serialize the class names and kwargs instead
21        self.init_kwargs = {}
22
23    def __call__(self, model, input_, labels, label_filter=None):
24        prediction = model(input_)
25        if self.activation is not None:
26            prediction = self.activation(prediction)
27        if label_filter is None:
28            loss = self.loss(prediction, labels)
29        else:
30            loss = self.loss(prediction * label_filter, labels * label_filter)
31        return loss

Loss function for self training.

Arguments:
  • loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss)
  • activation [nn.Module, callable] - the activation function to be applied to the prediction before passing it to the loss. (default: None)
DefaultSelfTrainingLoss(loss=DiceLoss(), activation=None)
16    def __init__(self, loss=torch_em.loss.DiceLoss(), activation=None):
17        super().__init__()
18        self.activation = activation
19        self.loss = loss
20        # TODO serialize the class names and kwargs instead
21        self.init_kwargs = {}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

activation
loss
init_kwargs
Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class DefaultSelfTrainingLossAndMetric(torch.nn.modules.module.Module):
34class DefaultSelfTrainingLossAndMetric(nn.Module):
35    """Loss and metric function for self training.
36
37    Parameters:
38        loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss)
39        metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
40        activation [nn.Module, callable] - the activation function to be applied to the prediction
41            before passing it to the loss. (default: None)
42    """
43    def __init__(self, loss=torch_em.loss.DiceLoss(), metric=torch_em.loss.DiceLoss(), activation=None):
44        super().__init__()
45        self.activation = activation
46        self.loss = loss
47        self.metric = metric
48        # TODO serialize the class names and dicts instead
49        self.init_kwargs = {}
50
51    def __call__(self, model, input_, labels, label_filter=None):
52        prediction = model(input_)
53        if self.activation is not None:
54            prediction = self.activation(prediction)
55        if label_filter is None:
56            loss = self.loss(prediction, labels)
57        else:
58            loss = self.loss(prediction * label_filter, labels * label_filter)
59        metric = self.metric(prediction, labels)
60        return loss, metric

Loss and metric function for self training.

Arguments:
  • loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss)
  • metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
  • activation [nn.Module, callable] - the activation function to be applied to the prediction before passing it to the loss. (default: None)
DefaultSelfTrainingLossAndMetric(loss=DiceLoss(), metric=DiceLoss(), activation=None)
43    def __init__(self, loss=torch_em.loss.DiceLoss(), metric=torch_em.loss.DiceLoss(), activation=None):
44        super().__init__()
45        self.activation = activation
46        self.loss = loss
47        self.metric = metric
48        # TODO serialize the class names and dicts instead
49        self.init_kwargs = {}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

activation
loss
metric
init_kwargs
Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def l2_regularisation(m):
63def l2_regularisation(m):
64    l2_reg = None
65
66    for W in m.parameters():
67        if l2_reg is None:
68            l2_reg = W.norm(2)
69        else:
70            l2_reg = l2_reg + W.norm(2)
71    return l2_reg
class ProbabilisticUNetLoss(torch.nn.modules.module.Module):
74class ProbabilisticUNetLoss(nn.Module):
75    """
76    Loss function for Probabilistic UNet
77
78    Parameters :
79        # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
80        loss [nn.Module] - the loss function to be used. (default: None)
81    """
82    def __init__(self, loss=None):
83        super().__init__()
84        self.loss = loss
85
86    def __call__(self, model, input_, labels, label_filter=None):
87        model.forward(input_, labels)
88
89        if self.loss is None:
90            elbo = model.elbo(labels, label_filter)
91            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
92                l2_regularisation(model.fcomb.layers)
93            loss = -elbo + 1e-5 * reg_loss
94
95        return loss

Loss function for Probabilistic UNet

Parameters :

TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)

loss [nn.Module] - the loss function to be used. (default: None)

ProbabilisticUNetLoss(loss=None)
82    def __init__(self, loss=None):
83        super().__init__()
84        self.loss = loss

Initializes internal Module state, shared by both nn.Module and ScriptModule.

loss
Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ProbabilisticUNetLossAndMetric(torch.nn.modules.module.Module):
 98class ProbabilisticUNetLossAndMetric(nn.Module):
 99    """Loss and metric function for Probabilistic UNet.
100
101    Parameters:
102        # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
103        loss [nn.Module] - the loss function to be used. (default: None)
104
105        metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
106        activation [nn.Module, callable] - the activation function to be applied to the prediction
107            before evaluating the average predictions. (default: None)
108    """
109    def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16):
110        super().__init__()
111        self.activation = activation
112        self.metric = metric
113        self.loss = loss
114        self.prior_samples = prior_samples
115
116    def __call__(self, model, input_, labels, label_filter=None):
117        model.forward(input_, labels)
118
119        if self.loss is None:
120            elbo = model.elbo(labels, label_filter)
121            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
122                l2_regularisation(model.fcomb.layers)
123            loss = -elbo + 1e-5 * reg_loss
124
125        samples_per_distribution = []
126        for _ in range(self.prior_samples):
127            samples = model.sample(testing=False)
128            if self.activation is not None:
129                samples = self.activation(samples)
130            samples_per_distribution.append(samples)
131
132        avg_samples = torch.stack(samples_per_distribution, dim=0).sum(dim=0) / len(samples_per_distribution)
133        metric = self.metric(avg_samples, labels)
134
135        return loss, metric

Loss and metric function for Probabilistic UNet.

Arguments:
  • # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
  • loss [nn.Module] - the loss function to be used. (default: None)
  • metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
  • activation [nn.Module, callable] - the activation function to be applied to the prediction before evaluating the average predictions. (default: None)
ProbabilisticUNetLossAndMetric(loss=None, metric=DiceLoss(), activation=Sigmoid(), prior_samples=16)
109    def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16):
110        super().__init__()
111        self.activation = activation
112        self.metric = metric
113        self.loss = loss
114        self.prior_samples = prior_samples

Initializes internal Module state, shared by both nn.Module and ScriptModule.

activation
metric
loss
prior_samples
Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile