torch_em.util.prediction
1from copy import deepcopy 2from concurrent import futures 3from typing import Tuple, Union, Callable, Any, List, Optional 4 5import numpy as np 6import nifty.tools as nt 7import torch 8from numpy.typing import ArrayLike 9 10try: 11 from napari.utils import progress as tqdm 12except ImportError: 13 from tqdm import tqdm 14 15from ..transform.raw import standardize 16 17 18def predict_with_padding( 19 model: torch.nn.Module, 20 input_: np.ndarray, 21 min_divisible: Tuple[int, ...], 22 device: Optional[Union[torch.device, str]] = None, 23 with_channels: bool = False, 24 prediction_function: Callable[[Any], Any] = None 25) -> np.ndarray: 26 """Run prediction with padding for a model that can only deal with inputs divisible by specific factors. 27 28 Args: 29 model: The model. 30 input_: The input for prediction. 31 min_divisible: The minimal factors the input shape must be divisible by. 32 For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels. 33 device: The device of the model. If not given, will be derived from the model parameters. 34 with_channels: Whether the input data contains channels. 35 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 36 37 Returns: 38 np.ndarray: The ouptut of the model. 39 """ 40 if with_channels: 41 assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}" 42 min_divisible_ = (1,) + min_divisible 43 else: 44 assert len(min_divisible) == input_.ndim 45 min_divisible_ = min_divisible 46 47 if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)): 48 pad_width = tuple( 49 (0, 0 if sh % md == 0 else md - sh % md) 50 for sh, md in zip(input_.shape, min_divisible_) 51 ) 52 crop_padding = tuple(slice(0, sh) for sh in input_.shape) 53 input_ = np.pad(input_, pad_width, mode="reflect") 54 else: 55 crop_padding = None 56 57 ndim = input_.ndim 58 ndim_model = 1 + ndim if with_channels else 2 + ndim 59 60 if device is None: 61 device = next(model.parameters()).device 62 63 expand_dim = (None,) * (ndim_model - ndim) 64 with torch.no_grad(): 65 model_input = torch.from_numpy(input_[expand_dim]).to(device) 66 output = model(model_input) if prediction_function is None else prediction_function(model, model_input) 67 output = output.cpu().numpy() 68 69 if crop_padding is not None: 70 crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding 71 output = output[crop_padding] 72 73 return output 74 75 76def _load_block(input_, offset, block_shape, halo, padding_mode="reflect", with_channels=False): 77 shape = input_.shape 78 if with_channels: 79 shape = shape[1:] 80 81 starts = [off - ha for off, ha in zip(offset, halo)] 82 stops = [off + bs + ha for off, bs, ha in zip(offset, block_shape, halo)] 83 84 # we pad the input volume if necessary 85 pad_left = None 86 pad_right = None 87 88 # check for padding to the left 89 if any(start < 0 for start in starts): 90 pad_left = tuple(abs(start) if start < 0 else 0 for start in starts) 91 starts = [max(0, start) for start in starts] 92 93 # check for padding to the right 94 if any(stop > shape[i] for i, stop in enumerate(stops)): 95 pad_right = tuple(stop - shape[i] if stop > shape[i] else 0 for i, stop in enumerate(stops)) 96 stops = [min(shape[i], stop) for i, stop in enumerate(stops)] 97 98 bb = tuple(slice(start, stop) for start, stop in zip(starts, stops)) 99 if with_channels: 100 data = input_[(slice(None),) + bb] 101 else: 102 data = input_[bb] 103 104 ndim = len(shape) 105 # pad if necessary 106 if pad_left is not None or pad_right is not None: 107 pad_left = (0,) * ndim if pad_left is None else pad_left 108 pad_right = (0,) * ndim if pad_right is None else pad_right 109 pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right)) 110 if with_channels: 111 pad_width = ((0, 0),) + pad_width 112 data = np.pad(data, pad_width, mode=padding_mode) 113 114 # extend the bounding box for downstream 115 bb = tuple( 116 slice(b.start - pl, b.stop + pr) 117 for b, pl, pr in zip(bb, pad_left, pad_right) 118 ) 119 120 return data, bb 121 122 123def predict_with_halo( 124 input_: ArrayLike, 125 model: torch.nn.Module, 126 gpu_ids: List[Union[str, int]], 127 block_shape: Tuple[int, ...], 128 halo: Tuple[int, ...], 129 output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None, 130 preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize, 131 postprocess: Callable[[np.ndarray], np.ndarray] = None, 132 with_channels: bool = False, 133 skip_block: Callable[[Any], bool] = None, 134 mask: Optional[ArrayLike] = None, 135 disable_tqdm: bool = False, 136 tqdm_desc: str = "predict with halo", 137 prediction_function: Optional[Callable] = None, 138 roi: Optional[Tuple[slice]] = None, 139) -> ArrayLike: 140 """Run block-wise network prediction with a halo. 141 142 Args: 143 input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar 144 model: The network. 145 gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`. 146 block_shape: The shape of the inner block to use for prediction. 147 halo: The shape of the halo to use for prediction 148 output: The output data, will be allocated if None is passed. 149 Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel. 150 preprocess: Function to preprocess input data before passing it to the network. 151 postprocess: Function to postprocess the network predictions. 152 with_channels: Whether the input has a channel axis. 153 skip_block: Function to evaluate whether a given input block will be skipped. 154 mask: Elements outside the mask will be ignored in the prediction. 155 disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times). 156 tqdm_desc: Fescription shown by the tqdm output. 157 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 158 roi: A region of interest of the input for which to run prediction. 159 160 Returns: 161 The model output. 162 """ 163 devices = [torch.device(gpu) for gpu in gpu_ids] 164 models = [ 165 (model if next(model.parameters()).device == device else deepcopy(model).to(device), device) 166 for device in devices 167 ] 168 169 n_workers = len(gpu_ids) 170 shape = input_.shape 171 if with_channels: 172 shape = shape[1:] 173 174 ndim = len(shape) 175 assert len(block_shape) == len(halo) == ndim 176 177 if roi is None: 178 blocking = nt.blocking([0] * ndim, shape, block_shape) 179 else: 180 assert len(roi) == ndim 181 blocking_start = [0 if ro.start is None else ro.start for ro in roi] 182 blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape)] 183 blocking = nt.blocking(blocking_start, blocking_stop, block_shape) 184 185 if output is None: 186 n_out = models[0][0].out_channels 187 output = np.zeros((n_out,) + shape, dtype="float32") 188 189 def predict_block(block_id): 190 worker_id = block_id % n_workers 191 net, device = models[worker_id] 192 193 with torch.no_grad(): 194 block = blocking.getBlock(block_id) 195 offset = [beg for beg in block.begin] 196 inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) 197 198 if mask is not None: 199 mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False) 200 mask_block = mask_block[inner_bb].astype("bool") 201 if mask_block.sum() == 0: 202 return 203 204 inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels) 205 206 if skip_block is not None and skip_block(inp): 207 return 208 209 if preprocess is not None: 210 inp = preprocess(inp) 211 212 # add (channel) and batch axis 213 expand_dims = np.s_[None] if with_channels else np.s_[None, None] 214 inp = torch.from_numpy(inp[expand_dims]).to(device) 215 216 prediction = net(inp) if prediction_function is None else prediction_function(net, inp) 217 218 # allow for list of tensors 219 try: 220 prediction = prediction.cpu().numpy().squeeze(0) 221 except AttributeError: 222 prediction = prediction[0] 223 prediction = prediction.cpu().numpy().squeeze(0) 224 225 if postprocess is not None: 226 prediction = postprocess(prediction) 227 228 if prediction.ndim == ndim + 1: 229 inner_bb = (slice(None),) + inner_bb 230 prediction = prediction[inner_bb] 231 232 if mask is not None: 233 if prediction.ndim == ndim + 1: 234 mask_block = np.concatenate(prediction.shape[0] * [mask_block[None]], axis=0) 235 prediction[~mask_block] = 0 236 237 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) 238 if isinstance(output, list): # we have multiple outputs and split the prediction channels 239 for out, channel_slice in output: 240 this_bb = bb if out.ndim == ndim else (slice(None),) + bb 241 out[this_bb] = prediction[channel_slice] 242 243 else: # we only have a single output array 244 if output.ndim == ndim + 1: 245 bb = (slice(None),) + bb 246 output[bb] = prediction 247 248 n_blocks = blocking.numberOfBlocks 249 with futures.ThreadPoolExecutor(n_workers) as tp: 250 list(tqdm(tp.map(predict_block, range(n_blocks)), total=n_blocks, disable=disable_tqdm, desc=tqdm_desc)) 251 252 return output
def
predict_with_padding( model: torch.nn.modules.module.Module, input_: numpy.ndarray, min_divisible: Tuple[int, ...], device: Union[torch.device, str, NoneType] = None, with_channels: bool = False, prediction_function: Callable[[Any], Any] = None) -> numpy.ndarray:
19def predict_with_padding( 20 model: torch.nn.Module, 21 input_: np.ndarray, 22 min_divisible: Tuple[int, ...], 23 device: Optional[Union[torch.device, str]] = None, 24 with_channels: bool = False, 25 prediction_function: Callable[[Any], Any] = None 26) -> np.ndarray: 27 """Run prediction with padding for a model that can only deal with inputs divisible by specific factors. 28 29 Args: 30 model: The model. 31 input_: The input for prediction. 32 min_divisible: The minimal factors the input shape must be divisible by. 33 For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels. 34 device: The device of the model. If not given, will be derived from the model parameters. 35 with_channels: Whether the input data contains channels. 36 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 37 38 Returns: 39 np.ndarray: The ouptut of the model. 40 """ 41 if with_channels: 42 assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}" 43 min_divisible_ = (1,) + min_divisible 44 else: 45 assert len(min_divisible) == input_.ndim 46 min_divisible_ = min_divisible 47 48 if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)): 49 pad_width = tuple( 50 (0, 0 if sh % md == 0 else md - sh % md) 51 for sh, md in zip(input_.shape, min_divisible_) 52 ) 53 crop_padding = tuple(slice(0, sh) for sh in input_.shape) 54 input_ = np.pad(input_, pad_width, mode="reflect") 55 else: 56 crop_padding = None 57 58 ndim = input_.ndim 59 ndim_model = 1 + ndim if with_channels else 2 + ndim 60 61 if device is None: 62 device = next(model.parameters()).device 63 64 expand_dim = (None,) * (ndim_model - ndim) 65 with torch.no_grad(): 66 model_input = torch.from_numpy(input_[expand_dim]).to(device) 67 output = model(model_input) if prediction_function is None else prediction_function(model, model_input) 68 output = output.cpu().numpy() 69 70 if crop_padding is not None: 71 crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding 72 output = output[crop_padding] 73 74 return output
Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
Arguments:
- model: The model.
- input_: The input for prediction.
- min_divisible: The minimal factors the input shape must be divisible by. For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
- device: The device of the model. If not given, will be derived from the model parameters.
- with_channels: Whether the input data contains channels.
- prediction_function: A wrapper function for prediction to enable custom prediction procedures.
Returns:
np.ndarray: The ouptut of the model.
def
predict_with_halo( input_: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], model: torch.nn.modules.module.Module, gpu_ids: List[Union[str, int]], block_shape: Tuple[int, ...], halo: Tuple[int, ...], output: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]], List[Tuple[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], slice]], NoneType] = None, preprocess: Callable[[Union[torch.Tensor, numpy.ndarray]], Union[torch.Tensor, numpy.ndarray]] = <function standardize>, postprocess: Callable[[numpy.ndarray], numpy.ndarray] = None, with_channels: bool = False, skip_block: Callable[[Any], bool] = None, mask: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]], NoneType] = None, disable_tqdm: bool = False, tqdm_desc: str = 'predict with halo', prediction_function: Optional[Callable] = None, roi: Optional[Tuple[slice]] = None) -> Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]:
124def predict_with_halo( 125 input_: ArrayLike, 126 model: torch.nn.Module, 127 gpu_ids: List[Union[str, int]], 128 block_shape: Tuple[int, ...], 129 halo: Tuple[int, ...], 130 output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None, 131 preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize, 132 postprocess: Callable[[np.ndarray], np.ndarray] = None, 133 with_channels: bool = False, 134 skip_block: Callable[[Any], bool] = None, 135 mask: Optional[ArrayLike] = None, 136 disable_tqdm: bool = False, 137 tqdm_desc: str = "predict with halo", 138 prediction_function: Optional[Callable] = None, 139 roi: Optional[Tuple[slice]] = None, 140) -> ArrayLike: 141 """Run block-wise network prediction with a halo. 142 143 Args: 144 input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar 145 model: The network. 146 gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`. 147 block_shape: The shape of the inner block to use for prediction. 148 halo: The shape of the halo to use for prediction 149 output: The output data, will be allocated if None is passed. 150 Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel. 151 preprocess: Function to preprocess input data before passing it to the network. 152 postprocess: Function to postprocess the network predictions. 153 with_channels: Whether the input has a channel axis. 154 skip_block: Function to evaluate whether a given input block will be skipped. 155 mask: Elements outside the mask will be ignored in the prediction. 156 disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times). 157 tqdm_desc: Fescription shown by the tqdm output. 158 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 159 roi: A region of interest of the input for which to run prediction. 160 161 Returns: 162 The model output. 163 """ 164 devices = [torch.device(gpu) for gpu in gpu_ids] 165 models = [ 166 (model if next(model.parameters()).device == device else deepcopy(model).to(device), device) 167 for device in devices 168 ] 169 170 n_workers = len(gpu_ids) 171 shape = input_.shape 172 if with_channels: 173 shape = shape[1:] 174 175 ndim = len(shape) 176 assert len(block_shape) == len(halo) == ndim 177 178 if roi is None: 179 blocking = nt.blocking([0] * ndim, shape, block_shape) 180 else: 181 assert len(roi) == ndim 182 blocking_start = [0 if ro.start is None else ro.start for ro in roi] 183 blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape)] 184 blocking = nt.blocking(blocking_start, blocking_stop, block_shape) 185 186 if output is None: 187 n_out = models[0][0].out_channels 188 output = np.zeros((n_out,) + shape, dtype="float32") 189 190 def predict_block(block_id): 191 worker_id = block_id % n_workers 192 net, device = models[worker_id] 193 194 with torch.no_grad(): 195 block = blocking.getBlock(block_id) 196 offset = [beg for beg in block.begin] 197 inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) 198 199 if mask is not None: 200 mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False) 201 mask_block = mask_block[inner_bb].astype("bool") 202 if mask_block.sum() == 0: 203 return 204 205 inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels) 206 207 if skip_block is not None and skip_block(inp): 208 return 209 210 if preprocess is not None: 211 inp = preprocess(inp) 212 213 # add (channel) and batch axis 214 expand_dims = np.s_[None] if with_channels else np.s_[None, None] 215 inp = torch.from_numpy(inp[expand_dims]).to(device) 216 217 prediction = net(inp) if prediction_function is None else prediction_function(net, inp) 218 219 # allow for list of tensors 220 try: 221 prediction = prediction.cpu().numpy().squeeze(0) 222 except AttributeError: 223 prediction = prediction[0] 224 prediction = prediction.cpu().numpy().squeeze(0) 225 226 if postprocess is not None: 227 prediction = postprocess(prediction) 228 229 if prediction.ndim == ndim + 1: 230 inner_bb = (slice(None),) + inner_bb 231 prediction = prediction[inner_bb] 232 233 if mask is not None: 234 if prediction.ndim == ndim + 1: 235 mask_block = np.concatenate(prediction.shape[0] * [mask_block[None]], axis=0) 236 prediction[~mask_block] = 0 237 238 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) 239 if isinstance(output, list): # we have multiple outputs and split the prediction channels 240 for out, channel_slice in output: 241 this_bb = bb if out.ndim == ndim else (slice(None),) + bb 242 out[this_bb] = prediction[channel_slice] 243 244 else: # we only have a single output array 245 if output.ndim == ndim + 1: 246 bb = (slice(None),) + bb 247 output[bb] = prediction 248 249 n_blocks = blocking.numberOfBlocks 250 with futures.ThreadPoolExecutor(n_workers) as tp: 251 list(tqdm(tp.map(predict_block, range(n_blocks)), total=n_blocks, disable=disable_tqdm, desc=tqdm_desc)) 252 253 return output
Run block-wise network prediction with a halo.
Arguments:
- input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
- model: The network.
- gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass
["cpu"]
. - block_shape: The shape of the inner block to use for prediction.
- halo: The shape of the halo to use for prediction
- output: The output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
- preprocess: Function to preprocess input data before passing it to the network.
- postprocess: Function to postprocess the network predictions.
- with_channels: Whether the input has a channel axis.
- skip_block: Function to evaluate whether a given input block will be skipped.
- mask: Elements outside the mask will be ignored in the prediction.
- disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
- tqdm_desc: Fescription shown by the tqdm output.
- prediction_function: A wrapper function for prediction to enable custom prediction procedures.
- roi: A region of interest of the input for which to run prediction.
Returns:
The model output.