
  1from copy import deepcopy
  2from concurrent import futures
  3from typing import Tuple, Union, Callable, Any, List, Optional
  5import numpy as np
  6import as nt
  7import torch
  8from numpy.typing import ArrayLike
 11    from napari.utils import progress as tqdm
 12except ImportError:
 13    from tqdm import tqdm
 15from ..transform.raw import standardize
 18def predict_with_padding(
 19    model: torch.nn.Module,
 20    input_: np.ndarray,
 21    min_divisible: Tuple[int, ...],
 22    device: Optional[Union[torch.device, str]] = None,
 23    with_channels: bool = False,
 24    prediction_function: Callable[[Any], Any] = None
 25) -> np.ndarray:
 26    """Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
 28    Args:
 29        model: The model.
 30        input_: The input for prediction.
 31        min_divisible: The minimal factors the input shape must be divisible by.
 32            For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
 33        device: The device of the model. If not given, will be derived from the model parameters.
 34        with_channels: Whether the input data contains channels.
 35        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
 37    Returns:
 38        np.ndarray: The ouptut of the model.
 39    """
 40    if with_channels:
 41        assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}"
 42        min_divisible_ = (1,) + min_divisible
 43    else:
 44        assert len(min_divisible) == input_.ndim
 45        min_divisible_ = min_divisible
 47    if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)):
 48        pad_width = tuple(
 49            (0, 0 if sh % md == 0 else md - sh % md)
 50            for sh, md in zip(input_.shape, min_divisible_)
 51        )
 52        crop_padding = tuple(slice(0, sh) for sh in input_.shape)
 53        input_ = np.pad(input_, pad_width, mode="reflect")
 54    else:
 55        crop_padding = None
 57    ndim = input_.ndim
 58    ndim_model = 1 + ndim if with_channels else 2 + ndim
 60    if device is None:
 61        device = next(model.parameters()).device
 63    expand_dim = (None,) * (ndim_model - ndim)
 64    with torch.no_grad():
 65        model_input = torch.from_numpy(input_[expand_dim]).to(device)
 66        output = model(model_input) if prediction_function is None else prediction_function(model, model_input)
 67        output = output.cpu().numpy()
 69    if crop_padding is not None:
 70        crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding
 71        output = output[crop_padding]
 73    return output
 76def _load_block(input_, offset, block_shape, halo, padding_mode="reflect", with_channels=False):
 77    shape = input_.shape
 78    if with_channels:
 79        shape = shape[1:]
 81    starts = [off - ha for off, ha in zip(offset, halo)]
 82    stops = [off + bs + ha for off, bs, ha in zip(offset, block_shape, halo)]
 84    # we pad the input volume if necessary
 85    pad_left = None
 86    pad_right = None
 88    # check for padding to the left
 89    if any(start < 0 for start in starts):
 90        pad_left = tuple(abs(start) if start < 0 else 0 for start in starts)
 91        starts = [max(0, start) for start in starts]
 93    # check for padding to the right
 94    if any(stop > shape[i] for i, stop in enumerate(stops)):
 95        pad_right = tuple(stop - shape[i] if stop > shape[i] else 0 for i, stop in enumerate(stops))
 96        stops = [min(shape[i], stop) for i, stop in enumerate(stops)]
 98    bb = tuple(slice(start, stop) for start, stop in zip(starts, stops))
 99    if with_channels:
100        data = input_[(slice(None),) + bb]
101    else:
102        data = input_[bb]
104    ndim = len(shape)
105    # pad if necessary
106    if pad_left is not None or pad_right is not None:
107        pad_left = (0,) * ndim if pad_left is None else pad_left
108        pad_right = (0,) * ndim if pad_right is None else pad_right
109        pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right))
110        if with_channels:
111            pad_width = ((0, 0),) + pad_width
112        data = np.pad(data, pad_width, mode=padding_mode)
114        # extend the bounding box for downstream
115        bb = tuple(
116            slice(b.start - pl, b.stop + pr)
117            for b, pl, pr in zip(bb, pad_left, pad_right)
118        )
120    return data, bb
123def predict_with_halo(
124    input_: ArrayLike,
125    model: torch.nn.Module,
126    gpu_ids: List[Union[str, int]],
127    block_shape: Tuple[int, ...],
128    halo: Tuple[int, ...],
129    output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None,
130    preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize,
131    postprocess: Callable[[np.ndarray], np.ndarray] = None,
132    with_channels: bool = False,
133    skip_block: Callable[[Any], bool] = None,
134    mask: Optional[ArrayLike] = None,
135    disable_tqdm: bool = False,
136    tqdm_desc: str = "predict with halo",
137    prediction_function: Optional[Callable] = None,
138    roi: Optional[Tuple[slice]] = None,
139) -> ArrayLike:
140    """Run block-wise network prediction with a halo.
142    Args:
143        input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
144        model: The network.
145        gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`.
146        block_shape: The shape of the inner block to use for prediction.
147        halo: The shape of the halo to use for prediction
148        output: The output data, will be allocated if None is passed.
149            Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
150        preprocess: Function to preprocess input data before passing it to the network.
151        postprocess: Function to postprocess the network predictions.
152        with_channels: Whether the input has a channel axis.
153        skip_block: Function to evaluate whether a given input block will be skipped.
154        mask: Elements outside the mask will be ignored in the prediction.
155        disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
156        tqdm_desc: Fescription shown by the tqdm output.
157        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
158        roi: A region of interest of the input for which to run prediction.
160    Returns:
161        The model output.
162    """
163    devices = [torch.device(gpu) for gpu in gpu_ids]
164    models = [
165        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
166        for device in devices
167    ]
169    n_workers = len(gpu_ids)
170    shape = input_.shape
171    if with_channels:
172        shape = shape[1:]
174    ndim = len(shape)
175    assert len(block_shape) == len(halo) == ndim
177    if roi is None:
178        blocking = nt.blocking([0] * ndim, shape, block_shape)
179    else:
180        assert len(roi) == ndim
181        blocking_start = [0 if ro.start is None else ro.start for ro in roi]
182        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape)]
183        blocking = nt.blocking(blocking_start, blocking_stop, block_shape)
185    if output is None:
186        n_out = models[0][0].out_channels
187        output = np.zeros((n_out,) + shape, dtype="float32")
189    def predict_block(block_id):
190        worker_id = block_id % n_workers
191        net, device = models[worker_id]
193        with torch.no_grad():
194            block = blocking.getBlock(block_id)
195            offset = [beg for beg in block.begin]
196            inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape))
198            if mask is not None:
199                mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False)
200                mask_block = mask_block[inner_bb].astype("bool")
201                if mask_block.sum() == 0:
202                    return
204            inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels)
206            if skip_block is not None and skip_block(inp):
207                return
209            if preprocess is not None:
210                inp = preprocess(inp)
212            # add (channel) and batch axis
213            expand_dims = np.s_[None] if with_channels else np.s_[None, None]
214            inp = torch.from_numpy(inp[expand_dims]).to(device)
216            prediction = net(inp) if prediction_function is None else prediction_function(net, inp)
218            # allow for list of tensors
219            try:
220                prediction = prediction.cpu().numpy().squeeze(0)
221            except AttributeError:
222                prediction = prediction[0]
223                prediction = prediction.cpu().numpy().squeeze(0)
225            if postprocess is not None:
226                prediction = postprocess(prediction)
228            if prediction.ndim == ndim + 1:
229                inner_bb = (slice(None),) + inner_bb
230            prediction = prediction[inner_bb]
232            if mask is not None:
233                if prediction.ndim == ndim + 1:
234                    mask_block = np.concatenate(prediction.shape[0] * [mask_block[None]], axis=0)
235                prediction[~mask_block] = 0
237            bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
238            if isinstance(output, list):  # we have multiple outputs and split the prediction channels
239                for out, channel_slice in output:
240                    this_bb = bb if out.ndim == ndim else (slice(None),) + bb
241                    out[this_bb] = prediction[channel_slice]
243            else:  # we only have a single output array
244                if output.ndim == ndim + 1:
245                    bb = (slice(None),) + bb
246                output[bb] = prediction
248    n_blocks = blocking.numberOfBlocks
249    with futures.ThreadPoolExecutor(n_workers) as tp:
250        list(tqdm(, range(n_blocks)), total=n_blocks, disable=disable_tqdm, desc=tqdm_desc))
252    return output
Run prediction with padding for a model that can only deal with inputs divisible by specific factors.

  • model: The model.
  • input_: The input for prediction.
  • min_divisible: The minimal factors the input shape must be divisible by. For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
  • device: The device of the model. If not given, will be derived from the model parameters.
  • with_channels: Whether the input data contains channels.
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures.

np.ndarray: The ouptut of the model.

Run block-wise network prediction with a halo.

  • input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
  • model: The network.
  • gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass ["cpu"].
  • block_shape: The shape of the inner block to use for prediction.
  • halo: The shape of the halo to use for prediction
  • output: The output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
  • preprocess: Function to preprocess input data before passing it to the network.
  • postprocess: Function to postprocess the network predictions.
  • with_channels: Whether the input has a channel axis.
  • skip_block: Function to evaluate whether a given input block will be skipped.
  • mask: Elements outside the mask will be ignored in the prediction.
  • disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
  • tqdm_desc: Fescription shown by the tqdm output.
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures.
  • roi: A region of interest of the input for which to run prediction.

The model output.