torch_em.util.prediction
1from concurrent import futures 2from copy import deepcopy 3 4import nifty.tools as nt 5import numpy as np 6import torch 7from tqdm import tqdm 8 9from ..transform.raw import standardize 10 11 12def predict_with_padding(model, input_, min_divisible, device, with_channels=False, prediction_function=None): 13 """Run prediction with padding for a model that can only deal with 14 inputs divisible by specific factors. 15 16 Arguments: 17 model [torch.nn.Module]: the model 18 input_ [np.ndarray]: the input () 19 min_divisible [tuple]: the divisibe shape factors 20 (e.g. (16, 16) for a model that needs inputs divisible by at least 16 pixels) 21 device [str, torch.device]: the device of the model 22 with_channels [bool]: Whether the input data contains channels (default: False) 23 prediction_function [callable] - A wrapper function for prediction to enable custom prediction procedures 24 (default: None) 25 Returns: 26 np.ndarray: the ouptut of the model 27 """ 28 if with_channels: 29 assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}" 30 min_divisible_ = (1,) + min_divisible 31 else: 32 assert len(min_divisible) == input_.ndim 33 min_divisible_ = min_divisible 34 35 if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)): 36 pad_width = tuple( 37 (0, 0 if sh % md == 0 else md - sh % md) 38 for sh, md in zip(input_.shape, min_divisible_) 39 ) 40 crop_padding = tuple(slice(0, sh) for sh in input_.shape) 41 input_ = np.pad(input_, pad_width, mode="reflect") 42 else: 43 crop_padding = None 44 45 ndim = input_.ndim 46 ndim_model = 1 + ndim if with_channels else 2 + ndim 47 48 expand_dim = (None,) * (ndim_model - ndim) 49 with torch.no_grad(): 50 model_input = torch.from_numpy(input_[expand_dim]).to(device) 51 output = model(model_input) if prediction_function is None else prediction_function(model, model_input) 52 output = output.cpu().numpy() 53 54 if crop_padding is not None: 55 crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding 56 output = output[crop_padding] 57 58 return output 59 60 61def _load_block(input_, offset, block_shape, halo, padding_mode="reflect", with_channels=False): 62 shape = input_.shape 63 if with_channels: 64 shape = shape[1:] 65 66 starts = [off - ha for off, ha in zip(offset, halo)] 67 stops = [off + bs + ha for off, bs, ha in zip(offset, block_shape, halo)] 68 69 # we pad the input volume if necessary 70 pad_left = None 71 pad_right = None 72 73 # check for padding to the left 74 if any(start < 0 for start in starts): 75 pad_left = tuple(abs(start) if start < 0 else 0 for start in starts) 76 starts = [max(0, start) for start in starts] 77 78 # check for padding to the right 79 if any(stop > shape[i] for i, stop in enumerate(stops)): 80 pad_right = tuple(stop - shape[i] if stop > shape[i] else 0 for i, stop in enumerate(stops)) 81 stops = [min(shape[i], stop) for i, stop in enumerate(stops)] 82 83 bb = tuple(slice(start, stop) for start, stop in zip(starts, stops)) 84 if with_channels: 85 data = input_[(slice(None),) + bb] 86 else: 87 data = input_[bb] 88 89 ndim = len(shape) 90 # pad if necessary 91 if pad_left is not None or pad_right is not None: 92 pad_left = (0,) * ndim if pad_left is None else pad_left 93 pad_right = (0,) * ndim if pad_right is None else pad_right 94 pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right)) 95 if with_channels: 96 pad_width = ((0, 0),) + pad_width 97 data = np.pad(data, pad_width, mode=padding_mode) 98 99 # extend the bounding box for downstream 100 bb = tuple( 101 slice(b.start - pl, b.stop + pr) 102 for b, pl, pr in zip(bb, pad_left, pad_right) 103 ) 104 105 return data, bb 106 107 108# TODO half precision prediction 109def predict_with_halo( 110 input_, 111 model, 112 gpu_ids, 113 block_shape, 114 halo, 115 output=None, 116 preprocess=standardize, 117 postprocess=None, 118 with_channels=False, 119 skip_block=None, 120 mask=None, 121 disable_tqdm=False, 122 tqdm_desc="predict with halo", 123 prediction_function=None, 124 roi=None, 125): 126 """ Run block-wise network prediction with halo. 127 128 Arguments: 129 input_ [arraylike] - the input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar 130 model [nn.Module] - the network 131 gpu_ids [list[int or string]] - list of gpus id used for prediction 132 block_shape [tuple] - shape of inner block used for prediction 133 halo [tuple] - shape of halo used for prediction 134 output [arraylike or list[tuple[arraylike, slice]]] - output data, will be allocated if None is passed. 135 Instead of a single output, this can also be a list of outputs and the corresponding channels. 136 (default: None) 137 preprocess [callable] - function to preprocess input data before passing it to the network. 138 (default: standardize) 139 postprocess [callable] - function to postprocess the network predictions (default: None) 140 with_channels [bool] - whether the input has a channel axis (default: False) 141 skip_block [callable] - function to evaluate wheter a given input block should be skipped (default: None) 142 mask [arraylike] - elements outside the mask will be ignored in the prediction (default: None) 143 disable_tqdm [bool] - flag that allows to disable tqdm output (e.g. if function is called multiple times) 144 tqdm_desc [str] - description shown by the tqdm output 145 prediction_function [callable] - a wrapper function for prediction to enable custom prediction procedures. 146 (default: None) 147 roi [tuple[slice]] - a region of interest for which to run prediction. (default: None) 148 """ 149 devices = [torch.device(gpu) for gpu in gpu_ids] 150 models = [ 151 (model if next(model.parameters()).device == device else deepcopy(model).to(device), device) 152 for device in devices 153 ] 154 155 n_workers = len(gpu_ids) 156 shape = input_.shape 157 if with_channels: 158 shape = shape[1:] 159 ndim = len(shape) 160 assert len(block_shape) == len(halo) == ndim 161 162 if roi is None: 163 blocking = nt.blocking([0] * ndim, shape, block_shape) 164 else: 165 assert len(roi) == ndim 166 blocking_start = [0 if ro.start is None else ro.start for ro in roi] 167 blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape)] 168 blocking = nt.blocking(blocking_start, blocking_stop, block_shape) 169 170 if output is None: 171 n_out = models[0][0].out_channels 172 output = np.zeros((n_out,) + shape, dtype="float32") 173 174 def predict_block(block_id): 175 worker_id = block_id % n_workers 176 net, device = models[worker_id] 177 178 with torch.no_grad(): 179 block = blocking.getBlock(block_id) 180 offset = [beg for beg in block.begin] 181 inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) 182 183 if mask is not None: 184 mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False) 185 mask_block = mask_block[inner_bb] 186 if mask_block.sum() == 0: 187 return 188 189 inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels) 190 191 if skip_block is not None and skip_block(inp): 192 return 193 194 if preprocess is not None: 195 inp = preprocess(inp) 196 197 # add (channel) and batch axis 198 expand_dims = np.s_[None] if with_channels else np.s_[None, None] 199 inp = torch.from_numpy(inp[expand_dims]).to(device) 200 201 prediction = net(inp) if prediction_function is None else prediction_function(net, inp) 202 203 # allow for list of tensors 204 try: 205 prediction = prediction.cpu().numpy().squeeze(0) 206 except AttributeError: 207 prediction = prediction[0] 208 prediction = prediction.cpu().numpy().squeeze(0) 209 210 if postprocess is not None: 211 prediction = postprocess(prediction) 212 213 if prediction.ndim == ndim + 1: 214 inner_bb = (slice(None),) + inner_bb 215 prediction = prediction[inner_bb] 216 217 if mask is not None: 218 if prediction.ndim == ndim + 1: 219 mask_block = np.concatenate(prediction.shape[0] * [mask_block[None]], axis=0) 220 prediction[~mask_block] = 0 221 222 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) 223 if isinstance(output, list): # we have multiple outputs and split the prediction channels 224 for out, channel_slice in output: 225 this_bb = bb if out.ndim == ndim else (slice(None),) + bb 226 out[this_bb] = prediction[channel_slice] 227 else: # we only have a single output array 228 if output.ndim == ndim + 1: 229 bb = (slice(None),) + bb 230 output[bb] = prediction 231 232 n_blocks = blocking.numberOfBlocks 233 with futures.ThreadPoolExecutor(n_workers) as tp: 234 list(tqdm(tp.map(predict_block, range(n_blocks)), total=n_blocks, disable=disable_tqdm, desc=tqdm_desc)) 235 236 return output
def
predict_with_padding( model, input_, min_divisible, device, with_channels=False, prediction_function=None):
13def predict_with_padding(model, input_, min_divisible, device, with_channels=False, prediction_function=None): 14 """Run prediction with padding for a model that can only deal with 15 inputs divisible by specific factors. 16 17 Arguments: 18 model [torch.nn.Module]: the model 19 input_ [np.ndarray]: the input () 20 min_divisible [tuple]: the divisibe shape factors 21 (e.g. (16, 16) for a model that needs inputs divisible by at least 16 pixels) 22 device [str, torch.device]: the device of the model 23 with_channels [bool]: Whether the input data contains channels (default: False) 24 prediction_function [callable] - A wrapper function for prediction to enable custom prediction procedures 25 (default: None) 26 Returns: 27 np.ndarray: the ouptut of the model 28 """ 29 if with_channels: 30 assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}" 31 min_divisible_ = (1,) + min_divisible 32 else: 33 assert len(min_divisible) == input_.ndim 34 min_divisible_ = min_divisible 35 36 if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)): 37 pad_width = tuple( 38 (0, 0 if sh % md == 0 else md - sh % md) 39 for sh, md in zip(input_.shape, min_divisible_) 40 ) 41 crop_padding = tuple(slice(0, sh) for sh in input_.shape) 42 input_ = np.pad(input_, pad_width, mode="reflect") 43 else: 44 crop_padding = None 45 46 ndim = input_.ndim 47 ndim_model = 1 + ndim if with_channels else 2 + ndim 48 49 expand_dim = (None,) * (ndim_model - ndim) 50 with torch.no_grad(): 51 model_input = torch.from_numpy(input_[expand_dim]).to(device) 52 output = model(model_input) if prediction_function is None else prediction_function(model, model_input) 53 output = output.cpu().numpy() 54 55 if crop_padding is not None: 56 crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding 57 output = output[crop_padding] 58 59 return output
Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
Arguments:
- model [torch.nn.Module]: the model
- input_ [np.ndarray]: the input ()
- min_divisible [tuple]: the divisibe shape factors (e.g. (16, 16) for a model that needs inputs divisible by at least 16 pixels)
- device [str, torch.device]: the device of the model
- with_channels [bool]: Whether the input data contains channels (default: False)
- prediction_function [callable] - A wrapper function for prediction to enable custom prediction procedures (default: None)
Returns:
np.ndarray: the ouptut of the model
def
predict_with_halo( input_, model, gpu_ids, block_shape, halo, output=None, preprocess=<function standardize>, postprocess=None, with_channels=False, skip_block=None, mask=None, disable_tqdm=False, tqdm_desc='predict with halo', prediction_function=None, roi=None):
110def predict_with_halo( 111 input_, 112 model, 113 gpu_ids, 114 block_shape, 115 halo, 116 output=None, 117 preprocess=standardize, 118 postprocess=None, 119 with_channels=False, 120 skip_block=None, 121 mask=None, 122 disable_tqdm=False, 123 tqdm_desc="predict with halo", 124 prediction_function=None, 125 roi=None, 126): 127 """ Run block-wise network prediction with halo. 128 129 Arguments: 130 input_ [arraylike] - the input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar 131 model [nn.Module] - the network 132 gpu_ids [list[int or string]] - list of gpus id used for prediction 133 block_shape [tuple] - shape of inner block used for prediction 134 halo [tuple] - shape of halo used for prediction 135 output [arraylike or list[tuple[arraylike, slice]]] - output data, will be allocated if None is passed. 136 Instead of a single output, this can also be a list of outputs and the corresponding channels. 137 (default: None) 138 preprocess [callable] - function to preprocess input data before passing it to the network. 139 (default: standardize) 140 postprocess [callable] - function to postprocess the network predictions (default: None) 141 with_channels [bool] - whether the input has a channel axis (default: False) 142 skip_block [callable] - function to evaluate wheter a given input block should be skipped (default: None) 143 mask [arraylike] - elements outside the mask will be ignored in the prediction (default: None) 144 disable_tqdm [bool] - flag that allows to disable tqdm output (e.g. if function is called multiple times) 145 tqdm_desc [str] - description shown by the tqdm output 146 prediction_function [callable] - a wrapper function for prediction to enable custom prediction procedures. 147 (default: None) 148 roi [tuple[slice]] - a region of interest for which to run prediction. (default: None) 149 """ 150 devices = [torch.device(gpu) for gpu in gpu_ids] 151 models = [ 152 (model if next(model.parameters()).device == device else deepcopy(model).to(device), device) 153 for device in devices 154 ] 155 156 n_workers = len(gpu_ids) 157 shape = input_.shape 158 if with_channels: 159 shape = shape[1:] 160 ndim = len(shape) 161 assert len(block_shape) == len(halo) == ndim 162 163 if roi is None: 164 blocking = nt.blocking([0] * ndim, shape, block_shape) 165 else: 166 assert len(roi) == ndim 167 blocking_start = [0 if ro.start is None else ro.start for ro in roi] 168 blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape)] 169 blocking = nt.blocking(blocking_start, blocking_stop, block_shape) 170 171 if output is None: 172 n_out = models[0][0].out_channels 173 output = np.zeros((n_out,) + shape, dtype="float32") 174 175 def predict_block(block_id): 176 worker_id = block_id % n_workers 177 net, device = models[worker_id] 178 179 with torch.no_grad(): 180 block = blocking.getBlock(block_id) 181 offset = [beg for beg in block.begin] 182 inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) 183 184 if mask is not None: 185 mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False) 186 mask_block = mask_block[inner_bb] 187 if mask_block.sum() == 0: 188 return 189 190 inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels) 191 192 if skip_block is not None and skip_block(inp): 193 return 194 195 if preprocess is not None: 196 inp = preprocess(inp) 197 198 # add (channel) and batch axis 199 expand_dims = np.s_[None] if with_channels else np.s_[None, None] 200 inp = torch.from_numpy(inp[expand_dims]).to(device) 201 202 prediction = net(inp) if prediction_function is None else prediction_function(net, inp) 203 204 # allow for list of tensors 205 try: 206 prediction = prediction.cpu().numpy().squeeze(0) 207 except AttributeError: 208 prediction = prediction[0] 209 prediction = prediction.cpu().numpy().squeeze(0) 210 211 if postprocess is not None: 212 prediction = postprocess(prediction) 213 214 if prediction.ndim == ndim + 1: 215 inner_bb = (slice(None),) + inner_bb 216 prediction = prediction[inner_bb] 217 218 if mask is not None: 219 if prediction.ndim == ndim + 1: 220 mask_block = np.concatenate(prediction.shape[0] * [mask_block[None]], axis=0) 221 prediction[~mask_block] = 0 222 223 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) 224 if isinstance(output, list): # we have multiple outputs and split the prediction channels 225 for out, channel_slice in output: 226 this_bb = bb if out.ndim == ndim else (slice(None),) + bb 227 out[this_bb] = prediction[channel_slice] 228 else: # we only have a single output array 229 if output.ndim == ndim + 1: 230 bb = (slice(None),) + bb 231 output[bb] = prediction 232 233 n_blocks = blocking.numberOfBlocks 234 with futures.ThreadPoolExecutor(n_workers) as tp: 235 list(tqdm(tp.map(predict_block, range(n_blocks)), total=n_blocks, disable=disable_tqdm, desc=tqdm_desc)) 236 237 return output
Run block-wise network prediction with halo.
Arguments:
- input_ [arraylike] - the input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
- model [nn.Module] - the network
- gpu_ids [list[int or string]] - list of gpus id used for prediction
- block_shape [tuple] - shape of inner block used for prediction
- halo [tuple] - shape of halo used for prediction
- output [arraylike or list[tuple[arraylike, slice]]] - output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and the corresponding channels. (default: None)
- preprocess [callable] - function to preprocess input data before passing it to the network. (default: standardize)
- postprocess [callable] - function to postprocess the network predictions (default: None)
- with_channels [bool] - whether the input has a channel axis (default: False)
- skip_block [callable] - function to evaluate wheter a given input block should be skipped (default: None)
- mask [arraylike] - elements outside the mask will be ignored in the prediction (default: None)
- disable_tqdm [bool] - flag that allows to disable tqdm output (e.g. if function is called multiple times)
- tqdm_desc [str] - description shown by the tqdm output
- prediction_function [callable] - a wrapper function for prediction to enable custom prediction procedures. (default: None)
- roi [tuple[slice]] - a region of interest for which to run prediction. (default: None)