torch_em.util.prediction

  1from copy import deepcopy
  2from concurrent import futures
  3from typing import Tuple, Union, Callable, Any, List, Optional
  4
  5import numpy as np
  6import nifty.tools as nt
  7import torch
  8from numpy.typing import ArrayLike
  9
 10try:
 11    from napari.utils import progress as tqdm
 12except ImportError:
 13    from tqdm import tqdm
 14
 15from ..transform.raw import standardize
 16
 17
 18def predict_with_padding(
 19    model: torch.nn.Module,
 20    input_: np.ndarray,
 21    min_divisible: Tuple[int, ...],
 22    device: Optional[Union[torch.device, str]] = None,
 23    with_channels: bool = False,
 24    prediction_function: Callable[[Any], Any] = None
 25) -> np.ndarray:
 26    """Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
 27
 28    Args:
 29        model: The model.
 30        input_: The input for prediction.
 31        min_divisible: The minimal factors the input shape must be divisible by.
 32            For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
 33        device: The device of the model. If not given, will be derived from the model parameters.
 34        with_channels: Whether the input data contains channels.
 35        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
 36
 37    Returns:
 38        np.ndarray: The ouptut of the model.
 39    """
 40    if with_channels:
 41        assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}"
 42        min_divisible_ = (1,) + min_divisible
 43    else:
 44        assert len(min_divisible) == input_.ndim
 45        min_divisible_ = min_divisible
 46
 47    if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)):
 48        pad_width = tuple(
 49            (0, 0 if sh % md == 0 else md - sh % md)
 50            for sh, md in zip(input_.shape, min_divisible_)
 51        )
 52        crop_padding = tuple(slice(0, sh) for sh in input_.shape)
 53        input_ = np.pad(input_, pad_width, mode="reflect")
 54    else:
 55        crop_padding = None
 56
 57    ndim = input_.ndim
 58    ndim_model = 1 + ndim if with_channels else 2 + ndim
 59
 60    if device is None:
 61        device = next(model.parameters()).device
 62
 63    expand_dim = (None,) * (ndim_model - ndim)
 64    with torch.no_grad():
 65        model_input = torch.from_numpy(input_[expand_dim]).to(device)
 66        output = model(model_input) if prediction_function is None else prediction_function(model, model_input)
 67        output = output.cpu().numpy()
 68
 69    if crop_padding is not None:
 70        crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding
 71        output = output[crop_padding]
 72
 73    return output
 74
 75
 76def _pad_for_shift_left(arr, pad_vox, with_channels, mode="constant", constant_values=0.0):
 77    pad_left = tuple(pad_vox)
 78    pad_right = tuple(0 for _ in pad_vox)
 79
 80    pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right))
 81    if with_channels:
 82        pad_width = ((0, 0),) + pad_width
 83
 84    arr_pad = np.pad(arr, pad_width, mode=mode, constant_values=constant_values)
 85    return arr_pad, pad_left
 86
 87
 88def _crop_after_shift_left(arr, pad_left, with_channels, original_shape_spatial):
 89    starts = pad_left
 90    stops = tuple(st + sh for st, sh in zip(starts, original_shape_spatial))
 91    spatial_slices = tuple(slice(st, sp) for st, sp in zip(starts, stops))
 92    return arr[(slice(None),) + spatial_slices] if with_channels else arr[spatial_slices]
 93
 94
 95def _load_block(input_, offset, block_shape, halo, padding_mode="reflect", with_channels=False):
 96    shape = input_.shape
 97    if with_channels:
 98        shape = shape[1:]
 99
100    starts = [off - ha for off, ha in zip(offset, halo)]
101    stops = [off + bs + ha for off, bs, ha in zip(offset, block_shape, halo)]
102
103    # we pad the input volume if necessary
104    pad_left = None
105    pad_right = None
106
107    # check for padding to the left
108    if any(start < 0 for start in starts):
109        pad_left = tuple(abs(start) if start < 0 else 0 for start in starts)
110        starts = [max(0, start) for start in starts]
111
112    # check for padding to the right
113    if any(stop > shape[i] for i, stop in enumerate(stops)):
114        pad_right = tuple(stop - shape[i] if stop > shape[i] else 0 for i, stop in enumerate(stops))
115        stops = [min(shape[i], stop) for i, stop in enumerate(stops)]
116
117    bb = tuple(slice(start, stop) for start, stop in zip(starts, stops))
118    if with_channels:
119        data = input_[(slice(None),) + bb]
120    else:
121        data = input_[bb]
122
123    ndim = len(shape)
124    # pad if necessary
125    if pad_left is not None or pad_right is not None:
126        pad_left = (0,) * ndim if pad_left is None else pad_left
127        pad_right = (0,) * ndim if pad_right is None else pad_right
128        pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right))
129        if with_channels:
130            pad_width = ((0, 0),) + pad_width
131        data = np.pad(data, pad_width, mode=padding_mode)
132
133        # extend the bounding box for downstream
134        bb = tuple(
135            slice(b.start - pl, b.stop + pr)
136            for b, pl, pr in zip(bb, pad_left, pad_right)
137        )
138
139    return data, bb
140
141
142def predict_with_halo(
143    input_: ArrayLike,
144    model: torch.nn.Module,
145    gpu_ids: List[Union[str, int]],
146    block_shape: Tuple[int, ...],
147    halo: Tuple[int, ...],
148    output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None,
149    preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize,
150    postprocess: Callable[[np.ndarray], np.ndarray] = None,
151    with_channels: bool = False,
152    skip_block: Callable[[Any], bool] = None,
153    mask: Optional[ArrayLike] = None,
154    disable_tqdm: bool = False,
155    tqdm_desc: str = "predict with halo",
156    prediction_function: Optional[Callable] = None,
157    roi: Optional[Tuple[slice]] = None,
158    iter_list: Optional[List[int]] = None,
159    grid_shift: Optional[Tuple[float, ...]] = None,
160) -> ArrayLike:
161    """Run block-wise network prediction with a halo.
162
163    Args:
164        input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
165        model: The network.
166        gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`.
167        block_shape: The shape of the inner block to use for prediction.
168        halo: The shape of the halo to use for prediction
169        output: The output data, will be allocated if None is passed.
170            Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
171        preprocess: Function to preprocess input data before passing it to the network.
172        postprocess: Function to postprocess the network predictions.
173        with_channels: Whether the input has a channel axis.
174        skip_block: Function to evaluate whether a given input block will be skipped.
175        mask: Elements outside the mask will be ignored in the prediction.
176        disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
177        tqdm_desc: Fescription shown by the tqdm output.
178        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
179        roi: A region of interest of the input for which to run prediction.
180        grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0).
181    Returns:
182        The model output.
183    """
184    devices = [torch.device(gpu) for gpu in gpu_ids]
185    models = [
186        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
187        for device in devices
188    ]
189    n_workers = len(gpu_ids)
190
191    # ---- original shape (spatial only) ----
192    shape0 = input_.shape
193    shape_spatial0 = shape0[1:] if with_channels else shape0
194    ndim = len(shape_spatial0)
195    assert len(block_shape) == len(halo) == ndim
196
197    # ---- apply grid_shift via padding+cropping (zero padding) ----
198    input_eff = input_
199    mask_eff = mask
200
201    if grid_shift is not None:
202        assert len(grid_shift) == ndim, "grid_shift must match number of spatial dims"
203        pad_vox = tuple(int(np.rint(abs(gs) * bs)) for gs, bs in zip(grid_shift, block_shape))
204
205        if not isinstance(input_eff, np.ndarray):
206            raise TypeError("grid_shift padding currently requires input_ to be a numpy array")
207
208        input_eff, pad_left = _pad_for_shift_left(input_eff, pad_vox, with_channels=with_channels, mode="constant",
209                                                  constant_values=0)
210
211        if mask_eff is not None:
212            if not isinstance(mask_eff, np.ndarray):
213                raise TypeError("grid_shift padding currently requires mask to be a numpy array")
214            mask_eff, _ = _pad_for_shift_left(mask_eff, pad_vox, with_channels=False, mode="constant",
215                                              constant_values=0)
216    else:
217        pad_left = (0,) * ndim
218
219    # shapes after shift-padding
220    shape_eff = input_eff.shape
221    shape_spatial_eff = shape_eff[1:] if with_channels else shape_eff
222
223    # ---- blocking (on the padded input) ----
224    if roi is None:
225        blocking = nt.blocking([0] * ndim, shape_spatial_eff, block_shape)
226    else:
227        assert len(roi) == ndim
228        blocking_start = [0 if ro.start is None else ro.start for ro in roi]
229        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial_eff)]
230        blocking = nt.blocking(blocking_start, blocking_stop, block_shape)
231
232    # ---- output allocation (for padded shape) ----
233    if output is None:
234        n_out = models[0][0].out_channels
235        output = np.zeros((n_out,) + tuple(shape_spatial_eff), dtype="float32")
236    elif grid_shift:
237        raise ValueError(
238            "grid_shift is not supported together with a user-provided `output`, because "
239            "grid_shift requires internal zero-padding and a final cropping step. "
240            "Pass `output=None` (let this function allocate the output) or disable `grid_shift`. "
241            "Or pad the input manually beforehand."
242        )
243
244    def predict_block(block_id):
245        worker_id = block_id % n_workers
246        net, device = models[worker_id]
247
248        with torch.no_grad():
249            block = blocking.getBlock(block_id)
250            offset = [beg for beg in block.begin]
251            inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape))
252
253            if mask_eff is not None:
254                mask_block, _ = _load_block(mask_eff, offset, block_shape, halo, with_channels=False)
255                mask_block = mask_block[inner_bb].astype("bool")
256                if mask_block.sum() == 0:
257                    return
258
259            inp, _ = _load_block(input_eff, offset, block_shape, halo, with_channels=with_channels)
260
261            if skip_block is not None and skip_block(inp):
262                return
263
264            if preprocess is not None:
265                inp = preprocess(inp)
266
267            # add (channel) and batch axis
268            expand_dims = np.s_[None] if with_channels else np.s_[None, None]
269            inp = torch.from_numpy(inp[expand_dims]).to(device)
270
271            prediction = net(inp) if prediction_function is None else prediction_function(net, inp)
272
273            # allow for list of tensors
274            try:
275                prediction = prediction.cpu().numpy().squeeze(0)
276            except AttributeError:
277                prediction = prediction[0]
278                prediction = prediction.cpu().numpy().squeeze(0)
279
280            if postprocess is not None:
281                prediction = postprocess(prediction)
282
283            if prediction.ndim == ndim + 1:
284                inner_bb_pred = (slice(None),) + inner_bb
285            else:
286                inner_bb_pred = inner_bb
287            prediction = prediction[inner_bb_pred]
288
289            if mask_eff is not None:
290                if prediction.ndim == ndim + 1:
291                    mb = np.broadcast_to(mask_block[None], prediction.shape)
292                else:
293                    mb = mask_block
294                prediction[~mb] = 0
295
296            bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
297            if isinstance(output, list):  # we have multiple outputs and split the prediction channels
298                for out, channel_slice in output:
299                    this_bb = bb if out.ndim == ndim else (slice(None),) + bb
300                    out[this_bb] = prediction[channel_slice]
301            else:  # we only have a single output array
302                if output.ndim == ndim + 1:
303                    bb = (slice(None),) + bb
304                output[bb] = prediction
305
306    n_blocks = blocking.numberOfBlocks
307    iteration_ids = range(n_blocks) if iter_list is None else np.array(iter_list)
308
309    with futures.ThreadPoolExecutor(n_workers) as tp:
310        list(tqdm(tp.map(predict_block, iteration_ids),
311                  total=len(iteration_ids),
312                  disable=disable_tqdm,
313                  desc=tqdm_desc))
314
315    # ---- crop away the shift padding so the returned output matches original shape ----
316    if grid_shift is not None:
317        output = _crop_after_shift_left(output, pad_left, with_channels=(output.ndim == ndim+1),
318                                        original_shape_spatial=tuple(shape_spatial0))
319
320    return output
def predict_with_padding( model: torch.nn.modules.module.Module, input_: numpy.ndarray, min_divisible: Tuple[int, ...], device: Union[torch.device, str, NoneType] = None, with_channels: bool = False, prediction_function: Callable[[Any], Any] = None) -> numpy.ndarray:
19def predict_with_padding(
20    model: torch.nn.Module,
21    input_: np.ndarray,
22    min_divisible: Tuple[int, ...],
23    device: Optional[Union[torch.device, str]] = None,
24    with_channels: bool = False,
25    prediction_function: Callable[[Any], Any] = None
26) -> np.ndarray:
27    """Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
28
29    Args:
30        model: The model.
31        input_: The input for prediction.
32        min_divisible: The minimal factors the input shape must be divisible by.
33            For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
34        device: The device of the model. If not given, will be derived from the model parameters.
35        with_channels: Whether the input data contains channels.
36        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
37
38    Returns:
39        np.ndarray: The ouptut of the model.
40    """
41    if with_channels:
42        assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}"
43        min_divisible_ = (1,) + min_divisible
44    else:
45        assert len(min_divisible) == input_.ndim
46        min_divisible_ = min_divisible
47
48    if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)):
49        pad_width = tuple(
50            (0, 0 if sh % md == 0 else md - sh % md)
51            for sh, md in zip(input_.shape, min_divisible_)
52        )
53        crop_padding = tuple(slice(0, sh) for sh in input_.shape)
54        input_ = np.pad(input_, pad_width, mode="reflect")
55    else:
56        crop_padding = None
57
58    ndim = input_.ndim
59    ndim_model = 1 + ndim if with_channels else 2 + ndim
60
61    if device is None:
62        device = next(model.parameters()).device
63
64    expand_dim = (None,) * (ndim_model - ndim)
65    with torch.no_grad():
66        model_input = torch.from_numpy(input_[expand_dim]).to(device)
67        output = model(model_input) if prediction_function is None else prediction_function(model, model_input)
68        output = output.cpu().numpy()
69
70    if crop_padding is not None:
71        crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding
72        output = output[crop_padding]
73
74    return output

Run prediction with padding for a model that can only deal with inputs divisible by specific factors.

Arguments:
  • model: The model.
  • input_: The input for prediction.
  • min_divisible: The minimal factors the input shape must be divisible by. For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
  • device: The device of the model. If not given, will be derived from the model parameters.
  • with_channels: Whether the input data contains channels.
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures.
Returns:

np.ndarray: The ouptut of the model.

def predict_with_halo( input_: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], model: torch.nn.modules.module.Module, gpu_ids: List[Union[str, int]], block_shape: Tuple[int, ...], halo: Tuple[int, ...], output: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], List[Tuple[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], slice]], NoneType] = None, preprocess: Callable[[Union[torch.Tensor, numpy.ndarray]], Union[torch.Tensor, numpy.ndarray]] = <function standardize>, postprocess: Callable[[numpy.ndarray], numpy.ndarray] = None, with_channels: bool = False, skip_block: Callable[[Any], bool] = None, mask: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], NoneType] = None, disable_tqdm: bool = False, tqdm_desc: str = 'predict with halo', prediction_function: Optional[Callable] = None, roi: Optional[Tuple[slice]] = None, iter_list: Optional[List[int]] = None, grid_shift: Optional[Tuple[float, ...]] = None) -> Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]]:
143def predict_with_halo(
144    input_: ArrayLike,
145    model: torch.nn.Module,
146    gpu_ids: List[Union[str, int]],
147    block_shape: Tuple[int, ...],
148    halo: Tuple[int, ...],
149    output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None,
150    preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize,
151    postprocess: Callable[[np.ndarray], np.ndarray] = None,
152    with_channels: bool = False,
153    skip_block: Callable[[Any], bool] = None,
154    mask: Optional[ArrayLike] = None,
155    disable_tqdm: bool = False,
156    tqdm_desc: str = "predict with halo",
157    prediction_function: Optional[Callable] = None,
158    roi: Optional[Tuple[slice]] = None,
159    iter_list: Optional[List[int]] = None,
160    grid_shift: Optional[Tuple[float, ...]] = None,
161) -> ArrayLike:
162    """Run block-wise network prediction with a halo.
163
164    Args:
165        input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
166        model: The network.
167        gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`.
168        block_shape: The shape of the inner block to use for prediction.
169        halo: The shape of the halo to use for prediction
170        output: The output data, will be allocated if None is passed.
171            Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
172        preprocess: Function to preprocess input data before passing it to the network.
173        postprocess: Function to postprocess the network predictions.
174        with_channels: Whether the input has a channel axis.
175        skip_block: Function to evaluate whether a given input block will be skipped.
176        mask: Elements outside the mask will be ignored in the prediction.
177        disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
178        tqdm_desc: Fescription shown by the tqdm output.
179        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
180        roi: A region of interest of the input for which to run prediction.
181        grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0).
182    Returns:
183        The model output.
184    """
185    devices = [torch.device(gpu) for gpu in gpu_ids]
186    models = [
187        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
188        for device in devices
189    ]
190    n_workers = len(gpu_ids)
191
192    # ---- original shape (spatial only) ----
193    shape0 = input_.shape
194    shape_spatial0 = shape0[1:] if with_channels else shape0
195    ndim = len(shape_spatial0)
196    assert len(block_shape) == len(halo) == ndim
197
198    # ---- apply grid_shift via padding+cropping (zero padding) ----
199    input_eff = input_
200    mask_eff = mask
201
202    if grid_shift is not None:
203        assert len(grid_shift) == ndim, "grid_shift must match number of spatial dims"
204        pad_vox = tuple(int(np.rint(abs(gs) * bs)) for gs, bs in zip(grid_shift, block_shape))
205
206        if not isinstance(input_eff, np.ndarray):
207            raise TypeError("grid_shift padding currently requires input_ to be a numpy array")
208
209        input_eff, pad_left = _pad_for_shift_left(input_eff, pad_vox, with_channels=with_channels, mode="constant",
210                                                  constant_values=0)
211
212        if mask_eff is not None:
213            if not isinstance(mask_eff, np.ndarray):
214                raise TypeError("grid_shift padding currently requires mask to be a numpy array")
215            mask_eff, _ = _pad_for_shift_left(mask_eff, pad_vox, with_channels=False, mode="constant",
216                                              constant_values=0)
217    else:
218        pad_left = (0,) * ndim
219
220    # shapes after shift-padding
221    shape_eff = input_eff.shape
222    shape_spatial_eff = shape_eff[1:] if with_channels else shape_eff
223
224    # ---- blocking (on the padded input) ----
225    if roi is None:
226        blocking = nt.blocking([0] * ndim, shape_spatial_eff, block_shape)
227    else:
228        assert len(roi) == ndim
229        blocking_start = [0 if ro.start is None else ro.start for ro in roi]
230        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial_eff)]
231        blocking = nt.blocking(blocking_start, blocking_stop, block_shape)
232
233    # ---- output allocation (for padded shape) ----
234    if output is None:
235        n_out = models[0][0].out_channels
236        output = np.zeros((n_out,) + tuple(shape_spatial_eff), dtype="float32")
237    elif grid_shift:
238        raise ValueError(
239            "grid_shift is not supported together with a user-provided `output`, because "
240            "grid_shift requires internal zero-padding and a final cropping step. "
241            "Pass `output=None` (let this function allocate the output) or disable `grid_shift`. "
242            "Or pad the input manually beforehand."
243        )
244
245    def predict_block(block_id):
246        worker_id = block_id % n_workers
247        net, device = models[worker_id]
248
249        with torch.no_grad():
250            block = blocking.getBlock(block_id)
251            offset = [beg for beg in block.begin]
252            inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape))
253
254            if mask_eff is not None:
255                mask_block, _ = _load_block(mask_eff, offset, block_shape, halo, with_channels=False)
256                mask_block = mask_block[inner_bb].astype("bool")
257                if mask_block.sum() == 0:
258                    return
259
260            inp, _ = _load_block(input_eff, offset, block_shape, halo, with_channels=with_channels)
261
262            if skip_block is not None and skip_block(inp):
263                return
264
265            if preprocess is not None:
266                inp = preprocess(inp)
267
268            # add (channel) and batch axis
269            expand_dims = np.s_[None] if with_channels else np.s_[None, None]
270            inp = torch.from_numpy(inp[expand_dims]).to(device)
271
272            prediction = net(inp) if prediction_function is None else prediction_function(net, inp)
273
274            # allow for list of tensors
275            try:
276                prediction = prediction.cpu().numpy().squeeze(0)
277            except AttributeError:
278                prediction = prediction[0]
279                prediction = prediction.cpu().numpy().squeeze(0)
280
281            if postprocess is not None:
282                prediction = postprocess(prediction)
283
284            if prediction.ndim == ndim + 1:
285                inner_bb_pred = (slice(None),) + inner_bb
286            else:
287                inner_bb_pred = inner_bb
288            prediction = prediction[inner_bb_pred]
289
290            if mask_eff is not None:
291                if prediction.ndim == ndim + 1:
292                    mb = np.broadcast_to(mask_block[None], prediction.shape)
293                else:
294                    mb = mask_block
295                prediction[~mb] = 0
296
297            bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
298            if isinstance(output, list):  # we have multiple outputs and split the prediction channels
299                for out, channel_slice in output:
300                    this_bb = bb if out.ndim == ndim else (slice(None),) + bb
301                    out[this_bb] = prediction[channel_slice]
302            else:  # we only have a single output array
303                if output.ndim == ndim + 1:
304                    bb = (slice(None),) + bb
305                output[bb] = prediction
306
307    n_blocks = blocking.numberOfBlocks
308    iteration_ids = range(n_blocks) if iter_list is None else np.array(iter_list)
309
310    with futures.ThreadPoolExecutor(n_workers) as tp:
311        list(tqdm(tp.map(predict_block, iteration_ids),
312                  total=len(iteration_ids),
313                  disable=disable_tqdm,
314                  desc=tqdm_desc))
315
316    # ---- crop away the shift padding so the returned output matches original shape ----
317    if grid_shift is not None:
318        output = _crop_after_shift_left(output, pad_left, with_channels=(output.ndim == ndim+1),
319                                        original_shape_spatial=tuple(shape_spatial0))
320
321    return output

Run block-wise network prediction with a halo.

Arguments:
  • input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
  • model: The network.
  • gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass ["cpu"].
  • block_shape: The shape of the inner block to use for prediction.
  • halo: The shape of the halo to use for prediction
  • output: The output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
  • preprocess: Function to preprocess input data before passing it to the network.
  • postprocess: Function to postprocess the network predictions.
  • with_channels: Whether the input has a channel axis.
  • skip_block: Function to evaluate whether a given input block will be skipped.
  • mask: Elements outside the mask will be ignored in the prediction.
  • disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
  • tqdm_desc: Fescription shown by the tqdm output.
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures.
  • roi: A region of interest of the input for which to run prediction.
  • grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0).
Returns:

The model output.