torch_em.loss.combined_loss
1from typing import List 2 3import torch 4 5 6class CombinedLoss(torch.nn.Module): 7 """Combination of multiple losses. 8 9 Args: 10 losses: The loss functions to combine. 11 loss_weights: The weights for the loss functions. 12 """ 13 def __init__(self, *losses: torch.nn.Module, loss_weights: List[float] = None): 14 super().__init__() 15 self.losses = torch.nn.ModuleList(losses) 16 n_losses = len(self.losses) 17 if loss_weights is None: 18 try: 19 self.loss_weights = [1.0 / n_losses] * n_losses 20 except ZeroDivisionError: 21 self.loss_weights = None 22 else: 23 assert len(loss_weights) == n_losses 24 self.loss_weights = loss_weights 25 26 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 27 """Compute the combined loss. 28 29 Args: 30 x: The prediction. 31 y: The target. 32 33 Returns: 34 The loss value. 35 """ 36 assert self.loss_weights is not None 37 loss_value = sum([loss(x, y) * weight for loss, weight in zip(self.losses, self.loss_weights)]) 38 return loss_value
class
CombinedLoss(torch.nn.modules.module.Module):
7class CombinedLoss(torch.nn.Module): 8 """Combination of multiple losses. 9 10 Args: 11 losses: The loss functions to combine. 12 loss_weights: The weights for the loss functions. 13 """ 14 def __init__(self, *losses: torch.nn.Module, loss_weights: List[float] = None): 15 super().__init__() 16 self.losses = torch.nn.ModuleList(losses) 17 n_losses = len(self.losses) 18 if loss_weights is None: 19 try: 20 self.loss_weights = [1.0 / n_losses] * n_losses 21 except ZeroDivisionError: 22 self.loss_weights = None 23 else: 24 assert len(loss_weights) == n_losses 25 self.loss_weights = loss_weights 26 27 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 28 """Compute the combined loss. 29 30 Args: 31 x: The prediction. 32 y: The target. 33 34 Returns: 35 The loss value. 36 """ 37 assert self.loss_weights is not None 38 loss_value = sum([loss(x, y) * weight for loss, weight in zip(self.losses, self.loss_weights)]) 39 return loss_value
Combination of multiple losses.
Arguments:
- losses: The loss functions to combine.
- loss_weights: The weights for the loss functions.
CombinedLoss( *losses: torch.nn.modules.module.Module, loss_weights: List[float] = None)
14 def __init__(self, *losses: torch.nn.Module, loss_weights: List[float] = None): 15 super().__init__() 16 self.losses = torch.nn.ModuleList(losses) 17 n_losses = len(self.losses) 18 if loss_weights is None: 19 try: 20 self.loss_weights = [1.0 / n_losses] * n_losses 21 except ZeroDivisionError: 22 self.loss_weights = None 23 else: 24 assert len(loss_weights) == n_losses 25 self.loss_weights = loss_weights
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
27 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 28 """Compute the combined loss. 29 30 Args: 31 x: The prediction. 32 y: The target. 33 34 Returns: 35 The loss value. 36 """ 37 assert self.loss_weights is not None 38 loss_value = sum([loss(x, y) * weight for loss, weight in zip(self.losses, self.loss_weights)]) 39 return loss_value
Compute the combined loss.
Arguments:
- x: The prediction.
- y: The target.
Returns:
The loss value.