torch_em.loss.dice

  1import torch
  2import torch.nn as nn
  3
  4
  5# TODO refactor
  6def flatten_samples(input_):
  7    """
  8    Flattens a tensor or a variable such that the channel axis is first and the sample axis
  9    is second. The shapes are transformed as follows:
 10        (N, C, H, W) --> (C, N * H * W)
 11        (N, C, D, H, W) --> (C, N * D * H * W)
 12        (N, C) --> (C, N)
 13    The input must be atleast 2d.
 14    """
 15    # Get number of channels
 16    num_channels = input_.size(1)
 17    # Permute the channel axis to first
 18    permute_axes = list(range(input_.dim()))
 19    permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0]
 20    # For input shape (say) NCHW, this should have the shape CNHW
 21    permuted = input_.permute(*permute_axes).contiguous()
 22    # Now flatten out all but the first axis and return
 23    flattened = permuted.view(num_channels, -1)
 24    return flattened
 25
 26
 27def dice_score(input_, target, invert=False, channelwise=True, reduce_channel="sum", eps=1e-7):
 28    if input_.shape != target.shape:
 29        raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
 30
 31    if channelwise:
 32        # Flatten input and target to have the shape (C, N),
 33        # where N is the number of samples
 34        input_ = flatten_samples(input_)
 35        target = flatten_samples(target)
 36        # Compute numerator and denominator (by summing over samples and
 37        # leaving the channels intact)
 38        numerator = (input_ * target).sum(-1)
 39        denominator = (input_ * input_).sum(-1) + (target * target).sum(-1)
 40        channelwise_score = 2 * (numerator / denominator.clamp(min=eps))
 41        if invert:
 42            channelwise_score = 1. - channelwise_score
 43
 44        # Reduce the dice score over the channels to compute the overall dice score.
 45        # (default is to use the sum)
 46        if reduce_channel is None:
 47            score = channelwise_score
 48        elif reduce_channel == "sum":
 49            score = channelwise_score.sum()
 50        elif reduce_channel == "mean":
 51            score = channelwise_score.mean()
 52        elif reduce_channel == "max":
 53            score = channelwise_score.max()
 54        elif reduce_channel == "min":
 55            score = channelwise_score.min()
 56        else:
 57            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
 58
 59    else:
 60        numerator = (input_ * target).sum()
 61        denominator = (input_ * input_).sum() + (target * target).sum()
 62        score = 2. * (numerator / denominator.clamp(min=eps))
 63        if invert:
 64            score = 1. - score
 65
 66    return score
 67
 68
 69class DiceLoss(nn.Module):
 70    def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"):
 71        if reduce_channel not in ("sum", "mean", "max", "min", None):
 72            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
 73        super().__init__()
 74        self.channelwise = channelwise
 75        self.eps = eps
 76        self.reduce_channel = reduce_channel
 77
 78        # all torch_em classes should store init kwargs to easily recreate the init call
 79        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
 80
 81    def forward(self, input_, target):
 82        return dice_score(input_, target,
 83                          invert=True, channelwise=self.channelwise,
 84                          eps=self.eps, reduce_channel=self.reduce_channel)
 85
 86
 87class DiceLossWithLogits(nn.Module):
 88    def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"):
 89        super().__init__()
 90        self.channelwise = channelwise
 91        self.eps = eps
 92        self.reduce_channel = reduce_channel
 93
 94        # all torch_em classes should store init kwargs to easily recreate the init call
 95        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
 96
 97    def forward(self, input_, target):
 98        return dice_score(
 99            nn.functional.sigmoid(input_),
100            target,
101            invert=True,
102            channelwise=self.channelwise,
103            eps=self.eps,
104            reduce_channel=self.reduce_channel,
105        )
106
107
108class BCEDiceLoss(nn.Module):
109
110    def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7):
111        super().__init__()
112        self.alpha = alpha
113        self.beta = beta
114        self.channelwise = channelwise
115        self.eps = eps
116
117        # all torch_em classes should store init kwargs to easily recreate the init call
118        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
119
120    def forward(self, input_, target):
121        loss_dice = dice_score(
122            input_,
123            target,
124            invert=True,
125            channelwise=self.channelwise,
126            eps=self.eps
127        )
128        loss_bce = nn.functional.binary_cross_entropy(
129            input_, target
130        )
131        return self.alpha * loss_dice + self.beta * loss_bce
132
133
134# TODO think about how to handle combined losses like this for mixed precision training
135class BCEDiceLossWithLogits(nn.Module):
136
137    def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7):
138        super().__init__()
139        self.alpha = alpha
140        self.beta = beta
141        self.channelwise = channelwise
142        self.eps = eps
143
144        # all torch_em classes should store init kwargs to easily recreate the init call
145        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
146
147    def forward(self, input_, target):
148        loss_dice = dice_score(
149            torch.sigmoid(input_),
150            target,
151            invert=True,
152            channelwise=self.channelwise,
153            eps=self.eps
154        )
155        loss_bce = nn.functional.binary_cross_entropy_with_logits(
156            input_, target
157        )
158        return self.alpha * loss_dice + self.beta * loss_bce
def flatten_samples(input_):
 7def flatten_samples(input_):
 8    """
 9    Flattens a tensor or a variable such that the channel axis is first and the sample axis
10    is second. The shapes are transformed as follows:
11        (N, C, H, W) --> (C, N * H * W)
12        (N, C, D, H, W) --> (C, N * D * H * W)
13        (N, C) --> (C, N)
14    The input must be atleast 2d.
15    """
16    # Get number of channels
17    num_channels = input_.size(1)
18    # Permute the channel axis to first
19    permute_axes = list(range(input_.dim()))
20    permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0]
21    # For input shape (say) NCHW, this should have the shape CNHW
22    permuted = input_.permute(*permute_axes).contiguous()
23    # Now flatten out all but the first axis and return
24    flattened = permuted.view(num_channels, -1)
25    return flattened

Flattens a tensor or a variable such that the channel axis is first and the sample axis is second. The shapes are transformed as follows: (N, C, H, W) --> (C, N * H * W) (N, C, D, H, W) --> (C, N * D * H * W) (N, C) --> (C, N) The input must be atleast 2d.

def dice_score( input_, target, invert=False, channelwise=True, reduce_channel='sum', eps=1e-07):
28def dice_score(input_, target, invert=False, channelwise=True, reduce_channel="sum", eps=1e-7):
29    if input_.shape != target.shape:
30        raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
31
32    if channelwise:
33        # Flatten input and target to have the shape (C, N),
34        # where N is the number of samples
35        input_ = flatten_samples(input_)
36        target = flatten_samples(target)
37        # Compute numerator and denominator (by summing over samples and
38        # leaving the channels intact)
39        numerator = (input_ * target).sum(-1)
40        denominator = (input_ * input_).sum(-1) + (target * target).sum(-1)
41        channelwise_score = 2 * (numerator / denominator.clamp(min=eps))
42        if invert:
43            channelwise_score = 1. - channelwise_score
44
45        # Reduce the dice score over the channels to compute the overall dice score.
46        # (default is to use the sum)
47        if reduce_channel is None:
48            score = channelwise_score
49        elif reduce_channel == "sum":
50            score = channelwise_score.sum()
51        elif reduce_channel == "mean":
52            score = channelwise_score.mean()
53        elif reduce_channel == "max":
54            score = channelwise_score.max()
55        elif reduce_channel == "min":
56            score = channelwise_score.min()
57        else:
58            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
59
60    else:
61        numerator = (input_ * target).sum()
62        denominator = (input_ * input_).sum() + (target * target).sum()
63        score = 2. * (numerator / denominator.clamp(min=eps))
64        if invert:
65            score = 1. - score
66
67    return score
class DiceLoss(torch.nn.modules.module.Module):
70class DiceLoss(nn.Module):
71    def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"):
72        if reduce_channel not in ("sum", "mean", "max", "min", None):
73            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
74        super().__init__()
75        self.channelwise = channelwise
76        self.eps = eps
77        self.reduce_channel = reduce_channel
78
79        # all torch_em classes should store init kwargs to easily recreate the init call
80        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
81
82    def forward(self, input_, target):
83        return dice_score(input_, target,
84                          invert=True, channelwise=self.channelwise,
85                          eps=self.eps, reduce_channel=self.reduce_channel)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

DiceLoss(channelwise=True, eps=1e-07, reduce_channel='sum')
71    def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"):
72        if reduce_channel not in ("sum", "mean", "max", "min", None):
73            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
74        super().__init__()
75        self.channelwise = channelwise
76        self.eps = eps
77        self.reduce_channel = reduce_channel
78
79        # all torch_em classes should store init kwargs to easily recreate the init call
80        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

channelwise
eps
reduce_channel
init_kwargs
def forward(self, input_, target):
82    def forward(self, input_, target):
83        return dice_score(input_, target,
84                          invert=True, channelwise=self.channelwise,
85                          eps=self.eps, reduce_channel=self.reduce_channel)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class DiceLossWithLogits(torch.nn.modules.module.Module):
 88class DiceLossWithLogits(nn.Module):
 89    def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"):
 90        super().__init__()
 91        self.channelwise = channelwise
 92        self.eps = eps
 93        self.reduce_channel = reduce_channel
 94
 95        # all torch_em classes should store init kwargs to easily recreate the init call
 96        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
 97
 98    def forward(self, input_, target):
 99        return dice_score(
100            nn.functional.sigmoid(input_),
101            target,
102            invert=True,
103            channelwise=self.channelwise,
104            eps=self.eps,
105            reduce_channel=self.reduce_channel,
106        )

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

DiceLossWithLogits(channelwise=True, eps=1e-07, reduce_channel='sum')
89    def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"):
90        super().__init__()
91        self.channelwise = channelwise
92        self.eps = eps
93        self.reduce_channel = reduce_channel
94
95        # all torch_em classes should store init kwargs to easily recreate the init call
96        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

channelwise
eps
reduce_channel
init_kwargs
def forward(self, input_, target):
 98    def forward(self, input_, target):
 99        return dice_score(
100            nn.functional.sigmoid(input_),
101            target,
102            invert=True,
103            channelwise=self.channelwise,
104            eps=self.eps,
105            reduce_channel=self.reduce_channel,
106        )

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class BCEDiceLoss(torch.nn.modules.module.Module):
109class BCEDiceLoss(nn.Module):
110
111    def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7):
112        super().__init__()
113        self.alpha = alpha
114        self.beta = beta
115        self.channelwise = channelwise
116        self.eps = eps
117
118        # all torch_em classes should store init kwargs to easily recreate the init call
119        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
120
121    def forward(self, input_, target):
122        loss_dice = dice_score(
123            input_,
124            target,
125            invert=True,
126            channelwise=self.channelwise,
127            eps=self.eps
128        )
129        loss_bce = nn.functional.binary_cross_entropy(
130            input_, target
131        )
132        return self.alpha * loss_dice + self.beta * loss_bce

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

BCEDiceLoss(alpha=1.0, beta=1.0, channelwise=True, eps=1e-07)
111    def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7):
112        super().__init__()
113        self.alpha = alpha
114        self.beta = beta
115        self.channelwise = channelwise
116        self.eps = eps
117
118        # all torch_em classes should store init kwargs to easily recreate the init call
119        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

alpha
beta
channelwise
eps
init_kwargs
def forward(self, input_, target):
121    def forward(self, input_, target):
122        loss_dice = dice_score(
123            input_,
124            target,
125            invert=True,
126            channelwise=self.channelwise,
127            eps=self.eps
128        )
129        loss_bce = nn.functional.binary_cross_entropy(
130            input_, target
131        )
132        return self.alpha * loss_dice + self.beta * loss_bce

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class BCEDiceLossWithLogits(torch.nn.modules.module.Module):
136class BCEDiceLossWithLogits(nn.Module):
137
138    def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7):
139        super().__init__()
140        self.alpha = alpha
141        self.beta = beta
142        self.channelwise = channelwise
143        self.eps = eps
144
145        # all torch_em classes should store init kwargs to easily recreate the init call
146        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
147
148    def forward(self, input_, target):
149        loss_dice = dice_score(
150            torch.sigmoid(input_),
151            target,
152            invert=True,
153            channelwise=self.channelwise,
154            eps=self.eps
155        )
156        loss_bce = nn.functional.binary_cross_entropy_with_logits(
157            input_, target
158        )
159        return self.alpha * loss_dice + self.beta * loss_bce

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

BCEDiceLossWithLogits(alpha=1.0, beta=1.0, channelwise=True, eps=1e-07)
138    def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7):
139        super().__init__()
140        self.alpha = alpha
141        self.beta = beta
142        self.channelwise = channelwise
143        self.eps = eps
144
145        # all torch_em classes should store init kwargs to easily recreate the init call
146        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

alpha
beta
channelwise
eps
init_kwargs
def forward(self, input_, target):
148    def forward(self, input_, target):
149        loss_dice = dice_score(
150            torch.sigmoid(input_),
151            target,
152            invert=True,
153            channelwise=self.channelwise,
154            eps=self.eps
155        )
156        loss_bce = nn.functional.binary_cross_entropy_with_logits(
157            input_, target
158        )
159        return self.alpha * loss_dice + self.beta * loss_bce

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile