torch_em.transform.label
1from typing import Optional 2 3import numpy as np 4import skimage.measure 5import skimage.segmentation 6import vigra 7 8from ..util import ensure_array, ensure_spatial_array 9 10try: 11 from affogato.affinities import compute_affinities 12except ImportError: 13 compute_affinities = None 14 15 16def connected_components(labels, ndim=None, ensure_zero=False): 17 labels = ensure_array(labels) if ndim is None else ensure_spatial_array(labels, ndim) 18 labels = skimage.measure.label(labels) 19 if ensure_zero and 0 not in labels: 20 labels -= 1 21 return labels 22 23 24def labels_to_binary(labels, background_label=0): 25 return (labels != background_label).astype(labels.dtype) 26 27 28def label_consecutive(labels, with_background=True): 29 if with_background: 30 seg = skimage.segmentation.relabel_sequential(labels)[0] 31 else: 32 if 0 in labels: 33 labels += 1 34 seg = skimage.segmentation.relabel_sequential(labels)[0] 35 assert seg.min() == 1 36 seg -= 1 37 return seg 38 39 40# TODO smoothing 41class BoundaryTransform: 42 def __init__(self, mode="thick", add_binary_target=False, ndim=None): 43 self.mode = mode 44 self.add_binary_target = add_binary_target 45 self.ndim = ndim 46 47 def __call__(self, labels): 48 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 49 boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None] 50 if self.add_binary_target: 51 binary = labels_to_binary(labels)[None].astype(boundaries.dtype) 52 target = np.concatenate([binary, boundaries], axis=0) 53 else: 54 target = boundaries 55 return target 56 57 58# TODO smoothing 59class NoToBackgroundBoundaryTransform: 60 def __init__(self, bg_label=0, mask_label=-1, mode="thick", add_binary_target=False, ndim=None): 61 self.bg_label = bg_label 62 self.mask_label = mask_label 63 self.mode = mode 64 self.ndim = ndim 65 self.add_binary_target = add_binary_target 66 67 def __call__(self, labels): 68 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 69 # calc normal boundaries 70 boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None] 71 72 # make label image binary and calculate to-background-boundaries 73 labels_binary = (labels != self.bg_label) 74 to_bg_boundaries = skimage.segmentation.find_boundaries(labels_binary, mode=self.mode)[None] 75 76 # mask the to-background-boundaries 77 boundaries = boundaries.astype(np.int8) 78 boundaries[to_bg_boundaries] = self.mask_label 79 80 if self.add_binary_target: 81 binary = labels_to_binary(labels, self.bg_label).astype(boundaries.dtype) 82 binary[labels == self.mask_label] = self.mask_label 83 target = np.concatenate([binary[None], boundaries], axis=0) 84 else: 85 target = boundaries 86 87 return target 88 89 90# TODO smoothing 91class BoundaryTransformWithIgnoreLabel: 92 def __init__(self, ignore_label=-1, mode="thick", add_binary_target=False, ndim=None): 93 self.ignore_label = ignore_label 94 self.mode = mode 95 self.ndim = ndim 96 self.add_binary_target = add_binary_target 97 98 def __call__(self, labels): 99 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 100 # calculate the normal boundaries 101 boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None] 102 103 # calculate the boundaries for the ignore label 104 labels_ignore = (labels == self.ignore_label) 105 to_ignore_boundaries = skimage.segmentation.find_boundaries(labels_ignore, mode=self.mode)[None] 106 107 # mask the to-background-boundaries 108 boundaries = boundaries.astype(np.int8) 109 boundaries[to_ignore_boundaries] = self.ignore_label 110 111 if self.add_binary_target: 112 binary = labels_to_binary(labels).astype(boundaries.dtype) 113 binary[labels == self.ignore_label] = self.ignore_label 114 target = np.concatenate([binary[None], boundaries], axis=0) 115 else: 116 target = boundaries 117 118 return target 119 120 121# TODO affinity smoothing 122class AffinityTransform: 123 def __init__(self, offsets, 124 ignore_label=None, 125 add_binary_target=False, 126 add_mask=False, 127 include_ignore_transitions=False): 128 assert compute_affinities is not None 129 self.offsets = offsets 130 self.ndim = len(self.offsets[0]) 131 assert self.ndim in (2, 3) 132 133 self.ignore_label = ignore_label 134 self.add_binary_target = add_binary_target 135 self.add_mask = add_mask 136 self.include_ignore_transitions = include_ignore_transitions 137 138 def add_ignore_transitions(self, affs, mask, labels): 139 ignore_seg = (labels == self.ignore_label).astype(labels.dtype) 140 ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets) 141 invalid_mask = np.logical_not(invalid_mask) 142 # NOTE affinity convention returned by affogato: transitions are marked by 0 143 ignore_transitions = ignore_transitions == 0 144 ignore_transitions[invalid_mask] = 0 145 affs[ignore_transitions] = 1 146 mask[ignore_transitions] = 1 147 return affs, mask 148 149 def __call__(self, labels): 150 dtype = "uint64" 151 if np.dtype(labels.dtype) in (np.dtype("int16"), np.dtype("int32"), np.dtype("int64")): 152 dtype = "int64" 153 labels = ensure_spatial_array(labels, self.ndim, dtype=dtype) 154 affs, mask = compute_affinities(labels, self.offsets, 155 have_ignore_label=self.ignore_label is not None, 156 ignore_label=0 if self.ignore_label is None else self.ignore_label) 157 # we use the "disaffinity" convention for training; i.e. 1 means repulsive, 0 attractive 158 affs = 1. - affs 159 160 # remove transitions to the ignore label from the mask 161 if self.ignore_label is not None and self.include_ignore_transitions: 162 affs, mask = self.add_ignore_transitions(affs, mask, labels) 163 164 if self.add_binary_target: 165 binary = labels_to_binary(labels)[None].astype(affs.dtype) 166 assert binary.ndim == affs.ndim 167 affs = np.concatenate([binary, affs], axis=0) 168 169 if self.add_mask: 170 if self.add_binary_target: 171 if self.ignore_label is None: 172 mask_for_bin = np.ones((1,) + labels.shape, dtype=mask.dtype) 173 else: 174 mask_for_bin = (labels != self.ignore_label)[None].astype(mask.dtype) 175 assert mask.ndim == mask_for_bin.ndim 176 mask = np.concatenate([mask_for_bin, mask], axis=0) 177 assert affs.shape == mask.shape 178 affs = np.concatenate([affs, mask.astype(affs.dtype)], axis=0) 179 180 return affs 181 182 183class OneHotTransform: 184 def __init__(self, class_ids=None): 185 self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids 186 187 def __call__(self, labels): 188 class_ids = np.unique(labels).tolist() if self.class_ids is None else self.class_ids 189 n_classes = len(class_ids) 190 one_hot = np.zeros((n_classes,) + labels.shape, dtype="float32") 191 for i, class_id in enumerate(class_ids): 192 one_hot[i][labels == class_id] = 1.0 193 return one_hot 194 195 196class DistanceTransform: 197 """Compute distances to foreground in the labels. 198 199 Args: 200 distances: Whether to compute the absolute distances. 201 directed_distances: Whether to compute the directed distances (vector distances). 202 normalize: Whether to normalize the computed distances. 203 max_distance: Maximal distance at which to threshold the distances. 204 foreground_id: Label id to which the distance is compute. 205 invert Whether to invert the distances: 206 func: Normalization function for the distances. 207 """ 208 eps = 1e-7 209 210 def __init__( 211 self, 212 distances: bool = True, 213 directed_distances: bool = False, 214 normalize: bool = True, 215 max_distance: Optional[float] = None, 216 foreground_id=1, 217 invert=False, 218 func=None 219 ): 220 if sum((distances, directed_distances)) == 0: 221 raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'") 222 self.directed_distances = directed_distances 223 self.distances = distances 224 self.normalize = normalize 225 self.max_distance = max_distance 226 self.foreground_id = foreground_id 227 self.invert = invert 228 self.func = func 229 230 def _compute_distances(self, directed_distances): 231 distances = np.linalg.norm(directed_distances, axis=0) 232 if self.max_distance is not None: 233 distances = np.clip(distances, 0, self.max_distance) 234 if self.normalize: 235 distances /= (distances.max() + self.eps) 236 if self.invert: 237 distances = distances.max() - distances 238 if self.func is not None: 239 distances = self.func(distances) 240 return distances 241 242 def _compute_directed_distances(self, directed_distances): 243 if self.max_distance is not None: 244 directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance) 245 if self.normalize: 246 directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps) 247 if self.invert: 248 directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances 249 if self.func is not None: 250 directed_distances = self.func(directed_distances) 251 return directed_distances 252 253 def _get_distances_for_empty_labels(self, labels): 254 shape = labels.shape 255 fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2) 256 data = np.full((labels.ndim,) + shape, fill_value) 257 return data 258 259 def __call__(self, labels): 260 distance_mask = (labels == self.foreground_id).astype("uint32") 261 # the distances are not computed corrected if they are all zero 262 # so this case needs to be handled separately 263 if distance_mask.sum() == 0: 264 directed_distances = self._get_distances_for_empty_labels(labels) 265 else: 266 ndim = distance_mask.ndim 267 to_channel_first = (ndim,) + tuple(range(ndim)) 268 directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first) 269 270 if self.distances: 271 distances = self._compute_distances(directed_distances) 272 273 if self.directed_distances: 274 directed_distances = self._compute_directed_distances(directed_distances) 275 276 if self.distances and self.directed_distances: 277 return np.concatenate((distances[None], directed_distances), axis=0) 278 if self.distances: 279 return distances 280 if self.directed_distances: 281 return directed_distances 282 283 284class PerObjectDistanceTransform: 285 """Compute normalized distances per object in a segmentation. 286 287 Args: 288 distances: Whether to compute the undirected distances. 289 boundary_distances: Whether to compute the distances to the object boundaries. 290 directed_distances: Whether to compute the directed distances (vector distances). 291 foreground: Whether to return a foreground channel. 292 apply_label: Whether to apply connected components to the labels before computing distances. 293 correct_centers: Whether to correct centers that are not in the objects. 294 min_size: Minimal size of objects for distance calculdation. 295 distance_fill_value: Fill value for the distances outside of objects. 296 """ 297 eps = 1e-7 298 299 def __init__( 300 self, 301 distances=True, 302 boundary_distances=True, 303 directed_distances=False, 304 foreground=True, 305 instances=False, 306 apply_label=True, 307 correct_centers=True, 308 min_size=0, 309 distance_fill_value=1.0, 310 ): 311 if sum([distances, directed_distances, boundary_distances]) == 0: 312 raise ValueError("At least one of distances or directed distances has to be passed.") 313 self.distances = distances 314 self.boundary_distances = boundary_distances 315 self.directed_distances = directed_distances 316 self.foreground = foreground 317 self.instances = instances 318 319 self.apply_label = apply_label 320 self.correct_centers = correct_centers 321 self.min_size = min_size 322 self.distance_fill_value = distance_fill_value 323 324 def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances): 325 # Crop the mask and generate array with the correct center. 326 cropped_mask = mask[bb] 327 cropped_center = tuple(ce - b.start for ce, b in zip(center, bb)) 328 329 # The centroid might not be inside of the object. 330 # In this case we correct the center by taking the maximum of the distance to the boundary. 331 # Note: the centroid is still the best estimate for the center, as long as it's in the object. 332 correct_center = not cropped_mask[cropped_center] 333 334 # Compute the boundary distances if necessary. 335 # (Either if we need to correct the center, or compute the boundary distances anyways.) 336 if correct_center or self.boundary_distances: 337 # Crop the boundary mask and compute the boundary distances. 338 cropped_boundary_mask = boundaries[bb] 339 boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask) 340 boundary_distances[~cropped_mask] = 0 341 max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape) 342 343 # Set the crop center to the max dist point 344 if correct_center: 345 # Find the center (= maximal distance from the boundaries). 346 cropped_center = max_dist_point 347 348 cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32") 349 cropped_center_mask[cropped_center] = 1 350 351 # Compute the directed distances, 352 if self.distances or self.directed_distances: 353 this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask) 354 else: 355 this_distances = None 356 357 # Keep only the specified distances: 358 if self.distances and self.directed_distances: # all distances 359 # Compute the undirected ditacnes from directed distances and concatenate, 360 undir = np.linalg.norm(this_distances, axis=-1, keepdims=True) 361 this_distances = np.concatenate([undir, this_distances], axis=-1) 362 363 elif self.distances: # only undirected distances 364 # Compute the undirected distances from directed distances and keep only them. 365 this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True) 366 367 elif self.directed_distances: # only directed distances 368 pass # We don't have to do anything becasue the directed distances are already computed. 369 370 # Add an extra channel for the boundary distances if specified. 371 if self.boundary_distances: 372 boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None] 373 if this_distances is None: 374 this_distances = boundary_distances 375 else: 376 this_distances = np.concatenate([this_distances, boundary_distances], axis=-1) 377 378 # Set distances outside of the mask to zero. 379 this_distances[~cropped_mask] = 0 380 381 # Normalize the distances. 382 spatial_axes = tuple(range(mask.ndim)) 383 this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps) 384 385 # Set the distance values in the global result. 386 distances[bb][cropped_mask] = this_distances[cropped_mask] 387 388 return distances 389 390 def __call__(self, labels): 391 # Apply label (connected components) if specified. 392 if self.apply_label: 393 labels = skimage.measure.label(labels).astype("uint32") 394 else: # Otherwise just relabel the segmentation. 395 labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32") 396 397 # Filter out small objects if min_size is specified. 398 if self.min_size > 0: 399 ids, sizes = np.unique(labels, return_counts=True) 400 discard_ids = ids[sizes < self.min_size] 401 labels[np.isin(labels, discard_ids)] = 0 402 labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32") 403 404 # Compute the boundaries. They will be used to determine the most central point, 405 # and if 'self.boundary_distances is True' to add the boundary distances. 406 boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32") 407 408 # Compute region properties to derive bounding boxes and centers. 409 ndim = labels.ndim 410 props = skimage.measure.regionprops(labels) 411 bounding_boxes = { 412 prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim)) 413 for prop in props 414 } 415 416 # Compute the object centers from centroids. 417 centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props} 418 419 # Compute how many distance channels we have. 420 n_channels = 0 421 if self.distances: # We need one channel for the overall distances. 422 n_channels += 1 423 if self.boundary_distances: # We need one channel for the boundary distances. 424 n_channels += 1 425 if self.directed_distances: # And ndim channels for directed distances. 426 n_channels += ndim 427 428 # Compute the per object distances. 429 distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32") 430 for prop in props: 431 label_id = prop.label 432 mask = labels == label_id 433 distances = self.compute_normalized_object_distances( 434 mask, boundaries, bounding_boxes[label_id], centers[label_id], distances 435 ) 436 437 # Bring the distance channel to the first dimension. 438 to_channel_first = (ndim,) + tuple(range(ndim)) 439 distances = distances.transpose(to_channel_first) 440 441 # Add the foreground mask as first channel if specified. 442 if self.foreground: 443 binary_labels = (labels > 0).astype("float32") 444 distances = np.concatenate([binary_labels[None], distances], axis=0) 445 446 if self.instances: 447 distances = np.concatenate([labels[None], distances], axis=0) 448 449 return distances
def
connected_components(labels, ndim=None, ensure_zero=False):
def
labels_to_binary(labels, background_label=0):
def
label_consecutive(labels, with_background=True):
class
BoundaryTransform:
42class BoundaryTransform: 43 def __init__(self, mode="thick", add_binary_target=False, ndim=None): 44 self.mode = mode 45 self.add_binary_target = add_binary_target 46 self.ndim = ndim 47 48 def __call__(self, labels): 49 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 50 boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None] 51 if self.add_binary_target: 52 binary = labels_to_binary(labels)[None].astype(boundaries.dtype) 53 target = np.concatenate([binary, boundaries], axis=0) 54 else: 55 target = boundaries 56 return target
class
NoToBackgroundBoundaryTransform:
60class NoToBackgroundBoundaryTransform: 61 def __init__(self, bg_label=0, mask_label=-1, mode="thick", add_binary_target=False, ndim=None): 62 self.bg_label = bg_label 63 self.mask_label = mask_label 64 self.mode = mode 65 self.ndim = ndim 66 self.add_binary_target = add_binary_target 67 68 def __call__(self, labels): 69 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 70 # calc normal boundaries 71 boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None] 72 73 # make label image binary and calculate to-background-boundaries 74 labels_binary = (labels != self.bg_label) 75 to_bg_boundaries = skimage.segmentation.find_boundaries(labels_binary, mode=self.mode)[None] 76 77 # mask the to-background-boundaries 78 boundaries = boundaries.astype(np.int8) 79 boundaries[to_bg_boundaries] = self.mask_label 80 81 if self.add_binary_target: 82 binary = labels_to_binary(labels, self.bg_label).astype(boundaries.dtype) 83 binary[labels == self.mask_label] = self.mask_label 84 target = np.concatenate([binary[None], boundaries], axis=0) 85 else: 86 target = boundaries 87 88 return target
class
BoundaryTransformWithIgnoreLabel:
92class BoundaryTransformWithIgnoreLabel: 93 def __init__(self, ignore_label=-1, mode="thick", add_binary_target=False, ndim=None): 94 self.ignore_label = ignore_label 95 self.mode = mode 96 self.ndim = ndim 97 self.add_binary_target = add_binary_target 98 99 def __call__(self, labels): 100 labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) 101 # calculate the normal boundaries 102 boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None] 103 104 # calculate the boundaries for the ignore label 105 labels_ignore = (labels == self.ignore_label) 106 to_ignore_boundaries = skimage.segmentation.find_boundaries(labels_ignore, mode=self.mode)[None] 107 108 # mask the to-background-boundaries 109 boundaries = boundaries.astype(np.int8) 110 boundaries[to_ignore_boundaries] = self.ignore_label 111 112 if self.add_binary_target: 113 binary = labels_to_binary(labels).astype(boundaries.dtype) 114 binary[labels == self.ignore_label] = self.ignore_label 115 target = np.concatenate([binary[None], boundaries], axis=0) 116 else: 117 target = boundaries 118 119 return target
class
AffinityTransform:
123class AffinityTransform: 124 def __init__(self, offsets, 125 ignore_label=None, 126 add_binary_target=False, 127 add_mask=False, 128 include_ignore_transitions=False): 129 assert compute_affinities is not None 130 self.offsets = offsets 131 self.ndim = len(self.offsets[0]) 132 assert self.ndim in (2, 3) 133 134 self.ignore_label = ignore_label 135 self.add_binary_target = add_binary_target 136 self.add_mask = add_mask 137 self.include_ignore_transitions = include_ignore_transitions 138 139 def add_ignore_transitions(self, affs, mask, labels): 140 ignore_seg = (labels == self.ignore_label).astype(labels.dtype) 141 ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets) 142 invalid_mask = np.logical_not(invalid_mask) 143 # NOTE affinity convention returned by affogato: transitions are marked by 0 144 ignore_transitions = ignore_transitions == 0 145 ignore_transitions[invalid_mask] = 0 146 affs[ignore_transitions] = 1 147 mask[ignore_transitions] = 1 148 return affs, mask 149 150 def __call__(self, labels): 151 dtype = "uint64" 152 if np.dtype(labels.dtype) in (np.dtype("int16"), np.dtype("int32"), np.dtype("int64")): 153 dtype = "int64" 154 labels = ensure_spatial_array(labels, self.ndim, dtype=dtype) 155 affs, mask = compute_affinities(labels, self.offsets, 156 have_ignore_label=self.ignore_label is not None, 157 ignore_label=0 if self.ignore_label is None else self.ignore_label) 158 # we use the "disaffinity" convention for training; i.e. 1 means repulsive, 0 attractive 159 affs = 1. - affs 160 161 # remove transitions to the ignore label from the mask 162 if self.ignore_label is not None and self.include_ignore_transitions: 163 affs, mask = self.add_ignore_transitions(affs, mask, labels) 164 165 if self.add_binary_target: 166 binary = labels_to_binary(labels)[None].astype(affs.dtype) 167 assert binary.ndim == affs.ndim 168 affs = np.concatenate([binary, affs], axis=0) 169 170 if self.add_mask: 171 if self.add_binary_target: 172 if self.ignore_label is None: 173 mask_for_bin = np.ones((1,) + labels.shape, dtype=mask.dtype) 174 else: 175 mask_for_bin = (labels != self.ignore_label)[None].astype(mask.dtype) 176 assert mask.ndim == mask_for_bin.ndim 177 mask = np.concatenate([mask_for_bin, mask], axis=0) 178 assert affs.shape == mask.shape 179 affs = np.concatenate([affs, mask.astype(affs.dtype)], axis=0) 180 181 return affs
AffinityTransform( offsets, ignore_label=None, add_binary_target=False, add_mask=False, include_ignore_transitions=False)
124 def __init__(self, offsets, 125 ignore_label=None, 126 add_binary_target=False, 127 add_mask=False, 128 include_ignore_transitions=False): 129 assert compute_affinities is not None 130 self.offsets = offsets 131 self.ndim = len(self.offsets[0]) 132 assert self.ndim in (2, 3) 133 134 self.ignore_label = ignore_label 135 self.add_binary_target = add_binary_target 136 self.add_mask = add_mask 137 self.include_ignore_transitions = include_ignore_transitions
def
add_ignore_transitions(self, affs, mask, labels):
139 def add_ignore_transitions(self, affs, mask, labels): 140 ignore_seg = (labels == self.ignore_label).astype(labels.dtype) 141 ignore_transitions, invalid_mask = compute_affinities(ignore_seg, self.offsets) 142 invalid_mask = np.logical_not(invalid_mask) 143 # NOTE affinity convention returned by affogato: transitions are marked by 0 144 ignore_transitions = ignore_transitions == 0 145 ignore_transitions[invalid_mask] = 0 146 affs[ignore_transitions] = 1 147 mask[ignore_transitions] = 1 148 return affs, mask
class
OneHotTransform:
184class OneHotTransform: 185 def __init__(self, class_ids=None): 186 self.class_ids = list(range(class_ids)) if isinstance(class_ids, int) else class_ids 187 188 def __call__(self, labels): 189 class_ids = np.unique(labels).tolist() if self.class_ids is None else self.class_ids 190 n_classes = len(class_ids) 191 one_hot = np.zeros((n_classes,) + labels.shape, dtype="float32") 192 for i, class_id in enumerate(class_ids): 193 one_hot[i][labels == class_id] = 1.0 194 return one_hot
class
DistanceTransform:
197class DistanceTransform: 198 """Compute distances to foreground in the labels. 199 200 Args: 201 distances: Whether to compute the absolute distances. 202 directed_distances: Whether to compute the directed distances (vector distances). 203 normalize: Whether to normalize the computed distances. 204 max_distance: Maximal distance at which to threshold the distances. 205 foreground_id: Label id to which the distance is compute. 206 invert Whether to invert the distances: 207 func: Normalization function for the distances. 208 """ 209 eps = 1e-7 210 211 def __init__( 212 self, 213 distances: bool = True, 214 directed_distances: bool = False, 215 normalize: bool = True, 216 max_distance: Optional[float] = None, 217 foreground_id=1, 218 invert=False, 219 func=None 220 ): 221 if sum((distances, directed_distances)) == 0: 222 raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'") 223 self.directed_distances = directed_distances 224 self.distances = distances 225 self.normalize = normalize 226 self.max_distance = max_distance 227 self.foreground_id = foreground_id 228 self.invert = invert 229 self.func = func 230 231 def _compute_distances(self, directed_distances): 232 distances = np.linalg.norm(directed_distances, axis=0) 233 if self.max_distance is not None: 234 distances = np.clip(distances, 0, self.max_distance) 235 if self.normalize: 236 distances /= (distances.max() + self.eps) 237 if self.invert: 238 distances = distances.max() - distances 239 if self.func is not None: 240 distances = self.func(distances) 241 return distances 242 243 def _compute_directed_distances(self, directed_distances): 244 if self.max_distance is not None: 245 directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance) 246 if self.normalize: 247 directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps) 248 if self.invert: 249 directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances 250 if self.func is not None: 251 directed_distances = self.func(directed_distances) 252 return directed_distances 253 254 def _get_distances_for_empty_labels(self, labels): 255 shape = labels.shape 256 fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2) 257 data = np.full((labels.ndim,) + shape, fill_value) 258 return data 259 260 def __call__(self, labels): 261 distance_mask = (labels == self.foreground_id).astype("uint32") 262 # the distances are not computed corrected if they are all zero 263 # so this case needs to be handled separately 264 if distance_mask.sum() == 0: 265 directed_distances = self._get_distances_for_empty_labels(labels) 266 else: 267 ndim = distance_mask.ndim 268 to_channel_first = (ndim,) + tuple(range(ndim)) 269 directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first) 270 271 if self.distances: 272 distances = self._compute_distances(directed_distances) 273 274 if self.directed_distances: 275 directed_distances = self._compute_directed_distances(directed_distances) 276 277 if self.distances and self.directed_distances: 278 return np.concatenate((distances[None], directed_distances), axis=0) 279 if self.distances: 280 return distances 281 if self.directed_distances: 282 return directed_distances
Compute distances to foreground in the labels.
Arguments:
- distances: Whether to compute the absolute distances.
- directed_distances: Whether to compute the directed distances (vector distances).
- normalize: Whether to normalize the computed distances.
- max_distance: Maximal distance at which to threshold the distances.
- foreground_id: Label id to which the distance is compute.
- invert Whether to invert the distances:
- func: Normalization function for the distances.
DistanceTransform( distances: bool = True, directed_distances: bool = False, normalize: bool = True, max_distance: Optional[float] = None, foreground_id=1, invert=False, func=None)
211 def __init__( 212 self, 213 distances: bool = True, 214 directed_distances: bool = False, 215 normalize: bool = True, 216 max_distance: Optional[float] = None, 217 foreground_id=1, 218 invert=False, 219 func=None 220 ): 221 if sum((distances, directed_distances)) == 0: 222 raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'") 223 self.directed_distances = directed_distances 224 self.distances = distances 225 self.normalize = normalize 226 self.max_distance = max_distance 227 self.foreground_id = foreground_id 228 self.invert = invert 229 self.func = func
class
PerObjectDistanceTransform:
285class PerObjectDistanceTransform: 286 """Compute normalized distances per object in a segmentation. 287 288 Args: 289 distances: Whether to compute the undirected distances. 290 boundary_distances: Whether to compute the distances to the object boundaries. 291 directed_distances: Whether to compute the directed distances (vector distances). 292 foreground: Whether to return a foreground channel. 293 apply_label: Whether to apply connected components to the labels before computing distances. 294 correct_centers: Whether to correct centers that are not in the objects. 295 min_size: Minimal size of objects for distance calculdation. 296 distance_fill_value: Fill value for the distances outside of objects. 297 """ 298 eps = 1e-7 299 300 def __init__( 301 self, 302 distances=True, 303 boundary_distances=True, 304 directed_distances=False, 305 foreground=True, 306 instances=False, 307 apply_label=True, 308 correct_centers=True, 309 min_size=0, 310 distance_fill_value=1.0, 311 ): 312 if sum([distances, directed_distances, boundary_distances]) == 0: 313 raise ValueError("At least one of distances or directed distances has to be passed.") 314 self.distances = distances 315 self.boundary_distances = boundary_distances 316 self.directed_distances = directed_distances 317 self.foreground = foreground 318 self.instances = instances 319 320 self.apply_label = apply_label 321 self.correct_centers = correct_centers 322 self.min_size = min_size 323 self.distance_fill_value = distance_fill_value 324 325 def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances): 326 # Crop the mask and generate array with the correct center. 327 cropped_mask = mask[bb] 328 cropped_center = tuple(ce - b.start for ce, b in zip(center, bb)) 329 330 # The centroid might not be inside of the object. 331 # In this case we correct the center by taking the maximum of the distance to the boundary. 332 # Note: the centroid is still the best estimate for the center, as long as it's in the object. 333 correct_center = not cropped_mask[cropped_center] 334 335 # Compute the boundary distances if necessary. 336 # (Either if we need to correct the center, or compute the boundary distances anyways.) 337 if correct_center or self.boundary_distances: 338 # Crop the boundary mask and compute the boundary distances. 339 cropped_boundary_mask = boundaries[bb] 340 boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask) 341 boundary_distances[~cropped_mask] = 0 342 max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape) 343 344 # Set the crop center to the max dist point 345 if correct_center: 346 # Find the center (= maximal distance from the boundaries). 347 cropped_center = max_dist_point 348 349 cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32") 350 cropped_center_mask[cropped_center] = 1 351 352 # Compute the directed distances, 353 if self.distances or self.directed_distances: 354 this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask) 355 else: 356 this_distances = None 357 358 # Keep only the specified distances: 359 if self.distances and self.directed_distances: # all distances 360 # Compute the undirected ditacnes from directed distances and concatenate, 361 undir = np.linalg.norm(this_distances, axis=-1, keepdims=True) 362 this_distances = np.concatenate([undir, this_distances], axis=-1) 363 364 elif self.distances: # only undirected distances 365 # Compute the undirected distances from directed distances and keep only them. 366 this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True) 367 368 elif self.directed_distances: # only directed distances 369 pass # We don't have to do anything becasue the directed distances are already computed. 370 371 # Add an extra channel for the boundary distances if specified. 372 if self.boundary_distances: 373 boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None] 374 if this_distances is None: 375 this_distances = boundary_distances 376 else: 377 this_distances = np.concatenate([this_distances, boundary_distances], axis=-1) 378 379 # Set distances outside of the mask to zero. 380 this_distances[~cropped_mask] = 0 381 382 # Normalize the distances. 383 spatial_axes = tuple(range(mask.ndim)) 384 this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps) 385 386 # Set the distance values in the global result. 387 distances[bb][cropped_mask] = this_distances[cropped_mask] 388 389 return distances 390 391 def __call__(self, labels): 392 # Apply label (connected components) if specified. 393 if self.apply_label: 394 labels = skimage.measure.label(labels).astype("uint32") 395 else: # Otherwise just relabel the segmentation. 396 labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32") 397 398 # Filter out small objects if min_size is specified. 399 if self.min_size > 0: 400 ids, sizes = np.unique(labels, return_counts=True) 401 discard_ids = ids[sizes < self.min_size] 402 labels[np.isin(labels, discard_ids)] = 0 403 labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32") 404 405 # Compute the boundaries. They will be used to determine the most central point, 406 # and if 'self.boundary_distances is True' to add the boundary distances. 407 boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32") 408 409 # Compute region properties to derive bounding boxes and centers. 410 ndim = labels.ndim 411 props = skimage.measure.regionprops(labels) 412 bounding_boxes = { 413 prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim)) 414 for prop in props 415 } 416 417 # Compute the object centers from centroids. 418 centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props} 419 420 # Compute how many distance channels we have. 421 n_channels = 0 422 if self.distances: # We need one channel for the overall distances. 423 n_channels += 1 424 if self.boundary_distances: # We need one channel for the boundary distances. 425 n_channels += 1 426 if self.directed_distances: # And ndim channels for directed distances. 427 n_channels += ndim 428 429 # Compute the per object distances. 430 distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32") 431 for prop in props: 432 label_id = prop.label 433 mask = labels == label_id 434 distances = self.compute_normalized_object_distances( 435 mask, boundaries, bounding_boxes[label_id], centers[label_id], distances 436 ) 437 438 # Bring the distance channel to the first dimension. 439 to_channel_first = (ndim,) + tuple(range(ndim)) 440 distances = distances.transpose(to_channel_first) 441 442 # Add the foreground mask as first channel if specified. 443 if self.foreground: 444 binary_labels = (labels > 0).astype("float32") 445 distances = np.concatenate([binary_labels[None], distances], axis=0) 446 447 if self.instances: 448 distances = np.concatenate([labels[None], distances], axis=0) 449 450 return distances
Compute normalized distances per object in a segmentation.
Arguments:
- distances: Whether to compute the undirected distances.
- boundary_distances: Whether to compute the distances to the object boundaries.
- directed_distances: Whether to compute the directed distances (vector distances).
- foreground: Whether to return a foreground channel.
- apply_label: Whether to apply connected components to the labels before computing distances.
- correct_centers: Whether to correct centers that are not in the objects.
- min_size: Minimal size of objects for distance calculdation.
- distance_fill_value: Fill value for the distances outside of objects.
PerObjectDistanceTransform( distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=False, apply_label=True, correct_centers=True, min_size=0, distance_fill_value=1.0)
300 def __init__( 301 self, 302 distances=True, 303 boundary_distances=True, 304 directed_distances=False, 305 foreground=True, 306 instances=False, 307 apply_label=True, 308 correct_centers=True, 309 min_size=0, 310 distance_fill_value=1.0, 311 ): 312 if sum([distances, directed_distances, boundary_distances]) == 0: 313 raise ValueError("At least one of distances or directed distances has to be passed.") 314 self.distances = distances 315 self.boundary_distances = boundary_distances 316 self.directed_distances = directed_distances 317 self.foreground = foreground 318 self.instances = instances 319 320 self.apply_label = apply_label 321 self.correct_centers = correct_centers 322 self.min_size = min_size 323 self.distance_fill_value = distance_fill_value
def
compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
325 def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances): 326 # Crop the mask and generate array with the correct center. 327 cropped_mask = mask[bb] 328 cropped_center = tuple(ce - b.start for ce, b in zip(center, bb)) 329 330 # The centroid might not be inside of the object. 331 # In this case we correct the center by taking the maximum of the distance to the boundary. 332 # Note: the centroid is still the best estimate for the center, as long as it's in the object. 333 correct_center = not cropped_mask[cropped_center] 334 335 # Compute the boundary distances if necessary. 336 # (Either if we need to correct the center, or compute the boundary distances anyways.) 337 if correct_center or self.boundary_distances: 338 # Crop the boundary mask and compute the boundary distances. 339 cropped_boundary_mask = boundaries[bb] 340 boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask) 341 boundary_distances[~cropped_mask] = 0 342 max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape) 343 344 # Set the crop center to the max dist point 345 if correct_center: 346 # Find the center (= maximal distance from the boundaries). 347 cropped_center = max_dist_point 348 349 cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32") 350 cropped_center_mask[cropped_center] = 1 351 352 # Compute the directed distances, 353 if self.distances or self.directed_distances: 354 this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask) 355 else: 356 this_distances = None 357 358 # Keep only the specified distances: 359 if self.distances and self.directed_distances: # all distances 360 # Compute the undirected ditacnes from directed distances and concatenate, 361 undir = np.linalg.norm(this_distances, axis=-1, keepdims=True) 362 this_distances = np.concatenate([undir, this_distances], axis=-1) 363 364 elif self.distances: # only undirected distances 365 # Compute the undirected distances from directed distances and keep only them. 366 this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True) 367 368 elif self.directed_distances: # only directed distances 369 pass # We don't have to do anything becasue the directed distances are already computed. 370 371 # Add an extra channel for the boundary distances if specified. 372 if self.boundary_distances: 373 boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None] 374 if this_distances is None: 375 this_distances = boundary_distances 376 else: 377 this_distances = np.concatenate([this_distances, boundary_distances], axis=-1) 378 379 # Set distances outside of the mask to zero. 380 this_distances[~cropped_mask] = 0 381 382 # Normalize the distances. 383 spatial_axes = tuple(range(mask.ndim)) 384 this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps) 385 386 # Set the distance values in the global result. 387 distances[bb][cropped_mask] = this_distances[cropped_mask] 388 389 return distances