torch_em.loss.wrapper
1import torch 2import torch.nn as nn 3 4 5class LossWrapper(nn.Module): 6 """ Wrapper around a torch loss function. 7 8 Applies transformations to prediction and/or target before passing it to the loss. 9 """ 10 def __init__(self, loss, transform): 11 super().__init__() 12 self.loss = loss 13 14 if not callable(transform): 15 raise ValueError("transform has to be callable.") 16 self.transform = transform 17 self.init_kwargs = {'loss': loss, 'transform': transform} 18 19 def apply_transform(self, prediction, target, **kwargs): 20 # check if the tensors (prediction and target are lists) 21 # if they are, apply the transform to each element inidvidually 22 if isinstance(prediction, (list, tuple)): 23 assert isinstance(target, (list, tuple)) 24 transformed_prediction, transformed_target = [], [] 25 for pred, targ in zip(prediction, target): 26 tr_pred, tr_targ = self.transform(pred, targ, **kwargs) 27 transformed_prediction.append(tr_pred) 28 transformed_target.append(tr_targ) 29 return transformed_prediction, transformed_target 30 # tensor input 31 else: 32 prediction, target = self.transform(prediction, target, **kwargs) 33 return prediction, target 34 35 def forward(self, prediction, target, **kwargs): 36 prediction, target = self.apply_transform(prediction, target, **kwargs) 37 loss = self.loss(prediction, target) 38 return loss 39 40 41# 42# Loss transformations 43# 44 45 46class ApplyMask: 47 def _crop(prediction, target, mask, channel_dim): 48 if mask.shape[channel_dim] != 1: 49 raise ValueError( 50 "_crop only supports a mask with a singleton channel axis. \ 51 Please consider using masking_method=multiply." 52 ) 53 mask = mask.type(torch.bool) 54 # remove singleton axis 55 mask = mask.squeeze(channel_dim) 56 # move channel axis to end 57 prediction = prediction.moveaxis(channel_dim, -1) 58 target = target.moveaxis(channel_dim, -1) 59 # output has shape N x C 60 # correct for torch_em.loss.dice.flatten_samples 61 return prediction[mask], target[mask] 62 63 def _multiply(prediction, target, mask, channel_dim): 64 prediction = prediction * mask 65 target = target * mask 66 return prediction, target 67 68 MASKING_FUNCS = { 69 "crop": _crop, 70 "multiply": _multiply, 71 } 72 73 def __init__(self, masking_method="crop", channel_dim=1): 74 if masking_method not in self.MASKING_FUNCS.keys(): 75 raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.") 76 self.masking_func = self.MASKING_FUNCS[masking_method] 77 self.channel_dim = channel_dim 78 79 self.init_kwargs = { 80 "masking_method": masking_method, 81 "channel_dim": channel_dim, 82 } 83 84 def __call__(self, prediction, target, mask): 85 mask.requires_grad = False 86 return self.masking_func(prediction, target, mask, self.channel_dim) 87 88 89class ApplyAndRemoveMask(ApplyMask): 90 def __call__(self, prediction, target): 91 assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}" 92 assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}" 93 assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}" 94 seperating_channel = target.size(1) // 2 95 mask = target[:, seperating_channel:] 96 target = target[:, :seperating_channel] 97 prediction, target = super().__call__(prediction, target, mask) 98 return prediction, target 99 100 101class MaskIgnoreLabel(ApplyMask): 102 def __init__(self, ignore_label=-1, masking_method="crop", channel_dim=1): 103 super().__init__(masking_method, channel_dim) 104 self.ignore_label = ignore_label 105 self.init_kwargs["ignore_label"] = ignore_label 106 107 def __call__(self, prediction, target): 108 mask = (target != self.ignore_label) 109 prediction, target = super().__call__(prediction, target, mask) 110 return prediction, target
class
LossWrapper(torch.nn.modules.module.Module):
6class LossWrapper(nn.Module): 7 """ Wrapper around a torch loss function. 8 9 Applies transformations to prediction and/or target before passing it to the loss. 10 """ 11 def __init__(self, loss, transform): 12 super().__init__() 13 self.loss = loss 14 15 if not callable(transform): 16 raise ValueError("transform has to be callable.") 17 self.transform = transform 18 self.init_kwargs = {'loss': loss, 'transform': transform} 19 20 def apply_transform(self, prediction, target, **kwargs): 21 # check if the tensors (prediction and target are lists) 22 # if they are, apply the transform to each element inidvidually 23 if isinstance(prediction, (list, tuple)): 24 assert isinstance(target, (list, tuple)) 25 transformed_prediction, transformed_target = [], [] 26 for pred, targ in zip(prediction, target): 27 tr_pred, tr_targ = self.transform(pred, targ, **kwargs) 28 transformed_prediction.append(tr_pred) 29 transformed_target.append(tr_targ) 30 return transformed_prediction, transformed_target 31 # tensor input 32 else: 33 prediction, target = self.transform(prediction, target, **kwargs) 34 return prediction, target 35 36 def forward(self, prediction, target, **kwargs): 37 prediction, target = self.apply_transform(prediction, target, **kwargs) 38 loss = self.loss(prediction, target) 39 return loss
Wrapper around a torch loss function.
Applies transformations to prediction and/or target before passing it to the loss.
LossWrapper(loss, transform)
11 def __init__(self, loss, transform): 12 super().__init__() 13 self.loss = loss 14 15 if not callable(transform): 16 raise ValueError("transform has to be callable.") 17 self.transform = transform 18 self.init_kwargs = {'loss': loss, 'transform': transform}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
def
apply_transform(self, prediction, target, **kwargs):
20 def apply_transform(self, prediction, target, **kwargs): 21 # check if the tensors (prediction and target are lists) 22 # if they are, apply the transform to each element inidvidually 23 if isinstance(prediction, (list, tuple)): 24 assert isinstance(target, (list, tuple)) 25 transformed_prediction, transformed_target = [], [] 26 for pred, targ in zip(prediction, target): 27 tr_pred, tr_targ = self.transform(pred, targ, **kwargs) 28 transformed_prediction.append(tr_pred) 29 transformed_target.append(tr_targ) 30 return transformed_prediction, transformed_target 31 # tensor input 32 else: 33 prediction, target = self.transform(prediction, target, **kwargs) 34 return prediction, target
def
forward(self, prediction, target, **kwargs):
36 def forward(self, prediction, target, **kwargs): 37 prediction, target = self.apply_transform(prediction, target, **kwargs) 38 loss = self.loss(prediction, target) 39 return loss
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
class
ApplyMask:
47class ApplyMask: 48 def _crop(prediction, target, mask, channel_dim): 49 if mask.shape[channel_dim] != 1: 50 raise ValueError( 51 "_crop only supports a mask with a singleton channel axis. \ 52 Please consider using masking_method=multiply." 53 ) 54 mask = mask.type(torch.bool) 55 # remove singleton axis 56 mask = mask.squeeze(channel_dim) 57 # move channel axis to end 58 prediction = prediction.moveaxis(channel_dim, -1) 59 target = target.moveaxis(channel_dim, -1) 60 # output has shape N x C 61 # correct for torch_em.loss.dice.flatten_samples 62 return prediction[mask], target[mask] 63 64 def _multiply(prediction, target, mask, channel_dim): 65 prediction = prediction * mask 66 target = target * mask 67 return prediction, target 68 69 MASKING_FUNCS = { 70 "crop": _crop, 71 "multiply": _multiply, 72 } 73 74 def __init__(self, masking_method="crop", channel_dim=1): 75 if masking_method not in self.MASKING_FUNCS.keys(): 76 raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.") 77 self.masking_func = self.MASKING_FUNCS[masking_method] 78 self.channel_dim = channel_dim 79 80 self.init_kwargs = { 81 "masking_method": masking_method, 82 "channel_dim": channel_dim, 83 } 84 85 def __call__(self, prediction, target, mask): 86 mask.requires_grad = False 87 return self.masking_func(prediction, target, mask, self.channel_dim)
ApplyMask(masking_method='crop', channel_dim=1)
74 def __init__(self, masking_method="crop", channel_dim=1): 75 if masking_method not in self.MASKING_FUNCS.keys(): 76 raise ValueError(f"{masking_method} is not available, please use one of {list(self.MASKING_FUNCS.keys())}.") 77 self.masking_func = self.MASKING_FUNCS[masking_method] 78 self.channel_dim = channel_dim 79 80 self.init_kwargs = { 81 "masking_method": masking_method, 82 "channel_dim": channel_dim, 83 }
90class ApplyAndRemoveMask(ApplyMask): 91 def __call__(self, prediction, target): 92 assert target.dim() == prediction.dim(), f"{target.dim()}, {prediction.dim()}" 93 assert target.size(1) == 2 * prediction.size(1), f"{target.size(1)}, {prediction.size(1)}" 94 assert target.shape[2:] == prediction.shape[2:], f"{str(target.shape)}, {str(prediction.shape)}" 95 seperating_channel = target.size(1) // 2 96 mask = target[:, seperating_channel:] 97 target = target[:, :seperating_channel] 98 prediction, target = super().__call__(prediction, target, mask) 99 return prediction, target
Inherited Members
102class MaskIgnoreLabel(ApplyMask): 103 def __init__(self, ignore_label=-1, masking_method="crop", channel_dim=1): 104 super().__init__(masking_method, channel_dim) 105 self.ignore_label = ignore_label 106 self.init_kwargs["ignore_label"] = ignore_label 107 108 def __call__(self, prediction, target): 109 mask = (target != self.ignore_label) 110 prediction, target = super().__call__(prediction, target, mask) 111 return prediction, target