torch_em.loss.cldice

  1import torch
  2import torch.nn as nn
  3import torch.nn.functional as F
  4
  5from .dice import dice_score
  6
  7# From "clDice -- A Novel Topology-Preserving Loss Function for Tubular Structure Segmentation":
  8# https://arxiv.org/abs/2003.07311
  9
 10
 11class SoftSkeletonize(torch.nn.Module):
 12    """`SoftSkeletonize` is a differentiable approximation for skeletonization,
 13        which applies iterative min- and max-pooling as a proxy for
 14        morphological erosion and dilation.
 15
 16    Args:
 17        num_iter: Number of iterations for soft-skeletonization.
 18            Should be greater or equal to than the maximum observed radius.
 19    """
 20    def __init__(self, num_iter: int = 5):
 21
 22        super(SoftSkeletonize, self).__init__()
 23        self.num_iter = num_iter
 24
 25    def soft_erode(self, input_: torch.Tensor):
 26
 27        if len(input_.shape) == 4:
 28            p1 = -F.max_pool2d(-input_, (3, 1), (1, 1), (1, 0))
 29            p2 = -F.max_pool2d(-input_, (1, 3), (1, 1), (0, 1))
 30            return torch.min(p1, p2)
 31        elif len(input_.shape) == 5:
 32            p1 = -F.max_pool3d(-input_, (3, 1, 1), (1, 1, 1), (1, 0, 0))
 33            p2 = -F.max_pool3d(-input_, (1, 3, 1), (1, 1, 1), (0, 1, 0))
 34            p3 = -F.max_pool3d(-input_, (1, 1, 3), (1, 1, 1), (0, 0, 1))
 35            return torch.min(torch.min(p1, p2), p3)
 36
 37    def soft_dilate(self, input_: torch.Tensor):
 38
 39        if len(input_.shape) == 4:
 40            return F.max_pool2d(input_, (3, 3), (1, 1), (1, 1))
 41        elif len(input_.shape) == 5:
 42            return F.max_pool3d(input_, (3, 3, 3), (1, 1, 1), (1, 1, 1))
 43
 44    def soft_open(self, input_: torch.Tensor):
 45
 46        return self.soft_dilate(self.soft_erode(input_))
 47
 48    def soft_skel(self, input_: torch.Tensor):
 49
 50        input1 = self.soft_open(input_)
 51        skel = F.relu(input_ - input1)
 52
 53        for j in range(self.num_iter):
 54            input_ = self.soft_erode(input_)
 55            input1 = self.soft_open(input_)
 56            delta = F.relu(input_-input1)
 57            skel = skel + F.relu(delta - skel * delta)
 58
 59        return skel
 60
 61    def forward(self, input_: torch.Tensor):
 62        """Skeletonize the input prediction.
 63
 64        Args:
 65            input_: The input logits.
 66
 67        Returns:
 68            The skeletonization.
 69        """
 70        return self.soft_skel(input_)
 71
 72
 73def cldice_score(
 74    input_: torch.Tensor,
 75    target: torch.Tensor,
 76    num_iter: int = 5,
 77    invert: bool = False,
 78    eps: float = 1e-7,
 79) -> torch.Tensor:
 80    """Adapted from .dice.py `dice_score`. Compute the soft clDice score between input and target.
 81
 82    Args:
 83        input_: The input tensor.
 84        target: The target tensor.
 85        num_iter: Number of iterations for soft-skeletonization.
 86        invert: Whether to invert the returned dice score to obtain the cldice error instead of the cldice score.
 87        channelwise: Not implemented; whether to return the dice score independently per channel.
 88        reduce_channel: Not implemented; how to return the dice score over the channel axis.
 89        eps: The epsilon value added to the denominator for numerical stability.
 90
 91    Returns:
 92        The clDice score.
 93    """
 94    if input_.shape != target.shape:
 95        raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
 96
 97    soft_skeletonize = SoftSkeletonize(num_iter=num_iter)
 98    skel_input = soft_skeletonize(input_)
 99    skel_target = soft_skeletonize(target)
100
101    t_prec = (skel_input * target).sum() / (skel_input.sum()).clamp(min=eps)
102    t_sens = (skel_target * input_).sum() / (skel_target.sum()).clamp(min=eps)
103    score = 2.*(t_prec*t_sens)/(t_prec+t_sens).clamp(min=eps)
104
105    if invert:
106        score = 1. - score
107
108    return score
109
110
111class SoftclDiceLoss(nn.Module):
112    """Combined soft Dice and clDice loss for segmentation of tubular structures.
113
114        The soft clDice loss computes topology-aware loss by computing the
115        soft skeleton of both the prediction and target
116        and measuring overlap of the two skeletons. Teaches the model to learn
117        skeletons directly. In the clDice paper, the authors recommend using
118        the combined soft-Dice and soft-clDice loss to learn topology-aware
119        segmentations, which is implemented below as `CombinedclDiceLoss`.
120
121    Args:
122        num_iter: Number of iterations for soft-skeletonization.
123        eps: The epsilon value added to the denominator for numerical
124            stability.
125        exclude_background: Whether to exclude background channel 0 from the
126            loss computation.
127            Useful for multi-class segmentation.
128        channelwise: Not implemented; Whether to return the dice score
129            independently per channel.
130        reduce_channel: Not implemented; The epsilon value added to the
131            denominator for numerical stability.
132    """
133    def __init__(self, num_iter: int = 5, eps: float = 1e-7,
134                 exclude_background: bool = False):
135        super(SoftclDiceLoss, self).__init__()
136
137        self.num_iter = num_iter
138        self.eps = eps
139        self.exclude_background = exclude_background
140        self.init_kwargs = {"num_iter": num_iter, "eps": eps, "exclude_background": exclude_background}
141
142    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
143        """Compute soft clDice score between the input logits and binary target.
144
145        Args:
146            input_: The input logits.
147            target: The binary target.
148
149        Returns:
150            The soft clDice score.
151        """
152        if input_.shape != target.shape:
153            raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
154
155        if self.exclude_background:
156            target = target[:, 1:, :, :]
157            input_ = input_[:, 1:, :, :]
158
159        cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps)
160
161        return cldice
162
163
164# TODO implement `channelwise` for multiclass segmentation
165# TODO consider if `exclude_background` is needed for multiclass segmentation
166class CombinedclDiceLoss(SoftclDiceLoss):
167    """Combined soft-Dice and soft-clDice loss for segmentation of tubular structures.
168
169        The soft-clDice loss computes topology-aware loss by computing the
170        soft skeleton of both the prediction and target and measuring overlap
171        of the two skeletons. This encourages the model to preserve the
172        connectivity and topology of tubular structures. The final loss is a
173        weighted combination of soft Dice and clDice, controlled by alpha.
174
175    Args:
176        num_iter: Number of iterations for soft-skeletonization.
177        alpha: The weight for combining the soft Dice and soft clDice loss.
178        eps: The epsilon value added to the denominator for numerical
179            stability.
180        exclude_background: Whether to exclude background channel 0 from the
181            loss computation. Useful for multi-class segmentation.
182        invert: Not implemented; Whether to invert the returned dice score to
183            obtain the dice error instead of the dice score.
184        channelwise: Not implemented; Whether to return the dice score
185            independently per channel.
186        reduce_chnanel: Not implemented; How to return the dice score over the
187            channel axis.
188
189    """
190    def __init__(self, num_iter: int = 5, alpha: float = 0.5, eps: float = 1e-7,
191                 exclude_background: bool = False):
192        super(CombinedclDiceLoss, self).__init__(num_iter=num_iter, eps=eps, exclude_background=exclude_background)
193
194        self.alpha = alpha
195        self.init_kwargs = {"num_iter": num_iter, "alpha": alpha, "eps": eps, "exclude_background": exclude_background}
196
197    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
198        """Compute combined soft Dice and clDice loss.
199
200        Args:
201            input_: The input logits.
202            target: The binary target.
203
204        Returns:
205            Combined clDice loss.
206        """
207        if input_.shape != target.shape:
208            raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
209
210        if self.exclude_background:
211            target = target[:, 1:, :, :]
212            input_ = input_[:, 1:, :, :]
213        dice = dice_score(input_, target, invert=True, channelwise=False, eps=self.eps)
214        cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps)
215
216        return (1.0-self.alpha)*dice+self.alpha*cldice
class SoftSkeletonize(torch.nn.modules.module.Module):
12class SoftSkeletonize(torch.nn.Module):
13    """`SoftSkeletonize` is a differentiable approximation for skeletonization,
14        which applies iterative min- and max-pooling as a proxy for
15        morphological erosion and dilation.
16
17    Args:
18        num_iter: Number of iterations for soft-skeletonization.
19            Should be greater or equal to than the maximum observed radius.
20    """
21    def __init__(self, num_iter: int = 5):
22
23        super(SoftSkeletonize, self).__init__()
24        self.num_iter = num_iter
25
26    def soft_erode(self, input_: torch.Tensor):
27
28        if len(input_.shape) == 4:
29            p1 = -F.max_pool2d(-input_, (3, 1), (1, 1), (1, 0))
30            p2 = -F.max_pool2d(-input_, (1, 3), (1, 1), (0, 1))
31            return torch.min(p1, p2)
32        elif len(input_.shape) == 5:
33            p1 = -F.max_pool3d(-input_, (3, 1, 1), (1, 1, 1), (1, 0, 0))
34            p2 = -F.max_pool3d(-input_, (1, 3, 1), (1, 1, 1), (0, 1, 0))
35            p3 = -F.max_pool3d(-input_, (1, 1, 3), (1, 1, 1), (0, 0, 1))
36            return torch.min(torch.min(p1, p2), p3)
37
38    def soft_dilate(self, input_: torch.Tensor):
39
40        if len(input_.shape) == 4:
41            return F.max_pool2d(input_, (3, 3), (1, 1), (1, 1))
42        elif len(input_.shape) == 5:
43            return F.max_pool3d(input_, (3, 3, 3), (1, 1, 1), (1, 1, 1))
44
45    def soft_open(self, input_: torch.Tensor):
46
47        return self.soft_dilate(self.soft_erode(input_))
48
49    def soft_skel(self, input_: torch.Tensor):
50
51        input1 = self.soft_open(input_)
52        skel = F.relu(input_ - input1)
53
54        for j in range(self.num_iter):
55            input_ = self.soft_erode(input_)
56            input1 = self.soft_open(input_)
57            delta = F.relu(input_-input1)
58            skel = skel + F.relu(delta - skel * delta)
59
60        return skel
61
62    def forward(self, input_: torch.Tensor):
63        """Skeletonize the input prediction.
64
65        Args:
66            input_: The input logits.
67
68        Returns:
69            The skeletonization.
70        """
71        return self.soft_skel(input_)

SoftSkeletonize is a differentiable approximation for skeletonization, which applies iterative min- and max-pooling as a proxy for morphological erosion and dilation.

Arguments:
  • num_iter: Number of iterations for soft-skeletonization. Should be greater or equal to than the maximum observed radius.
SoftSkeletonize(num_iter: int = 5)
21    def __init__(self, num_iter: int = 5):
22
23        super(SoftSkeletonize, self).__init__()
24        self.num_iter = num_iter

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_iter
def soft_erode(self, input_: torch.Tensor):
26    def soft_erode(self, input_: torch.Tensor):
27
28        if len(input_.shape) == 4:
29            p1 = -F.max_pool2d(-input_, (3, 1), (1, 1), (1, 0))
30            p2 = -F.max_pool2d(-input_, (1, 3), (1, 1), (0, 1))
31            return torch.min(p1, p2)
32        elif len(input_.shape) == 5:
33            p1 = -F.max_pool3d(-input_, (3, 1, 1), (1, 1, 1), (1, 0, 0))
34            p2 = -F.max_pool3d(-input_, (1, 3, 1), (1, 1, 1), (0, 1, 0))
35            p3 = -F.max_pool3d(-input_, (1, 1, 3), (1, 1, 1), (0, 0, 1))
36            return torch.min(torch.min(p1, p2), p3)
def soft_dilate(self, input_: torch.Tensor):
38    def soft_dilate(self, input_: torch.Tensor):
39
40        if len(input_.shape) == 4:
41            return F.max_pool2d(input_, (3, 3), (1, 1), (1, 1))
42        elif len(input_.shape) == 5:
43            return F.max_pool3d(input_, (3, 3, 3), (1, 1, 1), (1, 1, 1))
def soft_open(self, input_: torch.Tensor):
45    def soft_open(self, input_: torch.Tensor):
46
47        return self.soft_dilate(self.soft_erode(input_))
def soft_skel(self, input_: torch.Tensor):
49    def soft_skel(self, input_: torch.Tensor):
50
51        input1 = self.soft_open(input_)
52        skel = F.relu(input_ - input1)
53
54        for j in range(self.num_iter):
55            input_ = self.soft_erode(input_)
56            input1 = self.soft_open(input_)
57            delta = F.relu(input_-input1)
58            skel = skel + F.relu(delta - skel * delta)
59
60        return skel
def forward(self, input_: torch.Tensor):
62    def forward(self, input_: torch.Tensor):
63        """Skeletonize the input prediction.
64
65        Args:
66            input_: The input logits.
67
68        Returns:
69            The skeletonization.
70        """
71        return self.soft_skel(input_)

Skeletonize the input prediction.

Arguments:
  • input_: The input logits.
Returns:

The skeletonization.

def cldice_score( input_: torch.Tensor, target: torch.Tensor, num_iter: int = 5, invert: bool = False, eps: float = 1e-07) -> torch.Tensor:
 74def cldice_score(
 75    input_: torch.Tensor,
 76    target: torch.Tensor,
 77    num_iter: int = 5,
 78    invert: bool = False,
 79    eps: float = 1e-7,
 80) -> torch.Tensor:
 81    """Adapted from .dice.py `dice_score`. Compute the soft clDice score between input and target.
 82
 83    Args:
 84        input_: The input tensor.
 85        target: The target tensor.
 86        num_iter: Number of iterations for soft-skeletonization.
 87        invert: Whether to invert the returned dice score to obtain the cldice error instead of the cldice score.
 88        channelwise: Not implemented; whether to return the dice score independently per channel.
 89        reduce_channel: Not implemented; how to return the dice score over the channel axis.
 90        eps: The epsilon value added to the denominator for numerical stability.
 91
 92    Returns:
 93        The clDice score.
 94    """
 95    if input_.shape != target.shape:
 96        raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
 97
 98    soft_skeletonize = SoftSkeletonize(num_iter=num_iter)
 99    skel_input = soft_skeletonize(input_)
100    skel_target = soft_skeletonize(target)
101
102    t_prec = (skel_input * target).sum() / (skel_input.sum()).clamp(min=eps)
103    t_sens = (skel_target * input_).sum() / (skel_target.sum()).clamp(min=eps)
104    score = 2.*(t_prec*t_sens)/(t_prec+t_sens).clamp(min=eps)
105
106    if invert:
107        score = 1. - score
108
109    return score

Adapted from .dice.py dice_score. Compute the soft clDice score between input and target.

Arguments:
  • input_: The input tensor.
  • target: The target tensor.
  • num_iter: Number of iterations for soft-skeletonization.
  • invert: Whether to invert the returned dice score to obtain the cldice error instead of the cldice score.
  • channelwise: Not implemented; whether to return the dice score independently per channel.
  • reduce_channel: Not implemented; how to return the dice score over the channel axis.
  • eps: The epsilon value added to the denominator for numerical stability.
Returns:

The clDice score.

class SoftclDiceLoss(torch.nn.modules.module.Module):
112class SoftclDiceLoss(nn.Module):
113    """Combined soft Dice and clDice loss for segmentation of tubular structures.
114
115        The soft clDice loss computes topology-aware loss by computing the
116        soft skeleton of both the prediction and target
117        and measuring overlap of the two skeletons. Teaches the model to learn
118        skeletons directly. In the clDice paper, the authors recommend using
119        the combined soft-Dice and soft-clDice loss to learn topology-aware
120        segmentations, which is implemented below as `CombinedclDiceLoss`.
121
122    Args:
123        num_iter: Number of iterations for soft-skeletonization.
124        eps: The epsilon value added to the denominator for numerical
125            stability.
126        exclude_background: Whether to exclude background channel 0 from the
127            loss computation.
128            Useful for multi-class segmentation.
129        channelwise: Not implemented; Whether to return the dice score
130            independently per channel.
131        reduce_channel: Not implemented; The epsilon value added to the
132            denominator for numerical stability.
133    """
134    def __init__(self, num_iter: int = 5, eps: float = 1e-7,
135                 exclude_background: bool = False):
136        super(SoftclDiceLoss, self).__init__()
137
138        self.num_iter = num_iter
139        self.eps = eps
140        self.exclude_background = exclude_background
141        self.init_kwargs = {"num_iter": num_iter, "eps": eps, "exclude_background": exclude_background}
142
143    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
144        """Compute soft clDice score between the input logits and binary target.
145
146        Args:
147            input_: The input logits.
148            target: The binary target.
149
150        Returns:
151            The soft clDice score.
152        """
153        if input_.shape != target.shape:
154            raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
155
156        if self.exclude_background:
157            target = target[:, 1:, :, :]
158            input_ = input_[:, 1:, :, :]
159
160        cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps)
161
162        return cldice

Combined soft Dice and clDice loss for segmentation of tubular structures.

The soft clDice loss computes topology-aware loss by computing the
soft skeleton of both the prediction and target
and measuring overlap of the two skeletons. Teaches the model to learn
skeletons directly. In the clDice paper, the authors recommend using
the combined soft-Dice and soft-clDice loss to learn topology-aware
segmentations, which is implemented below as `CombinedclDiceLoss`.
Arguments:
  • num_iter: Number of iterations for soft-skeletonization.
  • eps: The epsilon value added to the denominator for numerical stability.
  • exclude_background: Whether to exclude background channel 0 from the loss computation. Useful for multi-class segmentation.
  • channelwise: Not implemented; Whether to return the dice score independently per channel.
  • reduce_channel: Not implemented; The epsilon value added to the denominator for numerical stability.
SoftclDiceLoss( num_iter: int = 5, eps: float = 1e-07, exclude_background: bool = False)
134    def __init__(self, num_iter: int = 5, eps: float = 1e-7,
135                 exclude_background: bool = False):
136        super(SoftclDiceLoss, self).__init__()
137
138        self.num_iter = num_iter
139        self.eps = eps
140        self.exclude_background = exclude_background
141        self.init_kwargs = {"num_iter": num_iter, "eps": eps, "exclude_background": exclude_background}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_iter
eps
exclude_background
init_kwargs
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
143    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
144        """Compute soft clDice score between the input logits and binary target.
145
146        Args:
147            input_: The input logits.
148            target: The binary target.
149
150        Returns:
151            The soft clDice score.
152        """
153        if input_.shape != target.shape:
154            raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
155
156        if self.exclude_background:
157            target = target[:, 1:, :, :]
158            input_ = input_[:, 1:, :, :]
159
160        cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps)
161
162        return cldice

Compute soft clDice score between the input logits and binary target.

Arguments:
  • input_: The input logits.
  • target: The binary target.
Returns:

The soft clDice score.

class CombinedclDiceLoss(SoftclDiceLoss):
167class CombinedclDiceLoss(SoftclDiceLoss):
168    """Combined soft-Dice and soft-clDice loss for segmentation of tubular structures.
169
170        The soft-clDice loss computes topology-aware loss by computing the
171        soft skeleton of both the prediction and target and measuring overlap
172        of the two skeletons. This encourages the model to preserve the
173        connectivity and topology of tubular structures. The final loss is a
174        weighted combination of soft Dice and clDice, controlled by alpha.
175
176    Args:
177        num_iter: Number of iterations for soft-skeletonization.
178        alpha: The weight for combining the soft Dice and soft clDice loss.
179        eps: The epsilon value added to the denominator for numerical
180            stability.
181        exclude_background: Whether to exclude background channel 0 from the
182            loss computation. Useful for multi-class segmentation.
183        invert: Not implemented; Whether to invert the returned dice score to
184            obtain the dice error instead of the dice score.
185        channelwise: Not implemented; Whether to return the dice score
186            independently per channel.
187        reduce_chnanel: Not implemented; How to return the dice score over the
188            channel axis.
189
190    """
191    def __init__(self, num_iter: int = 5, alpha: float = 0.5, eps: float = 1e-7,
192                 exclude_background: bool = False):
193        super(CombinedclDiceLoss, self).__init__(num_iter=num_iter, eps=eps, exclude_background=exclude_background)
194
195        self.alpha = alpha
196        self.init_kwargs = {"num_iter": num_iter, "alpha": alpha, "eps": eps, "exclude_background": exclude_background}
197
198    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
199        """Compute combined soft Dice and clDice loss.
200
201        Args:
202            input_: The input logits.
203            target: The binary target.
204
205        Returns:
206            Combined clDice loss.
207        """
208        if input_.shape != target.shape:
209            raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
210
211        if self.exclude_background:
212            target = target[:, 1:, :, :]
213            input_ = input_[:, 1:, :, :]
214        dice = dice_score(input_, target, invert=True, channelwise=False, eps=self.eps)
215        cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps)
216
217        return (1.0-self.alpha)*dice+self.alpha*cldice

Combined soft-Dice and soft-clDice loss for segmentation of tubular structures.

The soft-clDice loss computes topology-aware loss by computing the
soft skeleton of both the prediction and target and measuring overlap
of the two skeletons. This encourages the model to preserve the
connectivity and topology of tubular structures. The final loss is a
weighted combination of soft Dice and clDice, controlled by alpha.
Arguments:
  • num_iter: Number of iterations for soft-skeletonization.
  • alpha: The weight for combining the soft Dice and soft clDice loss.
  • eps: The epsilon value added to the denominator for numerical stability.
  • exclude_background: Whether to exclude background channel 0 from the loss computation. Useful for multi-class segmentation.
  • invert: Not implemented; Whether to invert the returned dice score to obtain the dice error instead of the dice score.
  • channelwise: Not implemented; Whether to return the dice score independently per channel.
  • reduce_chnanel: Not implemented; How to return the dice score over the channel axis.
CombinedclDiceLoss( num_iter: int = 5, alpha: float = 0.5, eps: float = 1e-07, exclude_background: bool = False)
191    def __init__(self, num_iter: int = 5, alpha: float = 0.5, eps: float = 1e-7,
192                 exclude_background: bool = False):
193        super(CombinedclDiceLoss, self).__init__(num_iter=num_iter, eps=eps, exclude_background=exclude_background)
194
195        self.alpha = alpha
196        self.init_kwargs = {"num_iter": num_iter, "alpha": alpha, "eps": eps, "exclude_background": exclude_background}

Initialize internal Module state, shared by both nn.Module and ScriptModule.

alpha
init_kwargs
def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
198    def forward(self, input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
199        """Compute combined soft Dice and clDice loss.
200
201        Args:
202            input_: The input logits.
203            target: The binary target.
204
205        Returns:
206            Combined clDice loss.
207        """
208        if input_.shape != target.shape:
209            raise ValueError(f"Expect input and target of same shape, got: {input_.shape}, {target.shape}.")
210
211        if self.exclude_background:
212            target = target[:, 1:, :, :]
213            input_ = input_[:, 1:, :, :]
214        dice = dice_score(input_, target, invert=True, channelwise=False, eps=self.eps)
215        cldice = cldice_score(input_, target, num_iter=self.num_iter, invert=True, eps=self.eps)
216
217        return (1.0-self.alpha)*dice+self.alpha*cldice

Compute combined soft Dice and clDice loss.

Arguments:
  • input_: The input logits.
  • target: The binary target.
Returns:

Combined clDice loss.