torch_em.loss.distance_based
1import torch.nn as nn 2 3from .dice import DiceLoss 4 5 6class DistanceLoss(nn.Module): 7 """Loss for distance based instance segmentation. 8 9 Expects input and targets with three channels: foreground and two distance channels. 10 Typically the distance channels are centroid and inverted boundary distance. 11 12 Args: 13 mask_distances_in_bg: whether to mask the loss for distance predictions in the background. 14 foreground_loss: the loss for comparing foreground predictions and target. 15 distance_loss: the loss for comparing distance predictions and target. 16 """ 17 def __init__( 18 self, 19 mask_distances_in_bg: bool = True, 20 foreground_loss: nn.Module = DiceLoss(), 21 distance_loss: nn.Module = nn.MSELoss(reduction="mean") 22 ) -> None: 23 super().__init__() 24 25 self.foreground_loss = foreground_loss 26 self.distance_loss = distance_loss 27 self.mask_distances_in_bg = mask_distances_in_bg 28 29 self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg} 30 31 def forward(self, input_, target): 32 assert input_.shape == target.shape, input_.shape 33 assert input_.shape[1] == 3, input_.shape 34 35 # IMPORTANT: preserve the channels! 36 # Otherwise the Dice Loss will do all kinds of shennanigans. 37 # Because it always interprets the first axis as channel, 38 # and treats it differently (sums over it independently). 39 # This will lead to a very large dice loss that dominates over everything else. 40 fg_input, fg_target = input_[:, 0:1], target[:, 0:1] 41 fg_loss = self.foreground_loss(fg_input, fg_target) 42 43 cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2] 44 if self.mask_distances_in_bg: 45 mask = fg_target 46 cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask) 47 else: 48 cdist_loss = self.distance_loss(cdist_input, cdist_target) 49 50 bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3] 51 if self.mask_distances_in_bg: 52 mask = fg_target 53 bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask) 54 else: 55 bdist_loss = self.distance_loss(bdist_input, bdist_target) 56 57 overall_loss = fg_loss + cdist_loss + bdist_loss 58 return overall_loss 59 60 61class DiceBasedDistanceLoss(DistanceLoss): 62 """Similar to DistanceLoss and uses dice for all losses. 63 """ 64 def __init__(self, mask_distances_in_bg: bool) -> None: 65 super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())
class
DistanceLoss(torch.nn.modules.module.Module):
7class DistanceLoss(nn.Module): 8 """Loss for distance based instance segmentation. 9 10 Expects input and targets with three channels: foreground and two distance channels. 11 Typically the distance channels are centroid and inverted boundary distance. 12 13 Args: 14 mask_distances_in_bg: whether to mask the loss for distance predictions in the background. 15 foreground_loss: the loss for comparing foreground predictions and target. 16 distance_loss: the loss for comparing distance predictions and target. 17 """ 18 def __init__( 19 self, 20 mask_distances_in_bg: bool = True, 21 foreground_loss: nn.Module = DiceLoss(), 22 distance_loss: nn.Module = nn.MSELoss(reduction="mean") 23 ) -> None: 24 super().__init__() 25 26 self.foreground_loss = foreground_loss 27 self.distance_loss = distance_loss 28 self.mask_distances_in_bg = mask_distances_in_bg 29 30 self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg} 31 32 def forward(self, input_, target): 33 assert input_.shape == target.shape, input_.shape 34 assert input_.shape[1] == 3, input_.shape 35 36 # IMPORTANT: preserve the channels! 37 # Otherwise the Dice Loss will do all kinds of shennanigans. 38 # Because it always interprets the first axis as channel, 39 # and treats it differently (sums over it independently). 40 # This will lead to a very large dice loss that dominates over everything else. 41 fg_input, fg_target = input_[:, 0:1], target[:, 0:1] 42 fg_loss = self.foreground_loss(fg_input, fg_target) 43 44 cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2] 45 if self.mask_distances_in_bg: 46 mask = fg_target 47 cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask) 48 else: 49 cdist_loss = self.distance_loss(cdist_input, cdist_target) 50 51 bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3] 52 if self.mask_distances_in_bg: 53 mask = fg_target 54 bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask) 55 else: 56 bdist_loss = self.distance_loss(bdist_input, bdist_target) 57 58 overall_loss = fg_loss + cdist_loss + bdist_loss 59 return overall_loss
Loss for distance based instance segmentation.
Expects input and targets with three channels: foreground and two distance channels. Typically the distance channels are centroid and inverted boundary distance.
Arguments:
- mask_distances_in_bg: whether to mask the loss for distance predictions in the background.
- foreground_loss: the loss for comparing foreground predictions and target.
- distance_loss: the loss for comparing distance predictions and target.
DistanceLoss( mask_distances_in_bg: bool = True, foreground_loss: torch.nn.modules.module.Module = DiceLoss(), distance_loss: torch.nn.modules.module.Module = MSELoss())
18 def __init__( 19 self, 20 mask_distances_in_bg: bool = True, 21 foreground_loss: nn.Module = DiceLoss(), 22 distance_loss: nn.Module = nn.MSELoss(reduction="mean") 23 ) -> None: 24 super().__init__() 25 26 self.foreground_loss = foreground_loss 27 self.distance_loss = distance_loss 28 self.mask_distances_in_bg = mask_distances_in_bg 29 30 self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
def
forward(self, input_, target):
32 def forward(self, input_, target): 33 assert input_.shape == target.shape, input_.shape 34 assert input_.shape[1] == 3, input_.shape 35 36 # IMPORTANT: preserve the channels! 37 # Otherwise the Dice Loss will do all kinds of shennanigans. 38 # Because it always interprets the first axis as channel, 39 # and treats it differently (sums over it independently). 40 # This will lead to a very large dice loss that dominates over everything else. 41 fg_input, fg_target = input_[:, 0:1], target[:, 0:1] 42 fg_loss = self.foreground_loss(fg_input, fg_target) 43 44 cdist_input, cdist_target = input_[:, 1:2], target[:, 1:2] 45 if self.mask_distances_in_bg: 46 mask = fg_target 47 cdist_loss = self.distance_loss(cdist_input * mask, cdist_target * mask) 48 else: 49 cdist_loss = self.distance_loss(cdist_input, cdist_target) 50 51 bdist_input, bdist_target = input_[:, 2:3], target[:, 2:3] 52 if self.mask_distances_in_bg: 53 mask = fg_target 54 bdist_loss = self.distance_loss(bdist_input * mask, bdist_target * mask) 55 else: 56 bdist_loss = self.distance_loss(bdist_input, bdist_target) 57 58 overall_loss = fg_loss + cdist_loss + bdist_loss 59 return overall_loss
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
62class DiceBasedDistanceLoss(DistanceLoss): 63 """Similar to DistanceLoss and uses dice for all losses. 64 """ 65 def __init__(self, mask_distances_in_bg: bool) -> None: 66 super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())
Similar to DistanceLoss and uses dice for all losses.
DiceBasedDistanceLoss(mask_distances_in_bg: bool)
65 def __init__(self, mask_distances_in_bg: bool) -> None: 66 super().__init__(mask_distances_in_bg, foreground_loss=DiceLoss(), distance_loss=DiceLoss())
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile