torch_em.loss.dice

  1from typing import Optional
  2
  3import torch
  4import torch.nn as nn
  5
  6
  7def flatten_samples(input_: torch.Tensor) -> torch.Tensor:
  8    """Flattens a tensor or a variable such that the channel axis is first and the sample (batch) axis is second.
  9
 10    The shapes are transformed as follows:
 11        (N, C, H, W) --> (C, N * H * W)
 12        (N, C, D, H, W) --> (C, N * D * H * W)
 13        (N, C) --> (C, N)
 14    The input must be atleast 2d.
 15
 16    Args:
 17        The input tensor.
 18
 19    Returns:
 20        The transformed input tensor.
 21    """
 22    # Get number of channels
 23    num_channels = input_.size(1)
 24    # Permute the channel axis to first
 25    permute_axes = list(range(input_.dim()))
 26    permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0]
 27    # For input shape (say) NCHW, this should have the shape CNHW
 28    permuted = input_.permute(*permute_axes).contiguous()
 29    # Now flatten out all but the first axis and return
 30    flattened = permuted.view(num_channels, -1)
 31    return flattened
 32
 33
 34def dice_score(
 35    input_: torch.Tensor,
 36    target: torch.Tensor,
 37    invert: bool = False,
 38    channelwise: bool = True,
 39    reduce_channel: Optional[str] = "sum",
 40    eps: float = 1e-7,
 41) -> torch.Tensor:
 42    """Compute the dice score between input and target.
 43
 44    Args:
 45        input_: The input tensor.
 46        target: The target tensor.
 47        invert: Whether to invert the returned dice score to obtain the dice error instead of the dice score.
 48        channelwise: Whether to return the dice score independently per channel.
 49        reduce_channel: How to return the dice score over the channel axis.
 50        eps: The epsilon value added to the denominator for numerical stability.
 51
 52    Returns:
 53        The dice score.
 54    """
 55    if input_.shape != target.shape:
 56        raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
 57
 58    if channelwise:
 59        # Flatten input and target to have the shape (C, N),
 60        # where N is the number of samples
 61        input_ = flatten_samples(input_)
 62        target = flatten_samples(target)
 63        # Compute numerator and denominator (by summing over samples and
 64        # leaving the channels intact)
 65        numerator = (input_ * target).sum(-1)
 66        denominator = (input_ * input_).sum(-1) + (target * target).sum(-1)
 67        channelwise_score = 2 * (numerator / denominator.clamp(min=eps))
 68        if invert:
 69            channelwise_score = 1. - channelwise_score
 70
 71        # Reduce the dice score over the channels to compute the overall dice score.
 72        # (default is to use the sum)
 73        if reduce_channel is None:
 74            score = channelwise_score
 75        elif reduce_channel == "sum":
 76            score = channelwise_score.sum()
 77        elif reduce_channel == "mean":
 78            score = channelwise_score.mean()
 79        elif reduce_channel == "max":
 80            score = channelwise_score.max()
 81        elif reduce_channel == "min":
 82            score = channelwise_score.min()
 83        else:
 84            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
 85
 86    else:
 87        numerator = (input_ * target).sum()
 88        denominator = (input_ * input_).sum() + (target * target).sum()
 89        score = 2. * (numerator / denominator.clamp(min=eps))
 90        if invert:
 91            score = 1. - score
 92
 93    return score
 94
 95
 96class DiceLoss(nn.Module):
 97    """Loss computed based on the dice error between a binary input and binary target.
 98
 99    Args:
100        channelwise: Whether to return the dice score independently per channel.
101        eps: The epsilon value added to the denominator for numerical stability.
102        reduce_channel: How to return the dice score over the channel axis.
103    """
104    def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"):
105        if reduce_channel not in ("sum", "mean", "max", "min", None):
106            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
107
108        super().__init__()
109        self.channelwise = channelwise
110        self.eps = eps
111        self.reduce_channel = reduce_channel
112
113        # all torch_em classes should store init kwargs to easily recreate the init call
114        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
115
116    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
117        """Compute the loss.
118
119        Args:
120            input_: The binary input.
121            target: The binary target.
122
123        Returns:
124            The dice loss.
125        """
126        return dice_score(
127            input_=input_,
128            target=target,
129            invert=True,
130            channelwise=self.channelwise,
131            eps=self.eps,
132            reduce_channel=self.reduce_channel
133        )
134
135
136class DiceLossWithLogits(nn.Module):
137    """Loss computed based on the dice error between logits and binary target.
138
139    Args:
140        channelwise: Whether to return the dice score independently per channel.
141        eps: The epsilon value added to the denominator for numerical stability.
142        reduce_channel: How to return the dice score over the channel axis.
143    """
144    def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"):
145        if reduce_channel not in ("sum", "mean", "max", "min", None):
146            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
147
148        super().__init__()
149        self.channelwise = channelwise
150        self.eps = eps
151        self.reduce_channel = reduce_channel
152
153        # all torch_em classes should store init kwargs to easily recreate the init call
154        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
155
156    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
157        """Compute the loss.
158
159        Args:
160            input_: The logits.
161            target: The binary target.
162
163        Returns:
164            The dice loss.
165        """
166        return dice_score(
167            input_=nn.functional.sigmoid(input_),
168            target=target,
169            invert=True,
170            channelwise=self.channelwise,
171            eps=self.eps,
172            reduce_channel=self.reduce_channel,
173        )
174
175
176class BCEDiceLoss(nn.Module):
177    """Loss computed based on the binary cross entropy and the dice error between binary inputs and binary target.
178
179    Args:
180        alpha: The weight for combining the BCE and dice loss.
181        channelwise: Whether to return the dice score independently per channel.
182        eps: The epsilon value added to the denominator for numerical stability.
183        reduce_channel: How to return the dice score over the channel axis.
184    """
185    def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7):
186        super().__init__()
187        self.alpha = alpha
188        self.beta = beta
189        self.channelwise = channelwise
190        self.eps = eps
191
192        # All torch_em classes should store init kwargs to easily recreate the init call.
193        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
194
195    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
196        """Compute the loss.
197
198        Args:
199            input_: The binary input.
200            target: The binary target.
201
202        Returns:
203            The combined BCE and dice loss.
204        """
205        loss_dice = dice_score(
206            input_=input_,
207            target=target,
208            invert=True,
209            channelwise=self.channelwise,
210            eps=self.eps
211        )
212        loss_bce = nn.functional.binary_cross_entropy(input_, target)
213        return self.alpha * loss_dice + self.beta * loss_bce
214
215
216# TODO think about how to handle combined losses like this for mixed precision training
217class BCEDiceLossWithLogits(nn.Module):
218    """Loss computed based on the binary cross entropy and the dice error between logits and binary target.
219
220    Args:
221        alpha: The weight for combining the BCE and dice loss.
222        channelwise: Whether to return the dice score independently per channel.
223        eps: The epsilon value added to the denominator for numerical stability.
224        reduce_channel: How to return the dice score over the channel axis.
225    """
226    def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7):
227        super().__init__()
228        self.alpha = alpha
229        self.beta = beta
230        self.channelwise = channelwise
231        self.eps = eps
232
233        # All torch_em classes should store init kwargs to easily recreate the init call.
234        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
235
236    def forward(self, input_, target):
237        """Compute the loss.
238
239        Args:
240            input_: The logits.
241            target: The binary target.
242
243        Returns:
244            The combined BCE and dice loss.
245        """
246        loss_dice = dice_score(
247            input_=nn.functional.sigmoid(input_),
248            target=target,
249            invert=True,
250            channelwise=self.channelwise,
251            eps=self.eps
252        )
253
254        loss_bce = nn.functional.binary_cross_entropy_with_logits(input_, target)
255
256        return self.alpha * loss_dice + self.beta * loss_bce
def flatten_samples(input_: torch.Tensor) -> torch.Tensor:
 8def flatten_samples(input_: torch.Tensor) -> torch.Tensor:
 9    """Flattens a tensor or a variable such that the channel axis is first and the sample (batch) axis is second.
10
11    The shapes are transformed as follows:
12        (N, C, H, W) --> (C, N * H * W)
13        (N, C, D, H, W) --> (C, N * D * H * W)
14        (N, C) --> (C, N)
15    The input must be atleast 2d.
16
17    Args:
18        The input tensor.
19
20    Returns:
21        The transformed input tensor.
22    """
23    # Get number of channels
24    num_channels = input_.size(1)
25    # Permute the channel axis to first
26    permute_axes = list(range(input_.dim()))
27    permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0]
28    # For input shape (say) NCHW, this should have the shape CNHW
29    permuted = input_.permute(*permute_axes).contiguous()
30    # Now flatten out all but the first axis and return
31    flattened = permuted.view(num_channels, -1)
32    return flattened

Flattens a tensor or a variable such that the channel axis is first and the sample (batch) axis is second.

The shapes are transformed as follows:

(N, C, H, W) --> (C, N * H * W) (N, C, D, H, W) --> (C, N * D * H * W) (N, C) --> (C, N)

The input must be atleast 2d.

Arguments:
  • The input tensor.
Returns:

The transformed input tensor.

def dice_score( input_: torch.Tensor, target: torch.Tensor, invert: bool = False, channelwise: bool = True, reduce_channel: Optional[str] = 'sum', eps: float = 1e-07) -> torch.Tensor:
35def dice_score(
36    input_: torch.Tensor,
37    target: torch.Tensor,
38    invert: bool = False,
39    channelwise: bool = True,
40    reduce_channel: Optional[str] = "sum",
41    eps: float = 1e-7,
42) -> torch.Tensor:
43    """Compute the dice score between input and target.
44
45    Args:
46        input_: The input tensor.
47        target: The target tensor.
48        invert: Whether to invert the returned dice score to obtain the dice error instead of the dice score.
49        channelwise: Whether to return the dice score independently per channel.
50        reduce_channel: How to return the dice score over the channel axis.
51        eps: The epsilon value added to the denominator for numerical stability.
52
53    Returns:
54        The dice score.
55    """
56    if input_.shape != target.shape:
57        raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
58
59    if channelwise:
60        # Flatten input and target to have the shape (C, N),
61        # where N is the number of samples
62        input_ = flatten_samples(input_)
63        target = flatten_samples(target)
64        # Compute numerator and denominator (by summing over samples and
65        # leaving the channels intact)
66        numerator = (input_ * target).sum(-1)
67        denominator = (input_ * input_).sum(-1) + (target * target).sum(-1)
68        channelwise_score = 2 * (numerator / denominator.clamp(min=eps))
69        if invert:
70            channelwise_score = 1. - channelwise_score
71
72        # Reduce the dice score over the channels to compute the overall dice score.
73        # (default is to use the sum)
74        if reduce_channel is None:
75            score = channelwise_score
76        elif reduce_channel == "sum":
77            score = channelwise_score.sum()
78        elif reduce_channel == "mean":
79            score = channelwise_score.mean()
80        elif reduce_channel == "max":
81            score = channelwise_score.max()
82        elif reduce_channel == "min":
83            score = channelwise_score.min()
84        else:
85            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
86
87    else:
88        numerator = (input_ * target).sum()
89        denominator = (input_ * input_).sum() + (target * target).sum()
90        score = 2. * (numerator / denominator.clamp(min=eps))
91        if invert:
92            score = 1. - score
93
94    return score

Compute the dice score between input and target.

Arguments:
  • input_: The input tensor.
  • target: The target tensor.
  • invert: Whether to invert the returned dice score to obtain the dice error instead of the dice score.
  • channelwise: Whether to return the dice score independently per channel.
  • reduce_channel: How to return the dice score over the channel axis.
  • eps: The epsilon value added to the denominator for numerical stability.
Returns:

The dice score.

class DiceLoss(torch.nn.modules.module.Module):
 97class DiceLoss(nn.Module):
 98    """Loss computed based on the dice error between a binary input and binary target.
 99
100    Args:
101        channelwise: Whether to return the dice score independently per channel.
102        eps: The epsilon value added to the denominator for numerical stability.
103        reduce_channel: How to return the dice score over the channel axis.
104    """
105    def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"):
106        if reduce_channel not in ("sum", "mean", "max", "min", None):
107            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
108
109        super().__init__()
110        self.channelwise = channelwise
111        self.eps = eps
112        self.reduce_channel = reduce_channel
113
114        # all torch_em classes should store init kwargs to easily recreate the init call
115        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
116
117    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
118        """Compute the loss.
119
120        Args:
121            input_: The binary input.
122            target: The binary target.
123
124        Returns:
125            The dice loss.
126        """
127        return dice_score(
128            input_=input_,
129            target=target,
130            invert=True,
131            channelwise=self.channelwise,
132            eps=self.eps,
133            reduce_channel=self.reduce_channel
134        )

Loss computed based on the dice error between a binary input and binary target.

Arguments:
  • channelwise: Whether to return the dice score independently per channel.
  • eps: The epsilon value added to the denominator for numerical stability.
  • reduce_channel: How to return the dice score over the channel axis.
DiceLoss( channelwise: bool = True, eps: float = 1e-07, reduce_channel: Optional[str] = 'sum')
105    def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"):
106        if reduce_channel not in ("sum", "mean", "max", "min", None):
107            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
108
109        super().__init__()
110        self.channelwise = channelwise
111        self.eps = eps
112        self.reduce_channel = reduce_channel
113
114        # all torch_em classes should store init kwargs to easily recreate the init call
115        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

channelwise
eps
reduce_channel
init_kwargs
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
117    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
118        """Compute the loss.
119
120        Args:
121            input_: The binary input.
122            target: The binary target.
123
124        Returns:
125            The dice loss.
126        """
127        return dice_score(
128            input_=input_,
129            target=target,
130            invert=True,
131            channelwise=self.channelwise,
132            eps=self.eps,
133            reduce_channel=self.reduce_channel
134        )

Compute the loss.

Arguments:
  • input_: The binary input.
  • target: The binary target.
Returns:

The dice loss.

class DiceLossWithLogits(torch.nn.modules.module.Module):
137class DiceLossWithLogits(nn.Module):
138    """Loss computed based on the dice error between logits and binary target.
139
140    Args:
141        channelwise: Whether to return the dice score independently per channel.
142        eps: The epsilon value added to the denominator for numerical stability.
143        reduce_channel: How to return the dice score over the channel axis.
144    """
145    def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"):
146        if reduce_channel not in ("sum", "mean", "max", "min", None):
147            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
148
149        super().__init__()
150        self.channelwise = channelwise
151        self.eps = eps
152        self.reduce_channel = reduce_channel
153
154        # all torch_em classes should store init kwargs to easily recreate the init call
155        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}
156
157    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
158        """Compute the loss.
159
160        Args:
161            input_: The logits.
162            target: The binary target.
163
164        Returns:
165            The dice loss.
166        """
167        return dice_score(
168            input_=nn.functional.sigmoid(input_),
169            target=target,
170            invert=True,
171            channelwise=self.channelwise,
172            eps=self.eps,
173            reduce_channel=self.reduce_channel,
174        )

Loss computed based on the dice error between logits and binary target.

Arguments:
  • channelwise: Whether to return the dice score independently per channel.
  • eps: The epsilon value added to the denominator for numerical stability.
  • reduce_channel: How to return the dice score over the channel axis.
DiceLossWithLogits( channelwise: bool = True, eps: float = 1e-07, reduce_channel: Optional[str] = 'sum')
145    def __init__(self, channelwise: bool = True, eps: float = 1e-7, reduce_channel: Optional[str] = "sum"):
146        if reduce_channel not in ("sum", "mean", "max", "min", None):
147            raise ValueError(f"Unsupported channel reduction {reduce_channel}")
148
149        super().__init__()
150        self.channelwise = channelwise
151        self.eps = eps
152        self.reduce_channel = reduce_channel
153
154        # all torch_em classes should store init kwargs to easily recreate the init call
155        self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

channelwise
eps
reduce_channel
init_kwargs
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
157    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
158        """Compute the loss.
159
160        Args:
161            input_: The logits.
162            target: The binary target.
163
164        Returns:
165            The dice loss.
166        """
167        return dice_score(
168            input_=nn.functional.sigmoid(input_),
169            target=target,
170            invert=True,
171            channelwise=self.channelwise,
172            eps=self.eps,
173            reduce_channel=self.reduce_channel,
174        )

Compute the loss.

Arguments:
  • input_: The logits.
  • target: The binary target.
Returns:

The dice loss.

class BCEDiceLoss(torch.nn.modules.module.Module):
177class BCEDiceLoss(nn.Module):
178    """Loss computed based on the binary cross entropy and the dice error between binary inputs and binary target.
179
180    Args:
181        alpha: The weight for combining the BCE and dice loss.
182        channelwise: Whether to return the dice score independently per channel.
183        eps: The epsilon value added to the denominator for numerical stability.
184        reduce_channel: How to return the dice score over the channel axis.
185    """
186    def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7):
187        super().__init__()
188        self.alpha = alpha
189        self.beta = beta
190        self.channelwise = channelwise
191        self.eps = eps
192
193        # All torch_em classes should store init kwargs to easily recreate the init call.
194        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
195
196    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
197        """Compute the loss.
198
199        Args:
200            input_: The binary input.
201            target: The binary target.
202
203        Returns:
204            The combined BCE and dice loss.
205        """
206        loss_dice = dice_score(
207            input_=input_,
208            target=target,
209            invert=True,
210            channelwise=self.channelwise,
211            eps=self.eps
212        )
213        loss_bce = nn.functional.binary_cross_entropy(input_, target)
214        return self.alpha * loss_dice + self.beta * loss_bce

Loss computed based on the binary cross entropy and the dice error between binary inputs and binary target.

Arguments:
  • alpha: The weight for combining the BCE and dice loss.
  • channelwise: Whether to return the dice score independently per channel.
  • eps: The epsilon value added to the denominator for numerical stability.
  • reduce_channel: How to return the dice score over the channel axis.
BCEDiceLoss( alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-07)
186    def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7):
187        super().__init__()
188        self.alpha = alpha
189        self.beta = beta
190        self.channelwise = channelwise
191        self.eps = eps
192
193        # All torch_em classes should store init kwargs to easily recreate the init call.
194        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

alpha
beta
channelwise
eps
init_kwargs
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
196    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
197        """Compute the loss.
198
199        Args:
200            input_: The binary input.
201            target: The binary target.
202
203        Returns:
204            The combined BCE and dice loss.
205        """
206        loss_dice = dice_score(
207            input_=input_,
208            target=target,
209            invert=True,
210            channelwise=self.channelwise,
211            eps=self.eps
212        )
213        loss_bce = nn.functional.binary_cross_entropy(input_, target)
214        return self.alpha * loss_dice + self.beta * loss_bce

Compute the loss.

Arguments:
  • input_: The binary input.
  • target: The binary target.
Returns:

The combined BCE and dice loss.

class BCEDiceLossWithLogits(torch.nn.modules.module.Module):
218class BCEDiceLossWithLogits(nn.Module):
219    """Loss computed based on the binary cross entropy and the dice error between logits and binary target.
220
221    Args:
222        alpha: The weight for combining the BCE and dice loss.
223        channelwise: Whether to return the dice score independently per channel.
224        eps: The epsilon value added to the denominator for numerical stability.
225        reduce_channel: How to return the dice score over the channel axis.
226    """
227    def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7):
228        super().__init__()
229        self.alpha = alpha
230        self.beta = beta
231        self.channelwise = channelwise
232        self.eps = eps
233
234        # All torch_em classes should store init kwargs to easily recreate the init call.
235        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}
236
237    def forward(self, input_, target):
238        """Compute the loss.
239
240        Args:
241            input_: The logits.
242            target: The binary target.
243
244        Returns:
245            The combined BCE and dice loss.
246        """
247        loss_dice = dice_score(
248            input_=nn.functional.sigmoid(input_),
249            target=target,
250            invert=True,
251            channelwise=self.channelwise,
252            eps=self.eps
253        )
254
255        loss_bce = nn.functional.binary_cross_entropy_with_logits(input_, target)
256
257        return self.alpha * loss_dice + self.beta * loss_bce

Loss computed based on the binary cross entropy and the dice error between logits and binary target.

Arguments:
  • alpha: The weight for combining the BCE and dice loss.
  • channelwise: Whether to return the dice score independently per channel.
  • eps: The epsilon value added to the denominator for numerical stability.
  • reduce_channel: How to return the dice score over the channel axis.
BCEDiceLossWithLogits( alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-07)
227    def __init__(self, alpha: float = 1.0, beta: float = 1.0, channelwise: bool = True, eps: float = 1e-7):
228        super().__init__()
229        self.alpha = alpha
230        self.beta = beta
231        self.channelwise = channelwise
232        self.eps = eps
233
234        # All torch_em classes should store init kwargs to easily recreate the init call.
235        self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

alpha
beta
channelwise
eps
init_kwargs
def forward(self, input_, target):
237    def forward(self, input_, target):
238        """Compute the loss.
239
240        Args:
241            input_: The logits.
242            target: The binary target.
243
244        Returns:
245            The combined BCE and dice loss.
246        """
247        loss_dice = dice_score(
248            input_=nn.functional.sigmoid(input_),
249            target=target,
250            invert=True,
251            channelwise=self.channelwise,
252            eps=self.eps
253        )
254
255        loss_bce = nn.functional.binary_cross_entropy_with_logits(input_, target)
256
257        return self.alpha * loss_dice + self.beta * loss_bce

Compute the loss.

Arguments:
  • input_: The logits.
  • target: The binary target.
Returns:

The combined BCE and dice loss.