torch_em.loss.affinity_side_loss
1from typing import List, Optional, Tuple 2 3import numpy as np 4import torch 5import torch.nn as nn 6from .dice import dice_score 7 8 9def shift_tensor(tensor: torch.Tensor, offset: List[int]) -> torch.Tensor: 10 """Shift a tensor by the given spatial offset. 11 12 Args: 13 tensor: A 4D (2 spatial dims) or 5D (3 spatial dims) tensor. Needs to be of float type. 14 offset: A 2d or 3d spatial offset used for shifting the tensor 15 16 Returns: 17 The shifted tensor. 18 """ 19 ndim = len(offset) 20 assert ndim in (2, 3) 21 diff = tensor.dim() - ndim 22 23 # don't pad for the first dimensions 24 # (usually batch and/or channel dimension) 25 slice_ = diff * [slice(None)] 26 27 # torch padding behaviour is a bit weird. 28 # we use nn.ReplicationPadND 29 # (torch.nn.functional.pad is even weirder and ReflectionPad is not supported in 3d) 30 # still, padding needs to be given in the inverse spatial order 31 32 # add padding in inverse spatial order 33 padding = [] 34 for off in offset[::-1]: 35 # if we have a negative offset, we need to shift "to the left", 36 # which means padding at the right border 37 # if we have a positive offset, we need to shift "to the right", 38 # which means padding to the left border 39 padding.extend([max(0, off), max(0, -off)]) 40 41 # add slicing in the normal spatial order 42 for off in offset: 43 if off == 0: 44 slice_.append(slice(None)) 45 elif off > 0: 46 slice_.append(slice(None, -off)) 47 else: 48 slice_.append(slice(-off, None)) 49 50 # pad the spatial part of the tensor with replication padding 51 slice_ = tuple(slice_) 52 padding = tuple(padding) 53 padder = nn.ReplicationPad2d if ndim == 2 else nn.ReplicationPad3d 54 padder = padder(padding) 55 shifted = padder(tensor) 56 57 # slice the oadded tensor to get the spatially shifted tensor 58 shifted = shifted[slice_] 59 assert shifted.shape == tensor.shape 60 61 return shifted 62 63 64def invert_offsets(offsets): 65 """@private 66 """ 67 return [[-off for off in offset] for offset in offsets] 68 69 70def segmentation_to_affinities(segmentation: torch.Tensor, offsets: List[List[int]]) -> torch.Tensor: 71 """Transform segmentation to affinities. 72 73 Args: 74 segmentation: A 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor. 75 The channel axis (= dimension 1) needs to be a singleton. 76 offsets: List of offsets for which to compute the affinities. 77 78 Returns: 79 The affinities. 80 """ 81 assert segmentation.shape[1] == 1, f"{segmentation.shape}" 82 # Shift the segmentation and substract the shifted tensor from the segmentation. 83 # We need to shift in the opposite direction of the offsets, so we invert them before applying the shift. 84 offsets_ = invert_offsets(offsets) 85 shifted = torch.cat([shift_tensor(segmentation.float(), off) for off in offsets_], dim=1) 86 affs = (segmentation - shifted) 87 # The affinities are 1, where we had the same segment id (the difference is 0) and 0 otherwise. 88 affs.eq_(0.) 89 return affs 90 91 92def embeddings_to_affinities(embeddings: torch.Tensor, offsets: List[List[int]], delta: float) -> torch.Tensor: 93 """Transform embeddings to affinities. 94 95 Args: 96 embeddings: The pixel-wise embeddings. 97 offsets: The offsets for computing affinities. 98 delta: The push force hinge used for training the embedding prediction network. 99 100 Returns: 101 The affinities. 102 """ 103 # Shift the embeddings by the offsets and stack them along a new axis. 104 # We need to shift in the opposite direction of the offsets, so we invert them before applying the shift. 105 offsets_ = invert_offsets(offsets) 106 shifted = torch.cat([shift_tensor(embeddings, off).unsqueeze(1) for off in offsets_], dim=1) 107 # Substract the embeddings from the shifted embeddings, take the norm and 108 # transform to affinities based on the delta distance. 109 affs = (2 * delta - torch.norm(embeddings.unsqueeze(1) - shifted, dim=2)) / (2 * delta) 110 affs = torch.clamp(affs, min=0) ** 2 111 return affs 112 113 114class AffinitySideLoss(nn.Module): 115 """Loss computed between affinities derived from predicted embeddings and a target segmentation. 116 117 The offsets for the affinities will be derived randomly from the given `offset_ranges`. 118 119 Args: 120 offset_ranges: Ranges for the offsets to sampled. 121 n_samples: Number of offsets to sample per loss computation. 122 delta: The push force hinge used for training the embedding prediction network. 123 """ 124 def __init__(self, offset_ranges: List[Tuple[int, int]], n_samples: int, delta: float): 125 assert all(len(orange) == 2 for orange in offset_ranges) 126 super().__init__() 127 self.ndim = len(offset_ranges) 128 self.offset_ranges = offset_ranges 129 self.n_samples = n_samples 130 self.delta = delta 131 132 def __call__( 133 self, 134 input_: torch.Tensor, 135 target: torch.Tensor, 136 ignore_labels: Optional[List[int]] = None, 137 ignore_in_variance_term: Optional[List[int]] = None, 138 ignore_in_distance_term: Optional[List[int]] = None, 139 ) -> torch.Tensor: 140 """Compute loss between affinities derived from predicted embeddings and a target segmentation. 141 142 Note: Support for the ignore labels is currently not implemented. 143 144 Args: 145 input_: The predicted embeddings. 146 target: The target segmentation. 147 ignore_labels: Ignore labels for the loss computation. 148 ignore_in_variance_term: Ignore labels for the variance term. 149 ignore_in_distance_term: Ignore labels for the distance term. 150 151 Returns: 152 The affinity loss value. 153 """ 154 assert input_.dim() == target.dim(), f"{input_.dim()}, {target.dim()}" 155 assert input_.shape[2:] == target.shape[2:] 156 157 # Sample the offsets. 158 offsets = [[np.random.randint(orange[0], orange[1]) for orange in self.offset_ranges] 159 for _ in range(self.n_samples)] 160 161 # We invert the affinities and the target affinities, 162 # so that we get boundaries as foreground, which is benefitial for the dice loss. 163 # Compute affinities from emebeddings. 164 affs = 1. - embeddings_to_affinities(input_, offsets, self.delta) 165 166 # Compute groundtruth affinities from the target segmentation. 167 target_affs = 1. - segmentation_to_affinities(target, offsets) 168 assert affs.shape == target_affs.shape, f"{affs.shape}, {target_affs.shape}" 169 170 # TODO implement masking the ignore labels 171 # Compute the dice score between affinities and target affinities. 172 return dice_score(affs, target_affs, invert=True)
def
shift_tensor(tensor: torch.Tensor, offset: List[int]) -> torch.Tensor:
10def shift_tensor(tensor: torch.Tensor, offset: List[int]) -> torch.Tensor: 11 """Shift a tensor by the given spatial offset. 12 13 Args: 14 tensor: A 4D (2 spatial dims) or 5D (3 spatial dims) tensor. Needs to be of float type. 15 offset: A 2d or 3d spatial offset used for shifting the tensor 16 17 Returns: 18 The shifted tensor. 19 """ 20 ndim = len(offset) 21 assert ndim in (2, 3) 22 diff = tensor.dim() - ndim 23 24 # don't pad for the first dimensions 25 # (usually batch and/or channel dimension) 26 slice_ = diff * [slice(None)] 27 28 # torch padding behaviour is a bit weird. 29 # we use nn.ReplicationPadND 30 # (torch.nn.functional.pad is even weirder and ReflectionPad is not supported in 3d) 31 # still, padding needs to be given in the inverse spatial order 32 33 # add padding in inverse spatial order 34 padding = [] 35 for off in offset[::-1]: 36 # if we have a negative offset, we need to shift "to the left", 37 # which means padding at the right border 38 # if we have a positive offset, we need to shift "to the right", 39 # which means padding to the left border 40 padding.extend([max(0, off), max(0, -off)]) 41 42 # add slicing in the normal spatial order 43 for off in offset: 44 if off == 0: 45 slice_.append(slice(None)) 46 elif off > 0: 47 slice_.append(slice(None, -off)) 48 else: 49 slice_.append(slice(-off, None)) 50 51 # pad the spatial part of the tensor with replication padding 52 slice_ = tuple(slice_) 53 padding = tuple(padding) 54 padder = nn.ReplicationPad2d if ndim == 2 else nn.ReplicationPad3d 55 padder = padder(padding) 56 shifted = padder(tensor) 57 58 # slice the oadded tensor to get the spatially shifted tensor 59 shifted = shifted[slice_] 60 assert shifted.shape == tensor.shape 61 62 return shifted
Shift a tensor by the given spatial offset.
Arguments:
- tensor: A 4D (2 spatial dims) or 5D (3 spatial dims) tensor. Needs to be of float type.
- offset: A 2d or 3d spatial offset used for shifting the tensor
Returns:
The shifted tensor.
def
segmentation_to_affinities(segmentation: torch.Tensor, offsets: List[List[int]]) -> torch.Tensor:
71def segmentation_to_affinities(segmentation: torch.Tensor, offsets: List[List[int]]) -> torch.Tensor: 72 """Transform segmentation to affinities. 73 74 Args: 75 segmentation: A 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor. 76 The channel axis (= dimension 1) needs to be a singleton. 77 offsets: List of offsets for which to compute the affinities. 78 79 Returns: 80 The affinities. 81 """ 82 assert segmentation.shape[1] == 1, f"{segmentation.shape}" 83 # Shift the segmentation and substract the shifted tensor from the segmentation. 84 # We need to shift in the opposite direction of the offsets, so we invert them before applying the shift. 85 offsets_ = invert_offsets(offsets) 86 shifted = torch.cat([shift_tensor(segmentation.float(), off) for off in offsets_], dim=1) 87 affs = (segmentation - shifted) 88 # The affinities are 1, where we had the same segment id (the difference is 0) and 0 otherwise. 89 affs.eq_(0.) 90 return affs
Transform segmentation to affinities.
Arguments:
- segmentation: A 4D (2 spatial dims) or 5D (3 spatial dims) segmentation tensor. The channel axis (= dimension 1) needs to be a singleton.
- offsets: List of offsets for which to compute the affinities.
Returns:
The affinities.
def
embeddings_to_affinities( embeddings: torch.Tensor, offsets: List[List[int]], delta: float) -> torch.Tensor:
93def embeddings_to_affinities(embeddings: torch.Tensor, offsets: List[List[int]], delta: float) -> torch.Tensor: 94 """Transform embeddings to affinities. 95 96 Args: 97 embeddings: The pixel-wise embeddings. 98 offsets: The offsets for computing affinities. 99 delta: The push force hinge used for training the embedding prediction network. 100 101 Returns: 102 The affinities. 103 """ 104 # Shift the embeddings by the offsets and stack them along a new axis. 105 # We need to shift in the opposite direction of the offsets, so we invert them before applying the shift. 106 offsets_ = invert_offsets(offsets) 107 shifted = torch.cat([shift_tensor(embeddings, off).unsqueeze(1) for off in offsets_], dim=1) 108 # Substract the embeddings from the shifted embeddings, take the norm and 109 # transform to affinities based on the delta distance. 110 affs = (2 * delta - torch.norm(embeddings.unsqueeze(1) - shifted, dim=2)) / (2 * delta) 111 affs = torch.clamp(affs, min=0) ** 2 112 return affs
Transform embeddings to affinities.
Arguments:
- embeddings: The pixel-wise embeddings.
- offsets: The offsets for computing affinities.
- delta: The push force hinge used for training the embedding prediction network.
Returns:
The affinities.
class
AffinitySideLoss(torch.nn.modules.module.Module):
115class AffinitySideLoss(nn.Module): 116 """Loss computed between affinities derived from predicted embeddings and a target segmentation. 117 118 The offsets for the affinities will be derived randomly from the given `offset_ranges`. 119 120 Args: 121 offset_ranges: Ranges for the offsets to sampled. 122 n_samples: Number of offsets to sample per loss computation. 123 delta: The push force hinge used for training the embedding prediction network. 124 """ 125 def __init__(self, offset_ranges: List[Tuple[int, int]], n_samples: int, delta: float): 126 assert all(len(orange) == 2 for orange in offset_ranges) 127 super().__init__() 128 self.ndim = len(offset_ranges) 129 self.offset_ranges = offset_ranges 130 self.n_samples = n_samples 131 self.delta = delta 132 133 def __call__( 134 self, 135 input_: torch.Tensor, 136 target: torch.Tensor, 137 ignore_labels: Optional[List[int]] = None, 138 ignore_in_variance_term: Optional[List[int]] = None, 139 ignore_in_distance_term: Optional[List[int]] = None, 140 ) -> torch.Tensor: 141 """Compute loss between affinities derived from predicted embeddings and a target segmentation. 142 143 Note: Support for the ignore labels is currently not implemented. 144 145 Args: 146 input_: The predicted embeddings. 147 target: The target segmentation. 148 ignore_labels: Ignore labels for the loss computation. 149 ignore_in_variance_term: Ignore labels for the variance term. 150 ignore_in_distance_term: Ignore labels for the distance term. 151 152 Returns: 153 The affinity loss value. 154 """ 155 assert input_.dim() == target.dim(), f"{input_.dim()}, {target.dim()}" 156 assert input_.shape[2:] == target.shape[2:] 157 158 # Sample the offsets. 159 offsets = [[np.random.randint(orange[0], orange[1]) for orange in self.offset_ranges] 160 for _ in range(self.n_samples)] 161 162 # We invert the affinities and the target affinities, 163 # so that we get boundaries as foreground, which is benefitial for the dice loss. 164 # Compute affinities from emebeddings. 165 affs = 1. - embeddings_to_affinities(input_, offsets, self.delta) 166 167 # Compute groundtruth affinities from the target segmentation. 168 target_affs = 1. - segmentation_to_affinities(target, offsets) 169 assert affs.shape == target_affs.shape, f"{affs.shape}, {target_affs.shape}" 170 171 # TODO implement masking the ignore labels 172 # Compute the dice score between affinities and target affinities. 173 return dice_score(affs, target_affs, invert=True)
Loss computed between affinities derived from predicted embeddings and a target segmentation.
The offsets for the affinities will be derived randomly from the given offset_ranges
.
Arguments:
- offset_ranges: Ranges for the offsets to sampled.
- n_samples: Number of offsets to sample per loss computation.
- delta: The push force hinge used for training the embedding prediction network.
AffinitySideLoss(offset_ranges: List[Tuple[int, int]], n_samples: int, delta: float)
125 def __init__(self, offset_ranges: List[Tuple[int, int]], n_samples: int, delta: float): 126 assert all(len(orange) == 2 for orange in offset_ranges) 127 super().__init__() 128 self.ndim = len(offset_ranges) 129 self.offset_ranges = offset_ranges 130 self.n_samples = n_samples 131 self.delta = delta
Initialize internal Module state, shared by both nn.Module and ScriptModule.