torch_em.loss.affinity_side_loss
1import numpy as np 2import torch 3import torch.nn as nn 4from .dice import dice_score 5 6 7def shift_tensor(tensor, offset): 8 """ Shift a tensor by the given (spatial) offset. 9 Arguments: 10 tensor [torch.Tensor] - 4D (=2 spatial dims) or 5D (=3 spatial dims) tensor. 11 Needs to be of float type. 12 offset (tuple) - 2d or 3d spatial offset used for shifting the tensor 13 """ 14 15 ndim = len(offset) 16 assert ndim in (2, 3) 17 diff = tensor.dim() - ndim 18 19 # don't pad for the first dimensions 20 # (usually batch and/or channel dimension) 21 slice_ = diff * [slice(None)] 22 23 # torch padding behaviour is a bit weird. 24 # we use nn.ReplicationPadND 25 # (torch.nn.functional.pad is even weirder and ReflectionPad is not supported in 3d) 26 # still, padding needs to be given in the inverse spatial order 27 28 # add padding in inverse spatial order 29 padding = [] 30 for off in offset[::-1]: 31 # if we have a negative offset, we need to shift "to the left", 32 # which means padding at the right border 33 # if we have a positive offset, we need to shift "to the right", 34 # which means padding to the left border 35 padding.extend([max(0, off), max(0, -off)]) 36 37 # add slicing in the normal spatial order 38 for off in offset: 39 if off == 0: 40 slice_.append(slice(None)) 41 elif off > 0: 42 slice_.append(slice(None, -off)) 43 else: 44 slice_.append(slice(-off, None)) 45 46 # pad the spatial part of the tensor with replication padding 47 slice_ = tuple(slice_) 48 padding = tuple(padding) 49 padder = nn.ReplicationPad2d if ndim == 2 else nn.ReplicationPad3d 50 padder = padder(padding) 51 shifted = padder(tensor) 52 53 # slice the oadded tensor to get the spatially shifted tensor 54 shifted = shifted[slice_] 55 assert shifted.shape == tensor.shape 56 57 return shifted 58 59 60def invert_offsets(offsets): 61 return [[-off for off in offset] for offset in offsets] 62 63 64def segmentation_to_affinities(segmentation, offsets): 65 """ Transform segmentation to affinities. 66 Arguments: 67 segmentation [torch.tensor] - 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor. 68 The channel axis (= dimension 1) needs to be a singleton. 69 offsets [list[tuple]] - list of offsets for which to compute the affinities. 70 """ 71 assert segmentation.shape[1] == 1, f"{segmentation.shape}" 72 # shift the segmentation and substract the shifted tensor from the segmentation 73 # we need to shift in the opposite direction of the offsets, so we invert them 74 # before applying the shift 75 offsets_ = invert_offsets(offsets) 76 shifted = torch.cat([shift_tensor(segmentation.float(), off) for off in offsets_], dim=1) 77 affs = (segmentation - shifted) 78 # the affinities are 1, where we had the same segment id (the difference is 0) 79 # and 0 otherwise 80 affs.eq_(0.) 81 return affs 82 83 84def embeddings_to_affinities(embeddings, offsets, delta): 85 """ Transform embeddings to affinities. 86 """ 87 # shift the embeddings by the offsets and stack them along a new axis 88 # we need to shift in the opposite direction of the offsets, so we invert them 89 # before applying the shift 90 offsets_ = invert_offsets(offsets) 91 shifted = torch.cat([shift_tensor(embeddings, off).unsqueeze(1) for off in offsets_], dim=1) 92 # substract the embeddings from the shifted embeddings, take the norm and 93 # transform to affinities based on the delta distance 94 affs = (2 * delta - torch.norm(embeddings.unsqueeze(1) - shifted, dim=2)) / (2 * delta) 95 affs = torch.clamp(affs, min=0) ** 2 96 return affs 97 98 99class AffinitySideLoss(nn.Module): 100 def __init__(self, offset_ranges, n_samples, delta): 101 assert all(len(orange) == 2 for orange in offset_ranges) 102 super().__init__() 103 self.ndim = len(offset_ranges) 104 self.offset_ranges = offset_ranges 105 self.n_samples = n_samples 106 self.delta = delta 107 108 def __call__( 109 self, 110 input_, 111 target, 112 ignore_labels=None, 113 ignore_in_variance_term=None, 114 ignore_in_distance_term=None, 115 ): 116 assert input_.dim() == target.dim(), f"{input_.dim()}, {target.dim()}" 117 assert input_.shape[2:] == target.shape[2:] 118 119 # sample offsets 120 offsets = [[np.random.randint(orange[0], orange[1]) for orange in self.offset_ranges] 121 for _ in range(self.n_samples)] 122 123 # we invert the affinities and the target affinities 124 # so that we get boundaries as foreground, which is benefitial for the dice loss. 125 # compute affinities from emebeddings 126 affs = 1. - embeddings_to_affinities(input_, offsets, self.delta) 127 128 # compute groundtruth affinities from target 129 target_affs = 1. - segmentation_to_affinities(target, offsets) 130 assert affs.shape == target_affs.shape, f"{affs.shape}, {target_affs.shape}" 131 132 # TODO implement masking the ignore labels 133 # compute the dice score between affinities and target affinities 134 loss = dice_score(affs, target_affs, invert=True) 135 return loss
8def shift_tensor(tensor, offset): 9 """ Shift a tensor by the given (spatial) offset. 10 Arguments: 11 tensor [torch.Tensor] - 4D (=2 spatial dims) or 5D (=3 spatial dims) tensor. 12 Needs to be of float type. 13 offset (tuple) - 2d or 3d spatial offset used for shifting the tensor 14 """ 15 16 ndim = len(offset) 17 assert ndim in (2, 3) 18 diff = tensor.dim() - ndim 19 20 # don't pad for the first dimensions 21 # (usually batch and/or channel dimension) 22 slice_ = diff * [slice(None)] 23 24 # torch padding behaviour is a bit weird. 25 # we use nn.ReplicationPadND 26 # (torch.nn.functional.pad is even weirder and ReflectionPad is not supported in 3d) 27 # still, padding needs to be given in the inverse spatial order 28 29 # add padding in inverse spatial order 30 padding = [] 31 for off in offset[::-1]: 32 # if we have a negative offset, we need to shift "to the left", 33 # which means padding at the right border 34 # if we have a positive offset, we need to shift "to the right", 35 # which means padding to the left border 36 padding.extend([max(0, off), max(0, -off)]) 37 38 # add slicing in the normal spatial order 39 for off in offset: 40 if off == 0: 41 slice_.append(slice(None)) 42 elif off > 0: 43 slice_.append(slice(None, -off)) 44 else: 45 slice_.append(slice(-off, None)) 46 47 # pad the spatial part of the tensor with replication padding 48 slice_ = tuple(slice_) 49 padding = tuple(padding) 50 padder = nn.ReplicationPad2d if ndim == 2 else nn.ReplicationPad3d 51 padder = padder(padding) 52 shifted = padder(tensor) 53 54 # slice the oadded tensor to get the spatially shifted tensor 55 shifted = shifted[slice_] 56 assert shifted.shape == tensor.shape 57 58 return shifted
Shift a tensor by the given (spatial) offset.
Arguments:
- tensor [torch.Tensor] - 4D (=2 spatial dims) or 5D (=3 spatial dims) tensor. Needs to be of float type.
- offset (tuple) - 2d or 3d spatial offset used for shifting the tensor
65def segmentation_to_affinities(segmentation, offsets): 66 """ Transform segmentation to affinities. 67 Arguments: 68 segmentation [torch.tensor] - 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor. 69 The channel axis (= dimension 1) needs to be a singleton. 70 offsets [list[tuple]] - list of offsets for which to compute the affinities. 71 """ 72 assert segmentation.shape[1] == 1, f"{segmentation.shape}" 73 # shift the segmentation and substract the shifted tensor from the segmentation 74 # we need to shift in the opposite direction of the offsets, so we invert them 75 # before applying the shift 76 offsets_ = invert_offsets(offsets) 77 shifted = torch.cat([shift_tensor(segmentation.float(), off) for off in offsets_], dim=1) 78 affs = (segmentation - shifted) 79 # the affinities are 1, where we had the same segment id (the difference is 0) 80 # and 0 otherwise 81 affs.eq_(0.) 82 return affs
Transform segmentation to affinities.
Arguments:
- segmentation [torch.tensor] - 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor. The channel axis (= dimension 1) needs to be a singleton.
- offsets [list[tuple]] - list of offsets for which to compute the affinities.
85def embeddings_to_affinities(embeddings, offsets, delta): 86 """ Transform embeddings to affinities. 87 """ 88 # shift the embeddings by the offsets and stack them along a new axis 89 # we need to shift in the opposite direction of the offsets, so we invert them 90 # before applying the shift 91 offsets_ = invert_offsets(offsets) 92 shifted = torch.cat([shift_tensor(embeddings, off).unsqueeze(1) for off in offsets_], dim=1) 93 # substract the embeddings from the shifted embeddings, take the norm and 94 # transform to affinities based on the delta distance 95 affs = (2 * delta - torch.norm(embeddings.unsqueeze(1) - shifted, dim=2)) / (2 * delta) 96 affs = torch.clamp(affs, min=0) ** 2 97 return affs
Transform embeddings to affinities.
100class AffinitySideLoss(nn.Module): 101 def __init__(self, offset_ranges, n_samples, delta): 102 assert all(len(orange) == 2 for orange in offset_ranges) 103 super().__init__() 104 self.ndim = len(offset_ranges) 105 self.offset_ranges = offset_ranges 106 self.n_samples = n_samples 107 self.delta = delta 108 109 def __call__( 110 self, 111 input_, 112 target, 113 ignore_labels=None, 114 ignore_in_variance_term=None, 115 ignore_in_distance_term=None, 116 ): 117 assert input_.dim() == target.dim(), f"{input_.dim()}, {target.dim()}" 118 assert input_.shape[2:] == target.shape[2:] 119 120 # sample offsets 121 offsets = [[np.random.randint(orange[0], orange[1]) for orange in self.offset_ranges] 122 for _ in range(self.n_samples)] 123 124 # we invert the affinities and the target affinities 125 # so that we get boundaries as foreground, which is benefitial for the dice loss. 126 # compute affinities from emebeddings 127 affs = 1. - embeddings_to_affinities(input_, offsets, self.delta) 128 129 # compute groundtruth affinities from target 130 target_affs = 1. - segmentation_to_affinities(target, offsets) 131 assert affs.shape == target_affs.shape, f"{affs.shape}, {target_affs.shape}" 132 133 # TODO implement masking the ignore labels 134 # compute the dice score between affinities and target affinities 135 loss = dice_score(affs, target_affs, invert=True) 136 return loss
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
101 def __init__(self, offset_ranges, n_samples, delta): 102 assert all(len(orange) == 2 for orange in offset_ranges) 103 super().__init__() 104 self.ndim = len(offset_ranges) 105 self.offset_ranges = offset_ranges 106 self.n_samples = n_samples 107 self.delta = delta
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile