torch_em.loss.distance_based

 1import torch.nn as nn
 2
 3from .dice import DiceLoss
 4
 5
 6class DistanceLoss(nn.Module):
 7    """Loss for distance based instance segmentation.
 8
 9    Expects input and targets with three channels: foreground and two distance channels.
10    Typically the distance channels are centroid and inverted boundary distance.
11
12    Args:
13        mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
14        foreground_loss: the loss for comparing foreground predictions and target.
15        distance_loss: the loss for comparing distance predictions and target.
16    """
17    def __init__(
18        self,
19        mask_distances_in_bg: bool = True,
20        foreground_loss: nn.Module = DiceLoss(),
21        distance_loss: nn.Module = nn.MSELoss(reduction="mean")
22    ) -> None:
23        super().__init__()
24
25        self.foreground_loss = foreground_loss
26        self.distance_loss = distance_loss
27        self.mask_distances_in_bg = mask_distances_in_bg
28
29        self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg}
30
31    def forward(self, input_, target):
32        assert input_.shape == target.shape, input_.shape
33        assert input_.shape[1] == 3, input_.shape
34
35        # IMPORTANT: preserve the channels!
36        # Otherwise the Dice Loss will do all kinds of shennanigans.
37        # Because it always interprets the first axis as channel,
38        # and treats it differently (sums over it independently).
39        # This will lead to a very large dice loss that dominates over everything else.
40        fg_input, fg_target = input_[:, 0:1], target[:, 0:1]
41        fg_loss = self.foreground_loss(fg_input, fg_target)
42
43        cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2]
44        if self.mask_distances_in_bg:
45            mask = fg_target
46            cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask)
47        else:
48            cdist_loss = self.distance_loss(cdist_input, cdist_target)
49
50        bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3]
51        if self.mask_distances_in_bg:
52            mask = fg_target
53            bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask)
54        else:
55            bdist_loss = self.distance_loss(bdist_input, bdist_target)
56
57        overall_loss = fg_loss + cdist_loss + bdist_loss
58        return overall_loss
59
60
61class DiceBasedDistanceLoss(DistanceLoss):
62    """Similar to DistanceLoss and uses dice for all losses.
63    """
64    def __init__(self, mask_distances_in_bg: bool) -> None:
65        super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())
class DistanceLoss(torch.nn.modules.module.Module):
 7class DistanceLoss(nn.Module):
 8    """Loss for distance based instance segmentation.
 9
10    Expects input and targets with three channels: foreground and two distance channels.
11    Typically the distance channels are centroid and inverted boundary distance.
12
13    Args:
14        mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
15        foreground_loss: the loss for comparing foreground predictions and target.
16        distance_loss: the loss for comparing distance predictions and target.
17    """
18    def __init__(
19        self,
20        mask_distances_in_bg: bool = True,
21        foreground_loss: nn.Module = DiceLoss(),
22        distance_loss: nn.Module = nn.MSELoss(reduction="mean")
23    ) -> None:
24        super().__init__()
25
26        self.foreground_loss = foreground_loss
27        self.distance_loss = distance_loss
28        self.mask_distances_in_bg = mask_distances_in_bg
29
30        self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg}
31
32    def forward(self, input_, target):
33        assert input_.shape == target.shape, input_.shape
34        assert input_.shape[1] == 3, input_.shape
35
36        # IMPORTANT: preserve the channels!
37        # Otherwise the Dice Loss will do all kinds of shennanigans.
38        # Because it always interprets the first axis as channel,
39        # and treats it differently (sums over it independently).
40        # This will lead to a very large dice loss that dominates over everything else.
41        fg_input, fg_target = input_[:, 0:1], target[:, 0:1]
42        fg_loss = self.foreground_loss(fg_input, fg_target)
43
44        cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2]
45        if self.mask_distances_in_bg:
46            mask = fg_target
47            cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask)
48        else:
49            cdist_loss = self.distance_loss(cdist_input, cdist_target)
50
51        bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3]
52        if self.mask_distances_in_bg:
53            mask = fg_target
54            bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask)
55        else:
56            bdist_loss = self.distance_loss(bdist_input, bdist_target)
57
58        overall_loss = fg_loss + cdist_loss + bdist_loss
59        return overall_loss

Loss for distance based instance segmentation.

Expects input and targets with three channels: foreground and two distance channels. Typically the distance channels are centroid and inverted boundary distance.

Arguments:
  • mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
  • foreground_loss: the loss for comparing foreground predictions and target.
  • distance_loss: the loss for comparing distance predictions and target.
DistanceLoss( mask_distances_in_bg: bool = True, foreground_loss: torch.nn.modules.module.Module = DiceLoss(), distance_loss: torch.nn.modules.module.Module = MSELoss())
18    def __init__(
19        self,
20        mask_distances_in_bg: bool = True,
21        foreground_loss: nn.Module = DiceLoss(),
22        distance_loss: nn.Module = nn.MSELoss(reduction="mean")
23    ) -> None:
24        super().__init__()
25
26        self.foreground_loss = foreground_loss
27        self.distance_loss = distance_loss
28        self.mask_distances_in_bg = mask_distances_in_bg
29
30        self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

foreground_loss
distance_loss
mask_distances_in_bg
init_kwargs
def forward(self, input_, target):
32    def forward(self, input_, target):
33        assert input_.shape == target.shape, input_.shape
34        assert input_.shape[1] == 3, input_.shape
35
36        # IMPORTANT: preserve the channels!
37        # Otherwise the Dice Loss will do all kinds of shennanigans.
38        # Because it always interprets the first axis as channel,
39        # and treats it differently (sums over it independently).
40        # This will lead to a very large dice loss that dominates over everything else.
41        fg_input, fg_target = input_[:, 0:1], target[:, 0:1]
42        fg_loss = self.foreground_loss(fg_input, fg_target)
43
44        cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2]
45        if self.mask_distances_in_bg:
46            mask = fg_target
47            cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask)
48        else:
49            cdist_loss = self.distance_loss(cdist_input, cdist_target)
50
51        bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3]
52        if self.mask_distances_in_bg:
53            mask = fg_target
54            bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask)
55        else:
56            bdist_loss = self.distance_loss(bdist_input, bdist_target)
57
58        overall_loss = fg_loss + cdist_loss + bdist_loss
59        return overall_loss

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class DiceBasedDistanceLoss(DistanceLoss):
62class DiceBasedDistanceLoss(DistanceLoss):
63    """Similar to DistanceLoss and uses dice for all losses.
64    """
65    def __init__(self, mask_distances_in_bg: bool) -> None:
66        super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())

Similar to DistanceLoss and uses dice for all losses.

DiceBasedDistanceLoss(mask_distances_in_bg: bool)
65    def __init__(self, mask_distances_in_bg: bool) -> None:
66        super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
DistanceLoss
foreground_loss
distance_loss
mask_distances_in_bg
init_kwargs
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile