torch_em.util.grid_search
1import numpy as np 2import torch.nn as nn 3import xarray 4 5import bioimageio.core 6 7from micro_sam.instance_segmentation import InstanceSegmentationWithDecoder 8from micro_sam.evaluation.instance_segmentation import ( 9 default_grid_search_values_instance_segmentation_with_decoder, 10 evaluate_instance_segmentation_grid_search, 11 run_instance_segmentation_grid_search, 12 _get_range_of_search_values, 13) 14 15from ..transform.raw import standardize 16from .prediction import predict_with_padding, predict_with_halo 17from .segmentation import watershed_from_components 18 19 20def default_grid_search_values_boundary_based_instance_segmentation( 21 threshold1_values=None, 22 threshold2_values=None, 23 min_size_values=None, 24): 25 if threshold1_values is None: 26 threshold1_values = [0.5, 0.55, 0.6] 27 if threshold2_values is None: 28 threshold2_values = _get_range_of_search_values( 29 [0.5, 0.9], step=0.1 30 ) 31 if min_size_values is None: 32 min_size_values = [25, 50, 75, 100, 200] 33 34 return { 35 "min_size": min_size_values, 36 "threshold1": threshold1_values, 37 "threshold2": threshold2_values, 38 } 39 40 41class _InstanceSegmentationBase(InstanceSegmentationWithDecoder): 42 def __init__(self, model, preprocess=None, block_shape=None, halo=None): 43 self._model = model 44 self._preprocess = standardize if preprocess is None else preprocess 45 46 assert (block_shape is None) == (halo is None) 47 self._block_shape = block_shape 48 self._halo = halo 49 50 self._is_initialized = False 51 52 def _initialize_torch(self, data): 53 device = next(iter(self._model.parameters())).device 54 55 if self._block_shape is None: 56 if hasattr(self._model, "scale_factors"): 57 scale_factors = self._model.init_kwargs["scale_factors"] 58 min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)] 59 elif hasattr(self._model, "depth"): 60 depth = self._model.depth 61 min_divisible = [2**depth, 2**depth] 62 else: 63 raise RuntimeError 64 input_ = self._preprocess(data) 65 output = predict_with_padding(self._model, input_, min_divisible, device) 66 else: 67 output = predict_with_halo( 68 data, self._model, [device], self._block_shape, self._halo, 69 preprocess=self._preprocess, 70 ) 71 return output 72 73 def _initialize_modelzoo(self, data): 74 if self._block_shape is None: 75 with bioimageio.core.create_prediction_pipeline(self._model) as pp: 76 dims = tuple("bcyx") if data.ndim == 2 else tuple("bczyx") 77 input_ = xarray.DataArray(data[None, None], dims=dims) 78 output = bioimageio.core.prediction.predict_with_padding(pp, input_, padding=True)[0] 79 output = output.squeeze().values 80 else: 81 raise NotImplementedError 82 return output 83 84 85class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase): 86 def __init__(self, model, preprocess=None, block_shape=None, halo=None): 87 super().__init__( 88 model=model, preprocess=preprocess, block_shape=block_shape, halo=halo 89 ) 90 91 self._foreground = None 92 self._boundaries = None 93 94 def initialize(self, data): 95 if isinstance(self._model, nn.Module): 96 output = self._initialize_torch(data) 97 else: 98 output = self._initialize_modelzoo(data) 99 100 assert output.shape[0] == 2 101 102 self._foreground = output[0] 103 self._boundaries = output[1] 104 105 self._is_initialized = True 106 107 def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"): 108 segmentation = watershed_from_components( 109 self._boundaries, self._foreground, 110 min_size=min_size, threshold1=threshold1, threshold2=threshold2, 111 ) 112 if output_mode is not None: 113 segmentation = self._to_masks(segmentation, output_mode) 114 return segmentation 115 116 117class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase): 118 """Over-write micro_sam functionality so that it works for distance based 119 segmentation with a U-net. 120 """ 121 def __init__(self, model, preprocess=None, block_shape=None, halo=None): 122 super().__init__( 123 model=model, preprocess=preprocess, block_shape=block_shape, halo=halo 124 ) 125 126 self._foreground = None 127 self._center_distances = None 128 self._boundary_distances = None 129 130 def initialize(self, data): 131 if isinstance(self._model, nn.Module): 132 output = self._initialize_torch(data) 133 else: 134 output = self._initialize_modelzoo(data) 135 136 assert output.shape[0] == 3 137 self._foreground = output[0] 138 self._center_distances = output[1] 139 self._boundary_distances = output[2] 140 141 self._is_initialized = True 142 143 144def instance_segmentation_grid_search( 145 segmenter, image_paths, gt_paths, result_dir, 146 grid_search_values=None, rois=None, 147 image_key=None, gt_key=None, 148): 149 if grid_search_values is None: 150 if isinstance(segmenter, DistanceBasedInstanceSegmentation): 151 grid_search_values = default_grid_search_values_instance_segmentation_with_decoder() 152 elif isinstance(segmenter, BoundaryBasedInstanceSegmentation): 153 grid_search_values = default_grid_search_values_boundary_based_instance_segmentation() 154 else: 155 raise ValueError(f"Could not derive default grid search values for segmenter of type {type(segmenter)}") 156 157 run_instance_segmentation_grid_search( 158 segmenter, grid_search_values, image_paths, gt_paths, result_dir, 159 embedding_dir=None, verbose_gs=True, 160 image_key=image_key, gt_key=gt_key, rois=rois, 161 ) 162 best_kwargs, best_score = evaluate_instance_segmentation_grid_search( 163 result_dir, list(grid_search_values.keys()) 164 ) 165 return best_kwargs, best_score
def
default_grid_search_values_boundary_based_instance_segmentation(threshold1_values=None, threshold2_values=None, min_size_values=None):
21def default_grid_search_values_boundary_based_instance_segmentation( 22 threshold1_values=None, 23 threshold2_values=None, 24 min_size_values=None, 25): 26 if threshold1_values is None: 27 threshold1_values = [0.5, 0.55, 0.6] 28 if threshold2_values is None: 29 threshold2_values = _get_range_of_search_values( 30 [0.5, 0.9], step=0.1 31 ) 32 if min_size_values is None: 33 min_size_values = [25, 50, 75, 100, 200] 34 35 return { 36 "min_size": min_size_values, 37 "threshold1": threshold1_values, 38 "threshold2": threshold2_values, 39 }
86class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase): 87 def __init__(self, model, preprocess=None, block_shape=None, halo=None): 88 super().__init__( 89 model=model, preprocess=preprocess, block_shape=block_shape, halo=halo 90 ) 91 92 self._foreground = None 93 self._boundaries = None 94 95 def initialize(self, data): 96 if isinstance(self._model, nn.Module): 97 output = self._initialize_torch(data) 98 else: 99 output = self._initialize_modelzoo(data) 100 101 assert output.shape[0] == 2 102 103 self._foreground = output[0] 104 self._boundaries = output[1] 105 106 self._is_initialized = True 107 108 def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"): 109 segmentation = watershed_from_components( 110 self._boundaries, self._foreground, 111 min_size=min_size, threshold1=threshold1, threshold2=threshold2, 112 ) 113 if output_mode is not None: 114 segmentation = self._to_masks(segmentation, output_mode) 115 return segmentation
Generates an instance segmentation without prompts, using a decoder.
Implements the same interface as AutomaticMaskGenerator
.
Use this class as follows:
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image) # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation.
Arguments:
- predictor: The segment anything predictor.
- decoder: The decoder to predict intermediate representations for instance segmentation.
def
initialize(self, data):
95 def initialize(self, data): 96 if isinstance(self._model, nn.Module): 97 output = self._initialize_torch(data) 98 else: 99 output = self._initialize_modelzoo(data) 100 101 assert output.shape[0] == 2 102 103 self._foreground = output[0] 104 self._boundaries = output[1] 105 106 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to be verbose.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
def
generate( self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode='binary_mask'):
108 def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"): 109 segmentation = watershed_from_components( 110 self._boundaries, self._foreground, 111 min_size=min_size, threshold1=threshold1, threshold2=threshold2, 112 ) 113 if output_mode is not None: 114 segmentation = self._to_masks(segmentation, output_mode) 115 return segmentation
Generate instance segmentation for the currently initialized image.
Arguments:
- center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions).
- boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions).
- foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction.
- foreground_threshold: Foreground predictions above this value will be used as foreground mask.
- distance_smoothing: Sigma value for smoothing the distance predictions.
- min_size: Minimal object size in the segmentation result.
- output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
Returns:
The instance segmentation masks.
Inherited Members
- micro_sam.instance_segmentation.InstanceSegmentationWithDecoder
- is_initialized
- get_state
- set_state
- clear_state
118class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase): 119 """Over-write micro_sam functionality so that it works for distance based 120 segmentation with a U-net. 121 """ 122 def __init__(self, model, preprocess=None, block_shape=None, halo=None): 123 super().__init__( 124 model=model, preprocess=preprocess, block_shape=block_shape, halo=halo 125 ) 126 127 self._foreground = None 128 self._center_distances = None 129 self._boundary_distances = None 130 131 def initialize(self, data): 132 if isinstance(self._model, nn.Module): 133 output = self._initialize_torch(data) 134 else: 135 output = self._initialize_modelzoo(data) 136 137 assert output.shape[0] == 3 138 self._foreground = output[0] 139 self._center_distances = output[1] 140 self._boundary_distances = output[2] 141 142 self._is_initialized = True
Over-write micro_sam functionality so that it works for distance based segmentation with a U-net.
def
initialize(self, data):
131 def initialize(self, data): 132 if isinstance(self._model, nn.Module): 133 output = self._initialize_torch(data) 134 else: 135 output = self._initialize_modelzoo(data) 136 137 assert output.shape[0] == 3 138 self._foreground = output[0] 139 self._center_distances = output[1] 140 self._boundary_distances = output[2] 141 142 self._is_initialized = True
Initialize image embeddings and decoder predictions for an image.
Arguments:
- image: The input image, volume or timeseries.
- image_embeddings: Optional precomputed image embeddings.
See
util.precompute_image_embeddings
for details. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - verbose: Whether to be verbose.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Inherited Members
- micro_sam.instance_segmentation.InstanceSegmentationWithDecoder
- is_initialized
- generate
- get_state
- set_state
- clear_state
def
instance_segmentation_grid_search( segmenter, image_paths, gt_paths, result_dir, grid_search_values=None, rois=None, image_key=None, gt_key=None):
145def instance_segmentation_grid_search( 146 segmenter, image_paths, gt_paths, result_dir, 147 grid_search_values=None, rois=None, 148 image_key=None, gt_key=None, 149): 150 if grid_search_values is None: 151 if isinstance(segmenter, DistanceBasedInstanceSegmentation): 152 grid_search_values = default_grid_search_values_instance_segmentation_with_decoder() 153 elif isinstance(segmenter, BoundaryBasedInstanceSegmentation): 154 grid_search_values = default_grid_search_values_boundary_based_instance_segmentation() 155 else: 156 raise ValueError(f"Could not derive default grid search values for segmenter of type {type(segmenter)}") 157 158 run_instance_segmentation_grid_search( 159 segmenter, grid_search_values, image_paths, gt_paths, result_dir, 160 embedding_dir=None, verbose_gs=True, 161 image_key=image_key, gt_key=gt_key, rois=rois, 162 ) 163 best_kwargs, best_score = evaluate_instance_segmentation_grid_search( 164 result_dir, list(grid_search_values.keys()) 165 ) 166 return best_kwargs, best_score