torch_em.util.grid_search
1from typing import Callable, Dict, List, Optional, Tuple 2 3import bioimageio.core 4import numpy as np 5import torch.nn as nn 6import xarray 7 8try: 9 from micro_sam.instance_segmentation import InstanceSegmentationWithDecoder 10 from micro_sam.evaluation.instance_segmentation import ( 11 default_grid_search_values_instance_segmentation_with_decoder, 12 evaluate_instance_segmentation_grid_search, 13 run_instance_segmentation_grid_search, 14 _get_range_of_search_values, 15 ) 16 17 HAVE_MICRO_SAM = True 18except ImportError: 19 class InstanceSegmentationWithDecoder: 20 def __init__(self, *args, **kwargs): 21 pass 22 23 HAVE_MICRO_SAM = False 24 25from ..transform.raw import standardize 26from .prediction import predict_with_padding, predict_with_halo 27from .segmentation import watershed_from_components 28 29 30def default_grid_search_values_boundary_based_instance_segmentation( 31 threshold1_values=None, threshold2_values=None, min_size_values=None, 32): 33 """@private 34 """ 35 if threshold1_values is None: 36 threshold1_values = [0.5, 0.55, 0.6] 37 if threshold2_values is None: 38 threshold2_values = _get_range_of_search_values( 39 [0.5, 0.9], step=0.1 40 ) 41 if min_size_values is None: 42 min_size_values = [25, 50, 75, 100, 200] 43 44 return {"min_size": min_size_values, "threshold1": threshold1_values, "threshold2": threshold2_values} 45 46 47class _InstanceSegmentationBase(InstanceSegmentationWithDecoder): 48 """Over-write micro_sam functionality so that it works for distance based segmentation with a U-net. 49 """ 50 def __init__(self, model, preprocess=None, block_shape=None, halo=None): 51 self._model = model 52 self._preprocess = standardize if preprocess is None else preprocess 53 54 assert (block_shape is None) == (halo is None) 55 self._block_shape = block_shape 56 self._halo = halo 57 self._is_initialized = False 58 59 def _initialize_torch(self, data): 60 device = next(iter(self._model.parameters())).device 61 62 if self._block_shape is None: 63 if hasattr(self._model, "scale_factors"): 64 scale_factors = self._model.init_kwargs["scale_factors"] 65 min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)] 66 elif hasattr(self._model, "depth"): 67 depth = self._model.depth 68 min_divisible = [2**depth, 2**depth] 69 else: 70 raise RuntimeError 71 input_ = self._preprocess(data) 72 output = predict_with_padding(self._model, input_, min_divisible, device) 73 else: 74 output = predict_with_halo( 75 data, self._model, [device], self._block_shape, self._halo, 76 preprocess=self._preprocess, 77 ) 78 return output 79 80 def _initialize_modelzoo(self, data): 81 if self._block_shape is None: 82 with bioimageio.core.create_prediction_pipeline(self._model) as pp: 83 dims = tuple("bcyx") if data.ndim == 2 else tuple("bczyx") 84 input_ = xarray.DataArray(data[None, None], dims=dims) 85 output = bioimageio.core.prediction.predict_with_padding(pp, input_, padding=True)[0] 86 output = output.squeeze().values 87 else: 88 raise NotImplementedError 89 return output 90 91 92class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase): 93 """Wrapper for boundary based instance segmentation. 94 95 Instances of this class can be passed to `instance_segmentation_grid_search`. 96 97 Args: 98 model: The model to evaluate. It must predict two channels: 99 The first channel fpr foreground probabilities and the second for boundary probabilities. 100 preprocess: Optional preprocessing function to apply to the model inputs. 101 block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc. 102 halo: Halo for tiled prediction. 103 """ 104 def __init__( 105 self, 106 model: nn.Module, 107 preprocess: Optional[Callable] = None, 108 block_shape: Tuple[int, ...] = None, 109 halo: Tuple[int, ...] = None, 110 ): 111 super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo) 112 self._foreground = None 113 self._boundaries = None 114 115 def initialize(self, data): 116 """@private 117 """ 118 if isinstance(self._model, nn.Module): 119 output = self._initialize_torch(data) 120 else: 121 output = self._initialize_modelzoo(data) 122 assert output.shape[0] == 2 123 124 self._foreground = output[0] 125 self._boundaries = output[1] 126 self._is_initialized = True 127 128 def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"): 129 """@private 130 """ 131 segmentation = watershed_from_components( 132 self._boundaries, self._foreground, 133 min_size=min_size, threshold1=threshold1, threshold2=threshold2, 134 ) 135 if output_mode is not None: 136 segmentation = self._to_masks(segmentation, output_mode) 137 return segmentation 138 139 140class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase): 141 """Wrapper for distance based instance segmentation. 142 143 Instances of this class can be passed to `instance_segmentation_grid_search`. 144 145 Args: 146 model: The model to evaluate. It must predict three channels: 147 The first channel fpr foreground probabilities, the second for center distances 148 and the third for boundary distances. 149 preprocess: Optional preprocessing function to apply to the model inputs. 150 block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc. 151 halo: Halo for tiled prediction. 152 """ 153 def __init__( 154 self, 155 model: nn.Module, 156 preprocess: Optional[Callable] = None, 157 block_shape: Tuple[int, ...] = None, 158 halo: Tuple[int, ...] = None, 159 ): 160 super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo) 161 162 self._foreground = None 163 self._center_distances = None 164 self._boundary_distances = None 165 166 def initialize(self, data): 167 """@private 168 """ 169 if isinstance(self._model, nn.Module): 170 output = self._initialize_torch(data) 171 else: 172 output = self._initialize_modelzoo(data) 173 174 assert output.shape[0] == 3 175 self._foreground = output[0] 176 self._center_distances = output[1] 177 self._boundary_distances = output[2] 178 self._is_initialized = True 179 180 181def instance_segmentation_grid_search( 182 segmenter, 183 image_paths: List[str], 184 gt_paths: List[str], 185 result_dir: str, 186 grid_search_values: Optional[Dict] = None, 187 rois: Optional[List[Tuple[slice, ...]]] = None, 188 image_key: Optional[str] = None, 189 gt_key: Optional[str] = None, 190) -> Tuple[Dict, float]: 191 """Run grid search for instance segmentation. 192 193 Args: 194 segmenter: The segmentation wrapper. Needs to provide a 'initialize' and 'generate' function. 195 The class `DistanceBasedInstanceSegmentation` can be used for models predicting distances 196 for instance segmentation, `BoundaryBasedInstanceSegmentation` for models predicting boundaries. 197 image_paths: The paths to the images to use for the grid search. 198 gt_paths: The paths to the labels to use for the grid search. 199 result_dir: The directory for caching the grid search results. 200 grid_search_values: The values to test in the grid search. 201 rois: Region of interests to use for the evaluation. If given, must have the same length as `image_paths`. 202 image_key: The key to the internal dataset with the image data. 203 Leave None if the images are in a regular image format such as tif. 204 gt_key: The key to the internal dataset with the label data. 205 Leave None if the images are in a regular image format such as tif. 206 207 Returns: 208 The best parameters found by the grid search. 209 The best score of the grid search. 210 """ 211 if not HAVE_MICRO_SAM: 212 raise RuntimeError( 213 "The gridsearch functionality requires micro_sam. Install it via `conda install -c conda-forge micro_sam.`" 214 ) 215 216 if grid_search_values is None: 217 if isinstance(segmenter, DistanceBasedInstanceSegmentation): 218 grid_search_values = default_grid_search_values_instance_segmentation_with_decoder() 219 elif isinstance(segmenter, BoundaryBasedInstanceSegmentation): 220 grid_search_values = default_grid_search_values_boundary_based_instance_segmentation() 221 else: 222 raise ValueError(f"Could not derive default grid search values for segmenter of type {type(segmenter)}") 223 224 run_instance_segmentation_grid_search( 225 segmenter, grid_search_values, image_paths, gt_paths, result_dir, 226 embedding_dir=None, verbose_gs=True, 227 image_key=image_key, gt_key=gt_key, rois=rois, 228 ) 229 best_kwargs, best_score = evaluate_instance_segmentation_grid_search( 230 result_dir, list(grid_search_values.keys()) 231 ) 232 return best_kwargs, best_score
93class BoundaryBasedInstanceSegmentation(_InstanceSegmentationBase): 94 """Wrapper for boundary based instance segmentation. 95 96 Instances of this class can be passed to `instance_segmentation_grid_search`. 97 98 Args: 99 model: The model to evaluate. It must predict two channels: 100 The first channel fpr foreground probabilities and the second for boundary probabilities. 101 preprocess: Optional preprocessing function to apply to the model inputs. 102 block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc. 103 halo: Halo for tiled prediction. 104 """ 105 def __init__( 106 self, 107 model: nn.Module, 108 preprocess: Optional[Callable] = None, 109 block_shape: Tuple[int, ...] = None, 110 halo: Tuple[int, ...] = None, 111 ): 112 super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo) 113 self._foreground = None 114 self._boundaries = None 115 116 def initialize(self, data): 117 """@private 118 """ 119 if isinstance(self._model, nn.Module): 120 output = self._initialize_torch(data) 121 else: 122 output = self._initialize_modelzoo(data) 123 assert output.shape[0] == 2 124 125 self._foreground = output[0] 126 self._boundaries = output[1] 127 self._is_initialized = True 128 129 def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"): 130 """@private 131 """ 132 segmentation = watershed_from_components( 133 self._boundaries, self._foreground, 134 min_size=min_size, threshold1=threshold1, threshold2=threshold2, 135 ) 136 if output_mode is not None: 137 segmentation = self._to_masks(segmentation, output_mode) 138 return segmentation
Wrapper for boundary based instance segmentation.
Instances of this class can be passed to instance_segmentation_grid_search
.
Arguments:
- model: The model to evaluate. It must predict two channels: The first channel fpr foreground probabilities and the second for boundary probabilities.
- preprocess: Optional preprocessing function to apply to the model inputs.
- block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc.
- halo: Halo for tiled prediction.
BoundaryBasedInstanceSegmentation( model: torch.nn.modules.module.Module, preprocess: Optional[Callable] = None, block_shape: Tuple[int, ...] = None, halo: Tuple[int, ...] = None)
105 def __init__( 106 self, 107 model: nn.Module, 108 preprocess: Optional[Callable] = None, 109 block_shape: Tuple[int, ...] = None, 110 halo: Tuple[int, ...] = None, 111 ): 112 super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo) 113 self._foreground = None 114 self._boundaries = None
141class DistanceBasedInstanceSegmentation(_InstanceSegmentationBase): 142 """Wrapper for distance based instance segmentation. 143 144 Instances of this class can be passed to `instance_segmentation_grid_search`. 145 146 Args: 147 model: The model to evaluate. It must predict three channels: 148 The first channel fpr foreground probabilities, the second for center distances 149 and the third for boundary distances. 150 preprocess: Optional preprocessing function to apply to the model inputs. 151 block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc. 152 halo: Halo for tiled prediction. 153 """ 154 def __init__( 155 self, 156 model: nn.Module, 157 preprocess: Optional[Callable] = None, 158 block_shape: Tuple[int, ...] = None, 159 halo: Tuple[int, ...] = None, 160 ): 161 super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo) 162 163 self._foreground = None 164 self._center_distances = None 165 self._boundary_distances = None 166 167 def initialize(self, data): 168 """@private 169 """ 170 if isinstance(self._model, nn.Module): 171 output = self._initialize_torch(data) 172 else: 173 output = self._initialize_modelzoo(data) 174 175 assert output.shape[0] == 3 176 self._foreground = output[0] 177 self._center_distances = output[1] 178 self._boundary_distances = output[2] 179 self._is_initialized = True
Wrapper for distance based instance segmentation.
Instances of this class can be passed to instance_segmentation_grid_search
.
Arguments:
- model: The model to evaluate. It must predict three channels: The first channel fpr foreground probabilities, the second for center distances and the third for boundary distances.
- preprocess: Optional preprocessing function to apply to the model inputs.
- block_shape: Optional block shape for tiled prediction. If None, the inputs will be predicted en bloc.
- halo: Halo for tiled prediction.
DistanceBasedInstanceSegmentation( model: torch.nn.modules.module.Module, preprocess: Optional[Callable] = None, block_shape: Tuple[int, ...] = None, halo: Tuple[int, ...] = None)
154 def __init__( 155 self, 156 model: nn.Module, 157 preprocess: Optional[Callable] = None, 158 block_shape: Tuple[int, ...] = None, 159 halo: Tuple[int, ...] = None, 160 ): 161 super().__init__(model=model, preprocess=preprocess, block_shape=block_shape, halo=halo) 162 163 self._foreground = None 164 self._center_distances = None 165 self._boundary_distances = None
def
instance_segmentation_grid_search( segmenter, image_paths: List[str], gt_paths: List[str], result_dir: str, grid_search_values: Optional[Dict] = None, rois: Optional[List[Tuple[slice, ...]]] = None, image_key: Optional[str] = None, gt_key: Optional[str] = None) -> Tuple[Dict, float]:
182def instance_segmentation_grid_search( 183 segmenter, 184 image_paths: List[str], 185 gt_paths: List[str], 186 result_dir: str, 187 grid_search_values: Optional[Dict] = None, 188 rois: Optional[List[Tuple[slice, ...]]] = None, 189 image_key: Optional[str] = None, 190 gt_key: Optional[str] = None, 191) -> Tuple[Dict, float]: 192 """Run grid search for instance segmentation. 193 194 Args: 195 segmenter: The segmentation wrapper. Needs to provide a 'initialize' and 'generate' function. 196 The class `DistanceBasedInstanceSegmentation` can be used for models predicting distances 197 for instance segmentation, `BoundaryBasedInstanceSegmentation` for models predicting boundaries. 198 image_paths: The paths to the images to use for the grid search. 199 gt_paths: The paths to the labels to use for the grid search. 200 result_dir: The directory for caching the grid search results. 201 grid_search_values: The values to test in the grid search. 202 rois: Region of interests to use for the evaluation. If given, must have the same length as `image_paths`. 203 image_key: The key to the internal dataset with the image data. 204 Leave None if the images are in a regular image format such as tif. 205 gt_key: The key to the internal dataset with the label data. 206 Leave None if the images are in a regular image format such as tif. 207 208 Returns: 209 The best parameters found by the grid search. 210 The best score of the grid search. 211 """ 212 if not HAVE_MICRO_SAM: 213 raise RuntimeError( 214 "The gridsearch functionality requires micro_sam. Install it via `conda install -c conda-forge micro_sam.`" 215 ) 216 217 if grid_search_values is None: 218 if isinstance(segmenter, DistanceBasedInstanceSegmentation): 219 grid_search_values = default_grid_search_values_instance_segmentation_with_decoder() 220 elif isinstance(segmenter, BoundaryBasedInstanceSegmentation): 221 grid_search_values = default_grid_search_values_boundary_based_instance_segmentation() 222 else: 223 raise ValueError(f"Could not derive default grid search values for segmenter of type {type(segmenter)}") 224 225 run_instance_segmentation_grid_search( 226 segmenter, grid_search_values, image_paths, gt_paths, result_dir, 227 embedding_dir=None, verbose_gs=True, 228 image_key=image_key, gt_key=gt_key, rois=rois, 229 ) 230 best_kwargs, best_score = evaluate_instance_segmentation_grid_search( 231 result_dir, list(grid_search_values.keys()) 232 ) 233 return best_kwargs, best_score
Run grid search for instance segmentation.
Arguments:
- segmenter: The segmentation wrapper. Needs to provide a 'initialize' and 'generate' function.
The class
DistanceBasedInstanceSegmentation
can be used for models predicting distances for instance segmentation,BoundaryBasedInstanceSegmentation
for models predicting boundaries. - image_paths: The paths to the images to use for the grid search.
- gt_paths: The paths to the labels to use for the grid search.
- result_dir: The directory for caching the grid search results.
- grid_search_values: The values to test in the grid search.
- rois: Region of interests to use for the evaluation. If given, must have the same length as
image_paths
. - image_key: The key to the internal dataset with the image data. Leave None if the images are in a regular image format such as tif.
- gt_key: The key to the internal dataset with the label data. Leave None if the images are in a regular image format such as tif.
Returns:
The best parameters found by the grid search. The best score of the grid search.
class
InstanceSegmentationWithDecoder: