torch_em.loss.wrapper
1from typing import Callable, Sequence, Tuple, Union 2 3import torch 4import torch.nn as nn 5 6 7class LossWrapper(nn.Module): 8 """A wrapper around a torch loss function. 9 10 Applies transformations to prediction and/or target before passing it to the loss. 11 12 Args: 13 loss: The loss function. 14 transform: The transformation applied to prediction and/or target. 15 Must take both the prediction and target as arguments and then return them, potentially transformed. 16 """ 17 def __init__(self, loss: nn.Module, transform: Callable): 18 super().__init__() 19 self.loss = loss 20 21 if not callable(transform): 22 raise ValueError("transform has to be callable.") 23 self.transform = transform 24 self.init_kwargs = {'loss': loss, 'transform': transform} 25 26 def apply_transform(self, prediction, target, **kwargs): 27 """@private 28 """ 29 # Check if the prediction and target are lists. 30 # If they are, apply the transform to each element individually. 31 if isinstance(prediction, (list, tuple)): 32 assert isinstance(target, (list, tuple)) 33 transformed_prediction, transformed_target = [], [] 34 for pred, targ in zip(prediction, target): 35 tr_pred, tr_targ = self.transform(pred, targ, **kwargs) 36 transformed_prediction.append(tr_pred) 37 transformed_target.append(tr_targ) 38 return transformed_prediction, transformed_target 39 # Otherwise, we expect that prediction and target are both tensors. 40 else: 41 prediction, target = self.transform(prediction, target, **kwargs) 42 return prediction, target 43 44 def forward( 45 self, 46 prediction: Union[Sequence[torch.Tensor], torch.Tensor], 47 target: Union[Sequence[torch.Tensor], torch.Tensor], 48 **kwargs 49 ) -> torch.Tensor: 50 """Apply the tranformations to prediction and/or target before computing the loss. 51 52 Args: 53 prediction: The prediction. 54 target: The target. 55 kwargs: Additional keyword arguments for the transformation. 56 57 Returns: 58 The loss. 59 """ 60 prediction, target = self.apply_transform(prediction, target, **kwargs) 61 loss = self.loss(prediction, target) 62 return loss 63 64 65# 66# Loss transformations 67# 68def _crop(prediction, target, mask, channel_dim): 69 if mask.shape[channel_dim] != 1: 70 raise ValueError( 71 "_crop only supports a mask with a singleton channel axis. Please consider using masking_method=multiply." 72 ) 73 mask = mask.type(torch.bool) 74 # remove singleton axis 75 mask = mask.squeeze(channel_dim) 76 # move channel axis to end 77 prediction = prediction.moveaxis(channel_dim, -1) 78 target = target.moveaxis(channel_dim, -1) 79 # output has shape N x C 80 # correct for torch_em.loss.dice.flatten_samples 81 return prediction[mask], target[mask] 82 83 84def _multiply(prediction, target, mask, channel_dim): 85 prediction = prediction * mask 86 target = target * mask 87 return prediction, target 88 89 90class ApplyMask: 91 """Apply a mask to prediction and target, so that only values in the mask are taken into account for the loss. 92 93 Supports two different masking methods: 94 - 'crop': Crop away the mask from the prediction and target. This only works if the mask just has a single channel, 95 and if the loss function does not require spatial inputs. 96 - 'multiply': Multiply the prediction and target with zeros outside of the mask. 97 98 Args: 99 masking_method: The masking method to use. Can be one of 'crop' or 'multiply'. 100 channel_dim: The dimension of the channel axis. 101 """ 102 MASKING_FUNCS = {"crop": _crop, "multiply": _multiply} 103 104 def __init__(self, masking_method: str = "crop", channel_dim: int = 1): 105 if masking_method not in self.MASKING_FUNCS.keys(): 106 raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.") 107 self.masking_func = self.MASKING_FUNCS[masking_method] 108 self.channel_dim = channel_dim 109 self.init_kwargs = {"masking_method": masking_method, "channel_dim": channel_dim} 110 111 def __call__( 112 self, prediction: torch.Tensor, target: torch.Tensor, mask: torch.Tensor 113 ) -> Tuple[torch.Tensor, torch.Tensor]: 114 """Mask predictions. 115 116 Args: 117 prediction: The prediction tensor. 118 target: The target tensor. 119 mask: The mask tensor. 120 121 Returns: 122 The masked prediction. 123 The masked target. 124 """ 125 mask.requires_grad = False 126 return self.masking_func(prediction, target, mask, self.channel_dim) 127 128 129class ApplyAndRemoveMask(ApplyMask): 130 """Extract mask from extra channels from a target tensor and use it to mask the prediction. 131 132 Supports the same masking methods as `ApplyMask`. 133 """ 134 def __call__(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 135 """Remove masking channels from the target and then apply the mask. 136 137 Args: 138 prediction: The prediction tensor. 139 target: The target tensor, with extra channels that contain the mask. 140 141 Returns: 142 The masked prediction. 143 The masked target, with extra channels removed. 144 """ 145 assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}" 146 assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}" 147 assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}" 148 seperating_channel = target.size(1) // 2 149 mask = target[:, seperating_channel:] 150 target = target[:, :seperating_channel] 151 prediction, target = super().__call__(prediction, target, mask) 152 return prediction, target 153 154 155class MaskIgnoreLabel(ApplyMask): 156 """Mask ignore label from the target. 157 158 Supports the same masking methods as `ApplyMask`. 159 160 Args: 161 ignore_label: The ignore label, which will be msaked. 162 masking_method: The masking method to use. Can be one of 'crop' or 'multiply'. 163 channel_dim: The dimension of the channel axis. 164 """ 165 def __init__(self, ignore_label: int = -1, masking_method: str = "crop", channel_dim: int = 1): 166 super().__init__(masking_method, channel_dim) 167 self.ignore_label = ignore_label 168 self.init_kwargs["ignore_label"] = ignore_label 169 170 def __call__(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 171 """Mask ignore label in the prediction and target. 172 173 Args: 174 prediction: The prediction tensor. 175 target: The target tensor. 176 177 Returns: 178 The masked prediction. 179 The masked target. 180 """ 181 mask = (target != self.ignore_label) 182 prediction, target = super().__call__(prediction, target, mask) 183 return prediction, target
class
LossWrapper(torch.nn.modules.module.Module):
8class LossWrapper(nn.Module): 9 """A wrapper around a torch loss function. 10 11 Applies transformations to prediction and/or target before passing it to the loss. 12 13 Args: 14 loss: The loss function. 15 transform: The transformation applied to prediction and/or target. 16 Must take both the prediction and target as arguments and then return them, potentially transformed. 17 """ 18 def __init__(self, loss: nn.Module, transform: Callable): 19 super().__init__() 20 self.loss = loss 21 22 if not callable(transform): 23 raise ValueError("transform has to be callable.") 24 self.transform = transform 25 self.init_kwargs = {'loss': loss, 'transform': transform} 26 27 def apply_transform(self, prediction, target, **kwargs): 28 """@private 29 """ 30 # Check if the prediction and target are lists. 31 # If they are, apply the transform to each element individually. 32 if isinstance(prediction, (list, tuple)): 33 assert isinstance(target, (list, tuple)) 34 transformed_prediction, transformed_target = [], [] 35 for pred, targ in zip(prediction, target): 36 tr_pred, tr_targ = self.transform(pred, targ, **kwargs) 37 transformed_prediction.append(tr_pred) 38 transformed_target.append(tr_targ) 39 return transformed_prediction, transformed_target 40 # Otherwise, we expect that prediction and target are both tensors. 41 else: 42 prediction, target = self.transform(prediction, target, **kwargs) 43 return prediction, target 44 45 def forward( 46 self, 47 prediction: Union[Sequence[torch.Tensor], torch.Tensor], 48 target: Union[Sequence[torch.Tensor], torch.Tensor], 49 **kwargs 50 ) -> torch.Tensor: 51 """Apply the tranformations to prediction and/or target before computing the loss. 52 53 Args: 54 prediction: The prediction. 55 target: The target. 56 kwargs: Additional keyword arguments for the transformation. 57 58 Returns: 59 The loss. 60 """ 61 prediction, target = self.apply_transform(prediction, target, **kwargs) 62 loss = self.loss(prediction, target) 63 return loss
A wrapper around a torch loss function.
Applies transformations to prediction and/or target before passing it to the loss.
Arguments:
- loss: The loss function.
- transform: The transformation applied to prediction and/or target. Must take both the prediction and target as arguments and then return them, potentially transformed.
LossWrapper(loss: torch.nn.modules.module.Module, transform: Callable)
18 def __init__(self, loss: nn.Module, transform: Callable): 19 super().__init__() 20 self.loss = loss 21 22 if not callable(transform): 23 raise ValueError("transform has to be callable.") 24 self.transform = transform 25 self.init_kwargs = {'loss': loss, 'transform': transform}
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
forward( self, prediction: Union[Sequence[torch.Tensor], torch.Tensor], target: Union[Sequence[torch.Tensor], torch.Tensor], **kwargs) -> torch.Tensor:
45 def forward( 46 self, 47 prediction: Union[Sequence[torch.Tensor], torch.Tensor], 48 target: Union[Sequence[torch.Tensor], torch.Tensor], 49 **kwargs 50 ) -> torch.Tensor: 51 """Apply the tranformations to prediction and/or target before computing the loss. 52 53 Args: 54 prediction: The prediction. 55 target: The target. 56 kwargs: Additional keyword arguments for the transformation. 57 58 Returns: 59 The loss. 60 """ 61 prediction, target = self.apply_transform(prediction, target, **kwargs) 62 loss = self.loss(prediction, target) 63 return loss
Apply the tranformations to prediction and/or target before computing the loss.
Arguments:
- prediction: The prediction.
- target: The target.
- kwargs: Additional keyword arguments for the transformation.
Returns:
The loss.
class
ApplyMask:
91class ApplyMask: 92 """Apply a mask to prediction and target, so that only values in the mask are taken into account for the loss. 93 94 Supports two different masking methods: 95 - 'crop': Crop away the mask from the prediction and target. This only works if the mask just has a single channel, 96 and if the loss function does not require spatial inputs. 97 - 'multiply': Multiply the prediction and target with zeros outside of the mask. 98 99 Args: 100 masking_method: The masking method to use. Can be one of 'crop' or 'multiply'. 101 channel_dim: The dimension of the channel axis. 102 """ 103 MASKING_FUNCS = {"crop": _crop, "multiply": _multiply} 104 105 def __init__(self, masking_method: str = "crop", channel_dim: int = 1): 106 if masking_method not in self.MASKING_FUNCS.keys(): 107 raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.") 108 self.masking_func = self.MASKING_FUNCS[masking_method] 109 self.channel_dim = channel_dim 110 self.init_kwargs = {"masking_method": masking_method, "channel_dim": channel_dim} 111 112 def __call__( 113 self, prediction: torch.Tensor, target: torch.Tensor, mask: torch.Tensor 114 ) -> Tuple[torch.Tensor, torch.Tensor]: 115 """Mask predictions. 116 117 Args: 118 prediction: The prediction tensor. 119 target: The target tensor. 120 mask: The mask tensor. 121 122 Returns: 123 The masked prediction. 124 The masked target. 125 """ 126 mask.requires_grad = False 127 return self.masking_func(prediction, target, mask, self.channel_dim)
Apply a mask to prediction and target, so that only values in the mask are taken into account for the loss.
Supports two different masking methods:
- 'crop': Crop away the mask from the prediction and target. This only works if the mask just has a single channel, and if the loss function does not require spatial inputs.
- 'multiply': Multiply the prediction and target with zeros outside of the mask.
Arguments:
- masking_method: The masking method to use. Can be one of 'crop' or 'multiply'.
- channel_dim: The dimension of the channel axis.
ApplyMask(masking_method: str = 'crop', channel_dim: int = 1)
105 def __init__(self, masking_method: str = "crop", channel_dim: int = 1): 106 if masking_method not in self.MASKING_FUNCS.keys(): 107 raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.") 108 self.masking_func = self.MASKING_FUNCS[masking_method] 109 self.channel_dim = channel_dim 110 self.init_kwargs = {"masking_method": masking_method, "channel_dim": channel_dim}
130class ApplyAndRemoveMask(ApplyMask): 131 """Extract mask from extra channels from a target tensor and use it to mask the prediction. 132 133 Supports the same masking methods as `ApplyMask`. 134 """ 135 def __call__(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 136 """Remove masking channels from the target and then apply the mask. 137 138 Args: 139 prediction: The prediction tensor. 140 target: The target tensor, with extra channels that contain the mask. 141 142 Returns: 143 The masked prediction. 144 The masked target, with extra channels removed. 145 """ 146 assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}" 147 assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}" 148 assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}" 149 seperating_channel = target.size(1) // 2 150 mask = target[:, seperating_channel:] 151 target = target[:, :seperating_channel] 152 prediction, target = super().__call__(prediction, target, mask) 153 return prediction, target
Extract mask from extra channels from a target tensor and use it to mask the prediction.
Supports the same masking methods as ApplyMask
.
Inherited Members
156class MaskIgnoreLabel(ApplyMask): 157 """Mask ignore label from the target. 158 159 Supports the same masking methods as `ApplyMask`. 160 161 Args: 162 ignore_label: The ignore label, which will be msaked. 163 masking_method: The masking method to use. Can be one of 'crop' or 'multiply'. 164 channel_dim: The dimension of the channel axis. 165 """ 166 def __init__(self, ignore_label: int = -1, masking_method: str = "crop", channel_dim: int = 1): 167 super().__init__(masking_method, channel_dim) 168 self.ignore_label = ignore_label 169 self.init_kwargs["ignore_label"] = ignore_label 170 171 def __call__(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 172 """Mask ignore label in the prediction and target. 173 174 Args: 175 prediction: The prediction tensor. 176 target: The target tensor. 177 178 Returns: 179 The masked prediction. 180 The masked target. 181 """ 182 mask = (target != self.ignore_label) 183 prediction, target = super().__call__(prediction, target, mask) 184 return prediction, target
Mask ignore label from the target.
Supports the same masking methods as ApplyMask
.
Arguments:
- ignore_label: The ignore label, which will be msaked.
- masking_method: The masking method to use. Can be one of 'crop' or 'multiply'.
- channel_dim: The dimension of the channel axis.