torch_em.util.prediction
1from copy import deepcopy 2from concurrent import futures 3from typing import Tuple, Union, Callable, Any, List, Optional 4 5import numpy as np 6import nifty.tools as nt 7import torch 8from numpy.typing import ArrayLike 9 10try: 11 from napari.utils import progress as tqdm 12except ImportError: 13 from tqdm import tqdm 14 15from ..transform.raw import standardize 16 17 18def predict_with_padding( 19 model: torch.nn.Module, 20 input_: np.ndarray, 21 min_divisible: Tuple[int, ...], 22 device: Optional[Union[torch.device, str]] = None, 23 with_channels: bool = False, 24 prediction_function: Callable[[Any], Any] = None 25) -> np.ndarray: 26 """Run prediction with padding for a model that can only deal with inputs divisible by specific factors. 27 28 Args: 29 model: The model. 30 input_: The input for prediction. 31 min_divisible: The minimal factors the input shape must be divisible by. 32 For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels. 33 device: The device of the model. If not given, will be derived from the model parameters. 34 with_channels: Whether the input data contains channels. 35 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 36 37 Returns: 38 np.ndarray: The ouptut of the model. 39 """ 40 if with_channels: 41 assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}" 42 min_divisible_ = (1,) + min_divisible 43 else: 44 assert len(min_divisible) == input_.ndim 45 min_divisible_ = min_divisible 46 47 if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)): 48 pad_width = tuple( 49 (0, 0 if sh % md == 0 else md - sh % md) 50 for sh, md in zip(input_.shape, min_divisible_) 51 ) 52 crop_padding = tuple(slice(0, sh) for sh in input_.shape) 53 input_ = np.pad(input_, pad_width, mode="reflect") 54 else: 55 crop_padding = None 56 57 ndim = input_.ndim 58 ndim_model = 1 + ndim if with_channels else 2 + ndim 59 60 if device is None: 61 device = next(model.parameters()).device 62 63 expand_dim = (None,) * (ndim_model - ndim) 64 with torch.no_grad(): 65 model_input = torch.from_numpy(input_[expand_dim]).to(device) 66 output = model(model_input) if prediction_function is None else prediction_function(model, model_input) 67 output = output.cpu().numpy() 68 69 if crop_padding is not None: 70 crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding 71 output = output[crop_padding] 72 73 return output 74 75 76def _pad_for_shift_left(arr, pad_vox, with_channels, mode="constant", constant_values=0.0): 77 pad_left = tuple(pad_vox) 78 pad_right = tuple(0 for _ in pad_vox) 79 80 pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right)) 81 if with_channels: 82 pad_width = ((0, 0),) + pad_width 83 84 arr_pad = np.pad(arr, pad_width, mode=mode, constant_values=constant_values) 85 return arr_pad, pad_left 86 87 88def _crop_after_shift_left(arr, pad_left, with_channels, original_shape_spatial): 89 starts = pad_left 90 stops = tuple(st + sh for st, sh in zip(starts, original_shape_spatial)) 91 spatial_slices = tuple(slice(st, sp) for st, sp in zip(starts, stops)) 92 return arr[(slice(None),) + spatial_slices] if with_channels else arr[spatial_slices] 93 94 95def _load_block(input_, offset, block_shape, halo, padding_mode="reflect", with_channels=False): 96 shape = input_.shape 97 if with_channels: 98 shape = shape[1:] 99 100 starts = [off - ha for off, ha in zip(offset, halo)] 101 stops = [off + bs + ha for off, bs, ha in zip(offset, block_shape, halo)] 102 103 # we pad the input volume if necessary 104 pad_left = None 105 pad_right = None 106 107 # check for padding to the left 108 if any(start < 0 for start in starts): 109 pad_left = tuple(abs(start) if start < 0 else 0 for start in starts) 110 starts = [max(0, start) for start in starts] 111 112 # check for padding to the right 113 if any(stop > shape[i] for i, stop in enumerate(stops)): 114 pad_right = tuple(stop - shape[i] if stop > shape[i] else 0 for i, stop in enumerate(stops)) 115 stops = [min(shape[i], stop) for i, stop in enumerate(stops)] 116 117 bb = tuple(slice(start, stop) for start, stop in zip(starts, stops)) 118 if with_channels: 119 data = input_[(slice(None),) + bb] 120 else: 121 data = input_[bb] 122 123 ndim = len(shape) 124 # pad if necessary 125 if pad_left is not None or pad_right is not None: 126 pad_left = (0,) * ndim if pad_left is None else pad_left 127 pad_right = (0,) * ndim if pad_right is None else pad_right 128 pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right)) 129 if with_channels: 130 pad_width = ((0, 0),) + pad_width 131 data = np.pad(data, pad_width, mode=padding_mode) 132 133 # extend the bounding box for downstream 134 bb = tuple( 135 slice(b.start - pl, b.stop + pr) 136 for b, pl, pr in zip(bb, pad_left, pad_right) 137 ) 138 139 return data, bb 140 141 142def predict_with_halo( 143 input_: ArrayLike, 144 model: torch.nn.Module, 145 gpu_ids: List[Union[str, int]], 146 block_shape: Tuple[int, ...], 147 halo: Tuple[int, ...], 148 output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None, 149 preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize, 150 postprocess: Callable[[np.ndarray], np.ndarray] = None, 151 with_channels: bool = False, 152 skip_block: Callable[[Any], bool] = None, 153 mask: Optional[ArrayLike] = None, 154 disable_tqdm: bool = False, 155 tqdm_desc: str = "predict with halo", 156 prediction_function: Optional[Callable] = None, 157 roi: Optional[Tuple[slice]] = None, 158 iter_list: Optional[List[int]] = None, 159 grid_shift: Optional[Tuple[float, ...]] = None, 160) -> ArrayLike: 161 """Run block-wise network prediction with a halo. 162 163 Args: 164 input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar 165 model: The network. 166 gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`. 167 block_shape: The shape of the inner block to use for prediction. 168 halo: The shape of the halo to use for prediction 169 output: The output data, will be allocated if None is passed. 170 Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel. 171 preprocess: Function to preprocess input data before passing it to the network. 172 postprocess: Function to postprocess the network predictions. 173 with_channels: Whether the input has a channel axis. 174 skip_block: Function to evaluate whether a given input block will be skipped. 175 mask: Elements outside the mask will be ignored in the prediction. 176 disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times). 177 tqdm_desc: Fescription shown by the tqdm output. 178 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 179 roi: A region of interest of the input for which to run prediction. 180 grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0). 181 Returns: 182 The model output. 183 """ 184 devices = [torch.device(gpu) for gpu in gpu_ids] 185 models = [ 186 (model if next(model.parameters()).device == device else deepcopy(model).to(device), device) 187 for device in devices 188 ] 189 n_workers = len(gpu_ids) 190 191 # ---- original shape (spatial only) ---- 192 shape0 = input_.shape 193 shape_spatial0 = shape0[1:] if with_channels else shape0 194 ndim = len(shape_spatial0) 195 assert len(block_shape) == len(halo) == ndim 196 197 # ---- apply grid_shift via padding+cropping (zero padding) ---- 198 input_eff = input_ 199 mask_eff = mask 200 201 if grid_shift is not None: 202 assert len(grid_shift) == ndim, "grid_shift must match number of spatial dims" 203 pad_vox = tuple(int(np.rint(abs(gs) * bs)) for gs, bs in zip(grid_shift, block_shape)) 204 205 if not isinstance(input_eff, np.ndarray): 206 raise TypeError("grid_shift padding currently requires input_ to be a numpy array") 207 208 input_eff, pad_left = _pad_for_shift_left(input_eff, pad_vox, with_channels=with_channels, mode="constant", 209 constant_values=0) 210 211 if mask_eff is not None: 212 if not isinstance(mask_eff, np.ndarray): 213 raise TypeError("grid_shift padding currently requires mask to be a numpy array") 214 mask_eff, _ = _pad_for_shift_left(mask_eff, pad_vox, with_channels=False, mode="constant", 215 constant_values=0) 216 else: 217 pad_left = (0,) * ndim 218 219 # shapes after shift-padding 220 shape_eff = input_eff.shape 221 shape_spatial_eff = shape_eff[1:] if with_channels else shape_eff 222 223 # ---- blocking (on the padded input) ---- 224 if roi is None: 225 blocking = nt.blocking([0] * ndim, shape_spatial_eff, block_shape) 226 else: 227 assert len(roi) == ndim 228 blocking_start = [0 if ro.start is None else ro.start for ro in roi] 229 blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial_eff)] 230 blocking = nt.blocking(blocking_start, blocking_stop, block_shape) 231 232 # ---- output allocation (for padded shape) ---- 233 if output is None: 234 n_out = models[0][0].out_channels 235 output = np.zeros((n_out,) + tuple(shape_spatial_eff), dtype="float32") 236 elif grid_shift: 237 raise ValueError( 238 "grid_shift is not supported together with a user-provided `output`, because " 239 "grid_shift requires internal zero-padding and a final cropping step. " 240 "Pass `output=None` (let this function allocate the output) or disable `grid_shift`. " 241 "Or pad the input manually beforehand." 242 ) 243 244 def predict_block(block_id): 245 worker_id = block_id % n_workers 246 net, device = models[worker_id] 247 248 with torch.no_grad(): 249 block = blocking.getBlock(block_id) 250 offset = [beg for beg in block.begin] 251 inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) 252 253 if mask_eff is not None: 254 mask_block, _ = _load_block(mask_eff, offset, block_shape, halo, with_channels=False) 255 mask_block = mask_block[inner_bb].astype("bool") 256 if mask_block.sum() == 0: 257 return 258 259 inp, _ = _load_block(input_eff, offset, block_shape, halo, with_channels=with_channels) 260 261 if skip_block is not None and skip_block(inp): 262 return 263 264 if preprocess is not None: 265 inp = preprocess(inp) 266 267 # add (channel) and batch axis 268 expand_dims = np.s_[None] if with_channels else np.s_[None, None] 269 inp = torch.from_numpy(inp[expand_dims]).to(device) 270 271 prediction = net(inp) if prediction_function is None else prediction_function(net, inp) 272 273 # allow for list of tensors 274 try: 275 prediction = prediction.cpu().numpy().squeeze(0) 276 except AttributeError: 277 prediction = prediction[0] 278 prediction = prediction.cpu().numpy().squeeze(0) 279 280 if postprocess is not None: 281 prediction = postprocess(prediction) 282 283 if prediction.ndim == ndim + 1: 284 inner_bb_pred = (slice(None),) + inner_bb 285 else: 286 inner_bb_pred = inner_bb 287 prediction = prediction[inner_bb_pred] 288 289 if mask_eff is not None: 290 if prediction.ndim == ndim + 1: 291 mb = np.broadcast_to(mask_block[None], prediction.shape) 292 else: 293 mb = mask_block 294 prediction[~mb] = 0 295 296 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) 297 if isinstance(output, list): # we have multiple outputs and split the prediction channels 298 for out, channel_slice in output: 299 this_bb = bb if out.ndim == ndim else (slice(None),) + bb 300 out[this_bb] = prediction[channel_slice] 301 else: # we only have a single output array 302 if output.ndim == ndim + 1: 303 bb = (slice(None),) + bb 304 output[bb] = prediction 305 306 n_blocks = blocking.numberOfBlocks 307 iteration_ids = range(n_blocks) if iter_list is None else np.array(iter_list) 308 309 with futures.ThreadPoolExecutor(n_workers) as tp: 310 list(tqdm(tp.map(predict_block, iteration_ids), 311 total=len(iteration_ids), 312 disable=disable_tqdm, 313 desc=tqdm_desc)) 314 315 # ---- crop away the shift padding so the returned output matches original shape ---- 316 if grid_shift is not None: 317 output = _crop_after_shift_left(output, pad_left, with_channels=(output.ndim == ndim+1), 318 original_shape_spatial=tuple(shape_spatial0)) 319 320 return output
def
predict_with_padding( model: torch.nn.modules.module.Module, input_: numpy.ndarray, min_divisible: Tuple[int, ...], device: Union[torch.device, str, NoneType] = None, with_channels: bool = False, prediction_function: Callable[[Any], Any] = None) -> numpy.ndarray:
19def predict_with_padding( 20 model: torch.nn.Module, 21 input_: np.ndarray, 22 min_divisible: Tuple[int, ...], 23 device: Optional[Union[torch.device, str]] = None, 24 with_channels: bool = False, 25 prediction_function: Callable[[Any], Any] = None 26) -> np.ndarray: 27 """Run prediction with padding for a model that can only deal with inputs divisible by specific factors. 28 29 Args: 30 model: The model. 31 input_: The input for prediction. 32 min_divisible: The minimal factors the input shape must be divisible by. 33 For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels. 34 device: The device of the model. If not given, will be derived from the model parameters. 35 with_channels: Whether the input data contains channels. 36 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 37 38 Returns: 39 np.ndarray: The ouptut of the model. 40 """ 41 if with_channels: 42 assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}" 43 min_divisible_ = (1,) + min_divisible 44 else: 45 assert len(min_divisible) == input_.ndim 46 min_divisible_ = min_divisible 47 48 if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)): 49 pad_width = tuple( 50 (0, 0 if sh % md == 0 else md - sh % md) 51 for sh, md in zip(input_.shape, min_divisible_) 52 ) 53 crop_padding = tuple(slice(0, sh) for sh in input_.shape) 54 input_ = np.pad(input_, pad_width, mode="reflect") 55 else: 56 crop_padding = None 57 58 ndim = input_.ndim 59 ndim_model = 1 + ndim if with_channels else 2 + ndim 60 61 if device is None: 62 device = next(model.parameters()).device 63 64 expand_dim = (None,) * (ndim_model - ndim) 65 with torch.no_grad(): 66 model_input = torch.from_numpy(input_[expand_dim]).to(device) 67 output = model(model_input) if prediction_function is None else prediction_function(model, model_input) 68 output = output.cpu().numpy() 69 70 if crop_padding is not None: 71 crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding 72 output = output[crop_padding] 73 74 return output
Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
Arguments:
- model: The model.
- input_: The input for prediction.
- min_divisible: The minimal factors the input shape must be divisible by. For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
- device: The device of the model. If not given, will be derived from the model parameters.
- with_channels: Whether the input data contains channels.
- prediction_function: A wrapper function for prediction to enable custom prediction procedures.
Returns:
np.ndarray: The ouptut of the model.
def
predict_with_halo( input_: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], model: torch.nn.modules.module.Module, gpu_ids: List[Union[str, int]], block_shape: Tuple[int, ...], halo: Tuple[int, ...], output: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], List[Tuple[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], slice]], NoneType] = None, preprocess: Callable[[Union[torch.Tensor, numpy.ndarray]], Union[torch.Tensor, numpy.ndarray]] = <function standardize>, postprocess: Callable[[numpy.ndarray], numpy.ndarray] = None, with_channels: bool = False, skip_block: Callable[[Any], bool] = None, mask: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], NoneType] = None, disable_tqdm: bool = False, tqdm_desc: str = 'predict with halo', prediction_function: Optional[Callable] = None, roi: Optional[Tuple[slice]] = None, iter_list: Optional[List[int]] = None, grid_shift: Optional[Tuple[float, ...]] = None) -> Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]]:
143def predict_with_halo( 144 input_: ArrayLike, 145 model: torch.nn.Module, 146 gpu_ids: List[Union[str, int]], 147 block_shape: Tuple[int, ...], 148 halo: Tuple[int, ...], 149 output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None, 150 preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize, 151 postprocess: Callable[[np.ndarray], np.ndarray] = None, 152 with_channels: bool = False, 153 skip_block: Callable[[Any], bool] = None, 154 mask: Optional[ArrayLike] = None, 155 disable_tqdm: bool = False, 156 tqdm_desc: str = "predict with halo", 157 prediction_function: Optional[Callable] = None, 158 roi: Optional[Tuple[slice]] = None, 159 iter_list: Optional[List[int]] = None, 160 grid_shift: Optional[Tuple[float, ...]] = None, 161) -> ArrayLike: 162 """Run block-wise network prediction with a halo. 163 164 Args: 165 input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar 166 model: The network. 167 gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`. 168 block_shape: The shape of the inner block to use for prediction. 169 halo: The shape of the halo to use for prediction 170 output: The output data, will be allocated if None is passed. 171 Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel. 172 preprocess: Function to preprocess input data before passing it to the network. 173 postprocess: Function to postprocess the network predictions. 174 with_channels: Whether the input has a channel axis. 175 skip_block: Function to evaluate whether a given input block will be skipped. 176 mask: Elements outside the mask will be ignored in the prediction. 177 disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times). 178 tqdm_desc: Fescription shown by the tqdm output. 179 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 180 roi: A region of interest of the input for which to run prediction. 181 grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0). 182 Returns: 183 The model output. 184 """ 185 devices = [torch.device(gpu) for gpu in gpu_ids] 186 models = [ 187 (model if next(model.parameters()).device == device else deepcopy(model).to(device), device) 188 for device in devices 189 ] 190 n_workers = len(gpu_ids) 191 192 # ---- original shape (spatial only) ---- 193 shape0 = input_.shape 194 shape_spatial0 = shape0[1:] if with_channels else shape0 195 ndim = len(shape_spatial0) 196 assert len(block_shape) == len(halo) == ndim 197 198 # ---- apply grid_shift via padding+cropping (zero padding) ---- 199 input_eff = input_ 200 mask_eff = mask 201 202 if grid_shift is not None: 203 assert len(grid_shift) == ndim, "grid_shift must match number of spatial dims" 204 pad_vox = tuple(int(np.rint(abs(gs) * bs)) for gs, bs in zip(grid_shift, block_shape)) 205 206 if not isinstance(input_eff, np.ndarray): 207 raise TypeError("grid_shift padding currently requires input_ to be a numpy array") 208 209 input_eff, pad_left = _pad_for_shift_left(input_eff, pad_vox, with_channels=with_channels, mode="constant", 210 constant_values=0) 211 212 if mask_eff is not None: 213 if not isinstance(mask_eff, np.ndarray): 214 raise TypeError("grid_shift padding currently requires mask to be a numpy array") 215 mask_eff, _ = _pad_for_shift_left(mask_eff, pad_vox, with_channels=False, mode="constant", 216 constant_values=0) 217 else: 218 pad_left = (0,) * ndim 219 220 # shapes after shift-padding 221 shape_eff = input_eff.shape 222 shape_spatial_eff = shape_eff[1:] if with_channels else shape_eff 223 224 # ---- blocking (on the padded input) ---- 225 if roi is None: 226 blocking = nt.blocking([0] * ndim, shape_spatial_eff, block_shape) 227 else: 228 assert len(roi) == ndim 229 blocking_start = [0 if ro.start is None else ro.start for ro in roi] 230 blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial_eff)] 231 blocking = nt.blocking(blocking_start, blocking_stop, block_shape) 232 233 # ---- output allocation (for padded shape) ---- 234 if output is None: 235 n_out = models[0][0].out_channels 236 output = np.zeros((n_out,) + tuple(shape_spatial_eff), dtype="float32") 237 elif grid_shift: 238 raise ValueError( 239 "grid_shift is not supported together with a user-provided `output`, because " 240 "grid_shift requires internal zero-padding and a final cropping step. " 241 "Pass `output=None` (let this function allocate the output) or disable `grid_shift`. " 242 "Or pad the input manually beforehand." 243 ) 244 245 def predict_block(block_id): 246 worker_id = block_id % n_workers 247 net, device = models[worker_id] 248 249 with torch.no_grad(): 250 block = blocking.getBlock(block_id) 251 offset = [beg for beg in block.begin] 252 inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) 253 254 if mask_eff is not None: 255 mask_block, _ = _load_block(mask_eff, offset, block_shape, halo, with_channels=False) 256 mask_block = mask_block[inner_bb].astype("bool") 257 if mask_block.sum() == 0: 258 return 259 260 inp, _ = _load_block(input_eff, offset, block_shape, halo, with_channels=with_channels) 261 262 if skip_block is not None and skip_block(inp): 263 return 264 265 if preprocess is not None: 266 inp = preprocess(inp) 267 268 # add (channel) and batch axis 269 expand_dims = np.s_[None] if with_channels else np.s_[None, None] 270 inp = torch.from_numpy(inp[expand_dims]).to(device) 271 272 prediction = net(inp) if prediction_function is None else prediction_function(net, inp) 273 274 # allow for list of tensors 275 try: 276 prediction = prediction.cpu().numpy().squeeze(0) 277 except AttributeError: 278 prediction = prediction[0] 279 prediction = prediction.cpu().numpy().squeeze(0) 280 281 if postprocess is not None: 282 prediction = postprocess(prediction) 283 284 if prediction.ndim == ndim + 1: 285 inner_bb_pred = (slice(None),) + inner_bb 286 else: 287 inner_bb_pred = inner_bb 288 prediction = prediction[inner_bb_pred] 289 290 if mask_eff is not None: 291 if prediction.ndim == ndim + 1: 292 mb = np.broadcast_to(mask_block[None], prediction.shape) 293 else: 294 mb = mask_block 295 prediction[~mb] = 0 296 297 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) 298 if isinstance(output, list): # we have multiple outputs and split the prediction channels 299 for out, channel_slice in output: 300 this_bb = bb if out.ndim == ndim else (slice(None),) + bb 301 out[this_bb] = prediction[channel_slice] 302 else: # we only have a single output array 303 if output.ndim == ndim + 1: 304 bb = (slice(None),) + bb 305 output[bb] = prediction 306 307 n_blocks = blocking.numberOfBlocks 308 iteration_ids = range(n_blocks) if iter_list is None else np.array(iter_list) 309 310 with futures.ThreadPoolExecutor(n_workers) as tp: 311 list(tqdm(tp.map(predict_block, iteration_ids), 312 total=len(iteration_ids), 313 disable=disable_tqdm, 314 desc=tqdm_desc)) 315 316 # ---- crop away the shift padding so the returned output matches original shape ---- 317 if grid_shift is not None: 318 output = _crop_after_shift_left(output, pad_left, with_channels=(output.ndim == ndim+1), 319 original_shape_spatial=tuple(shape_spatial0)) 320 321 return output
Run block-wise network prediction with a halo.
Arguments:
- input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
- model: The network.
- gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass
["cpu"]. - block_shape: The shape of the inner block to use for prediction.
- halo: The shape of the halo to use for prediction
- output: The output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
- preprocess: Function to preprocess input data before passing it to the network.
- postprocess: Function to postprocess the network predictions.
- with_channels: Whether the input has a channel axis.
- skip_block: Function to evaluate whether a given input block will be skipped.
- mask: Elements outside the mask will be ignored in the prediction.
- disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
- tqdm_desc: Fescription shown by the tqdm output.
- prediction_function: A wrapper function for prediction to enable custom prediction procedures.
- roi: A region of interest of the input for which to run prediction.
- grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0).
Returns:
The model output.