torch_em.util.prediction

  1from copy import deepcopy
  2from concurrent import futures
  3from typing import Tuple, Union, Callable, Any, List, Optional
  4
  5import numpy as np
  6import nifty.tools as nt
  7import torch
  8from numpy.typing import ArrayLike
  9
 10try:
 11    from napari.utils import progress as tqdm
 12except ImportError:
 13    from tqdm import tqdm
 14
 15from ..transform.raw import standardize
 16
 17
 18def predict_with_padding(
 19    model: torch.nn.Module,
 20    input_: np.ndarray,
 21    min_divisible: Tuple[int, ...],
 22    device: Optional[Union[torch.device, str]] = None,
 23    with_channels: bool = False,
 24    prediction_function: Callable[[Any], Any] = None
 25) -> np.ndarray:
 26    """Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
 27
 28    Args:
 29        model: The model.
 30        input_: The input for prediction.
 31        min_divisible: The minimal factors the input shape must be divisible by.
 32            For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
 33        device: The device of the model. If not given, will be derived from the model parameters.
 34        with_channels: Whether the input data contains channels.
 35        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
 36
 37    Returns:
 38        np.ndarray: The ouptut of the model.
 39    """
 40    if with_channels:
 41        assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}"
 42        min_divisible_ = (1,) + min_divisible
 43    else:
 44        assert len(min_divisible) == input_.ndim
 45        min_divisible_ = min_divisible
 46
 47    if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)):
 48        pad_width = tuple(
 49            (0, 0 if sh % md == 0 else md - sh % md)
 50            for sh, md in zip(input_.shape, min_divisible_)
 51        )
 52        crop_padding = tuple(slice(0, sh) for sh in input_.shape)
 53        input_ = np.pad(input_, pad_width, mode="reflect")
 54    else:
 55        crop_padding = None
 56
 57    ndim = input_.ndim
 58    ndim_model = 1 + ndim if with_channels else 2 + ndim
 59
 60    if device is None:
 61        device = next(model.parameters()).device
 62
 63    expand_dim = (None,) * (ndim_model - ndim)
 64    with torch.no_grad():
 65        model_input = torch.from_numpy(input_[expand_dim]).to(device)
 66        output = model(model_input) if prediction_function is None else prediction_function(model, model_input)
 67        output = output.cpu().numpy()
 68
 69    if crop_padding is not None:
 70        crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding
 71        output = output[crop_padding]
 72
 73    return output
 74
 75
 76def _load_block(input_, offset, block_shape, halo, padding_mode="reflect", with_channels=False):
 77    shape = input_.shape
 78    if with_channels:
 79        shape = shape[1:]
 80
 81    starts = [off - ha for off, ha in zip(offset, halo)]
 82    stops = [off + bs + ha for off, bs, ha in zip(offset, block_shape, halo)]
 83
 84    # we pad the input volume if necessary
 85    pad_left = None
 86    pad_right = None
 87
 88    # check for padding to the left
 89    if any(start < 0 for start in starts):
 90        pad_left = tuple(abs(start) if start < 0 else 0 for start in starts)
 91        starts = [max(0, start) for start in starts]
 92
 93    # check for padding to the right
 94    if any(stop > shape[i] for i, stop in enumerate(stops)):
 95        pad_right = tuple(stop - shape[i] if stop > shape[i] else 0 for i, stop in enumerate(stops))
 96        stops = [min(shape[i], stop) for i, stop in enumerate(stops)]
 97
 98    bb = tuple(slice(start, stop) for start, stop in zip(starts, stops))
 99    if with_channels:
100        data = input_[(slice(None),) + bb]
101    else:
102        data = input_[bb]
103
104    ndim = len(shape)
105    # pad if necessary
106    if pad_left is not None or pad_right is not None:
107        pad_left = (0,) * ndim if pad_left is None else pad_left
108        pad_right = (0,) * ndim if pad_right is None else pad_right
109        pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right))
110        if with_channels:
111            pad_width = ((0, 0),) + pad_width
112        data = np.pad(data, pad_width, mode=padding_mode)
113
114        # extend the bounding box for downstream
115        bb = tuple(
116            slice(b.start - pl, b.stop + pr)
117            for b, pl, pr in zip(bb, pad_left, pad_right)
118        )
119
120    return data, bb
121
122
123def predict_with_halo(
124    input_: ArrayLike,
125    model: torch.nn.Module,
126    gpu_ids: List[Union[str, int]],
127    block_shape: Tuple[int, ...],
128    halo: Tuple[int, ...],
129    output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None,
130    preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize,
131    postprocess: Callable[[np.ndarray], np.ndarray] = None,
132    with_channels: bool = False,
133    skip_block: Callable[[Any], bool] = None,
134    mask: Optional[ArrayLike] = None,
135    disable_tqdm: bool = False,
136    tqdm_desc: str = "predict with halo",
137    prediction_function: Optional[Callable] = None,
138    roi: Optional[Tuple[slice]] = None,
139) -> ArrayLike:
140    """Run block-wise network prediction with a halo.
141
142    Args:
143        input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
144        model: The network.
145        gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`.
146        block_shape: The shape of the inner block to use for prediction.
147        halo: The shape of the halo to use for prediction
148        output: The output data, will be allocated if None is passed.
149            Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
150        preprocess: Function to preprocess input data before passing it to the network.
151        postprocess: Function to postprocess the network predictions.
152        with_channels: Whether the input has a channel axis.
153        skip_block: Function to evaluate whether a given input block will be skipped.
154        mask: Elements outside the mask will be ignored in the prediction.
155        disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
156        tqdm_desc: Fescription shown by the tqdm output.
157        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
158        roi: A region of interest of the input for which to run prediction.
159
160    Returns:
161        The model output.
162    """
163    devices = [torch.device(gpu) for gpu in gpu_ids]
164    models = [
165        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
166        for device in devices
167    ]
168
169    n_workers = len(gpu_ids)
170    shape = input_.shape
171    if with_channels:
172        shape = shape[1:]
173
174    ndim = len(shape)
175    assert len(block_shape) == len(halo) == ndim
176
177    if roi is None:
178        blocking = nt.blocking([0] * ndim, shape, block_shape)
179    else:
180        assert len(roi) == ndim
181        blocking_start = [0 if ro.start is None else ro.start for ro in roi]
182        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape)]
183        blocking = nt.blocking(blocking_start, blocking_stop, block_shape)
184
185    if output is None:
186        n_out = models[0][0].out_channels
187        output = np.zeros((n_out,) + shape, dtype="float32")
188
189    def predict_block(block_id):
190        worker_id = block_id % n_workers
191        net, device = models[worker_id]
192
193        with torch.no_grad():
194            block = blocking.getBlock(block_id)
195            offset = [beg for beg in block.begin]
196            inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape))
197
198            if mask is not None:
199                mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False)
200                mask_block = mask_block[inner_bb].astype("bool")
201                if mask_block.sum() == 0:
202                    return
203
204            inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels)
205
206            if skip_block is not None and skip_block(inp):
207                return
208
209            if preprocess is not None:
210                inp = preprocess(inp)
211
212            # add (channel) and batch axis
213            expand_dims = np.s_[None] if with_channels else np.s_[None, None]
214            inp = torch.from_numpy(inp[expand_dims]).to(device)
215
216            prediction = net(inp) if prediction_function is None else prediction_function(net, inp)
217
218            # allow for list of tensors
219            try:
220                prediction = prediction.cpu().numpy().squeeze(0)
221            except AttributeError:
222                prediction = prediction[0]
223                prediction = prediction.cpu().numpy().squeeze(0)
224
225            if postprocess is not None:
226                prediction = postprocess(prediction)
227
228            if prediction.ndim == ndim + 1:
229                inner_bb = (slice(None),) + inner_bb
230            prediction = prediction[inner_bb]
231
232            if mask is not None:
233                if prediction.ndim == ndim + 1:
234                    mask_block = np.concatenate(prediction.shape[0] * [mask_block[None]], axis=0)
235                prediction[~mask_block] = 0
236
237            bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
238            if isinstance(output, list):  # we have multiple outputs and split the prediction channels
239                for out, channel_slice in output:
240                    this_bb = bb if out.ndim == ndim else (slice(None),) + bb
241                    out[this_bb] = prediction[channel_slice]
242
243            else:  # we only have a single output array
244                if output.ndim == ndim + 1:
245                    bb = (slice(None),) + bb
246                output[bb] = prediction
247
248    n_blocks = blocking.numberOfBlocks
249    with futures.ThreadPoolExecutor(n_workers) as tp:
250        list(tqdm(tp.map(predict_block, range(n_blocks)), total=n_blocks, disable=disable_tqdm, desc=tqdm_desc))
251
252    return output
def predict_with_padding( model: torch.nn.modules.module.Module, input_: numpy.ndarray, min_divisible: Tuple[int, ...], device: Union[torch.device, str, NoneType] = None, with_channels: bool = False, prediction_function: Callable[[Any], Any] = None) -> numpy.ndarray:
19def predict_with_padding(
20    model: torch.nn.Module,
21    input_: np.ndarray,
22    min_divisible: Tuple[int, ...],
23    device: Optional[Union[torch.device, str]] = None,
24    with_channels: bool = False,
25    prediction_function: Callable[[Any], Any] = None
26) -> np.ndarray:
27    """Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
28
29    Args:
30        model: The model.
31        input_: The input for prediction.
32        min_divisible: The minimal factors the input shape must be divisible by.
33            For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
34        device: The device of the model. If not given, will be derived from the model parameters.
35        with_channels: Whether the input data contains channels.
36        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
37
38    Returns:
39        np.ndarray: The ouptut of the model.
40    """
41    if with_channels:
42        assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}"
43        min_divisible_ = (1,) + min_divisible
44    else:
45        assert len(min_divisible) == input_.ndim
46        min_divisible_ = min_divisible
47
48    if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)):
49        pad_width = tuple(
50            (0, 0 if sh % md == 0 else md - sh % md)
51            for sh, md in zip(input_.shape, min_divisible_)
52        )
53        crop_padding = tuple(slice(0, sh) for sh in input_.shape)
54        input_ = np.pad(input_, pad_width, mode="reflect")
55    else:
56        crop_padding = None
57
58    ndim = input_.ndim
59    ndim_model = 1 + ndim if with_channels else 2 + ndim
60
61    if device is None:
62        device = next(model.parameters()).device
63
64    expand_dim = (None,) * (ndim_model - ndim)
65    with torch.no_grad():
66        model_input = torch.from_numpy(input_[expand_dim]).to(device)
67        output = model(model_input) if prediction_function is None else prediction_function(model, model_input)
68        output = output.cpu().numpy()
69
70    if crop_padding is not None:
71        crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding
72        output = output[crop_padding]
73
74    return output

Run prediction with padding for a model that can only deal with inputs divisible by specific factors.

Arguments:
  • model: The model.
  • input_: The input for prediction.
  • min_divisible: The minimal factors the input shape must be divisible by. For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
  • device: The device of the model. If not given, will be derived from the model parameters.
  • with_channels: Whether the input data contains channels.
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures.
Returns:

np.ndarray: The ouptut of the model.

def predict_with_halo( input_: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], model: torch.nn.modules.module.Module, gpu_ids: List[Union[str, int]], block_shape: Tuple[int, ...], halo: Tuple[int, ...], output: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]], List[Tuple[Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], slice]], NoneType] = None, preprocess: Callable[[Union[torch.Tensor, numpy.ndarray]], Union[torch.Tensor, numpy.ndarray]] = <function standardize>, postprocess: Callable[[numpy.ndarray], numpy.ndarray] = None, with_channels: bool = False, skip_block: Callable[[Any], bool] = None, mask: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]], NoneType] = None, disable_tqdm: bool = False, tqdm_desc: str = 'predict with halo', prediction_function: Optional[Callable] = None, roi: Optional[Tuple[slice]] = None) -> Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]:
124def predict_with_halo(
125    input_: ArrayLike,
126    model: torch.nn.Module,
127    gpu_ids: List[Union[str, int]],
128    block_shape: Tuple[int, ...],
129    halo: Tuple[int, ...],
130    output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None,
131    preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize,
132    postprocess: Callable[[np.ndarray], np.ndarray] = None,
133    with_channels: bool = False,
134    skip_block: Callable[[Any], bool] = None,
135    mask: Optional[ArrayLike] = None,
136    disable_tqdm: bool = False,
137    tqdm_desc: str = "predict with halo",
138    prediction_function: Optional[Callable] = None,
139    roi: Optional[Tuple[slice]] = None,
140) -> ArrayLike:
141    """Run block-wise network prediction with a halo.
142
143    Args:
144        input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
145        model: The network.
146        gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`.
147        block_shape: The shape of the inner block to use for prediction.
148        halo: The shape of the halo to use for prediction
149        output: The output data, will be allocated if None is passed.
150            Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
151        preprocess: Function to preprocess input data before passing it to the network.
152        postprocess: Function to postprocess the network predictions.
153        with_channels: Whether the input has a channel axis.
154        skip_block: Function to evaluate whether a given input block will be skipped.
155        mask: Elements outside the mask will be ignored in the prediction.
156        disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
157        tqdm_desc: Fescription shown by the tqdm output.
158        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
159        roi: A region of interest of the input for which to run prediction.
160
161    Returns:
162        The model output.
163    """
164    devices = [torch.device(gpu) for gpu in gpu_ids]
165    models = [
166        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
167        for device in devices
168    ]
169
170    n_workers = len(gpu_ids)
171    shape = input_.shape
172    if with_channels:
173        shape = shape[1:]
174
175    ndim = len(shape)
176    assert len(block_shape) == len(halo) == ndim
177
178    if roi is None:
179        blocking = nt.blocking([0] * ndim, shape, block_shape)
180    else:
181        assert len(roi) == ndim
182        blocking_start = [0 if ro.start is None else ro.start for ro in roi]
183        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape)]
184        blocking = nt.blocking(blocking_start, blocking_stop, block_shape)
185
186    if output is None:
187        n_out = models[0][0].out_channels
188        output = np.zeros((n_out,) + shape, dtype="float32")
189
190    def predict_block(block_id):
191        worker_id = block_id % n_workers
192        net, device = models[worker_id]
193
194        with torch.no_grad():
195            block = blocking.getBlock(block_id)
196            offset = [beg for beg in block.begin]
197            inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape))
198
199            if mask is not None:
200                mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False)
201                mask_block = mask_block[inner_bb].astype("bool")
202                if mask_block.sum() == 0:
203                    return
204
205            inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels)
206
207            if skip_block is not None and skip_block(inp):
208                return
209
210            if preprocess is not None:
211                inp = preprocess(inp)
212
213            # add (channel) and batch axis
214            expand_dims = np.s_[None] if with_channels else np.s_[None, None]
215            inp = torch.from_numpy(inp[expand_dims]).to(device)
216
217            prediction = net(inp) if prediction_function is None else prediction_function(net, inp)
218
219            # allow for list of tensors
220            try:
221                prediction = prediction.cpu().numpy().squeeze(0)
222            except AttributeError:
223                prediction = prediction[0]
224                prediction = prediction.cpu().numpy().squeeze(0)
225
226            if postprocess is not None:
227                prediction = postprocess(prediction)
228
229            if prediction.ndim == ndim + 1:
230                inner_bb = (slice(None),) + inner_bb
231            prediction = prediction[inner_bb]
232
233            if mask is not None:
234                if prediction.ndim == ndim + 1:
235                    mask_block = np.concatenate(prediction.shape[0] * [mask_block[None]], axis=0)
236                prediction[~mask_block] = 0
237
238            bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
239            if isinstance(output, list):  # we have multiple outputs and split the prediction channels
240                for out, channel_slice in output:
241                    this_bb = bb if out.ndim == ndim else (slice(None),) + bb
242                    out[this_bb] = prediction[channel_slice]
243
244            else:  # we only have a single output array
245                if output.ndim == ndim + 1:
246                    bb = (slice(None),) + bb
247                output[bb] = prediction
248
249    n_blocks = blocking.numberOfBlocks
250    with futures.ThreadPoolExecutor(n_workers) as tp:
251        list(tqdm(tp.map(predict_block, range(n_blocks)), total=n_blocks, disable=disable_tqdm, desc=tqdm_desc))
252
253    return output

Run block-wise network prediction with a halo.

Arguments:
  • input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
  • model: The network.
  • gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass ["cpu"].
  • block_shape: The shape of the inner block to use for prediction.
  • halo: The shape of the halo to use for prediction
  • output: The output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
  • preprocess: Function to preprocess input data before passing it to the network.
  • postprocess: Function to postprocess the network predictions.
  • with_channels: Whether the input has a channel axis.
  • skip_block: Function to evaluate whether a given input block will be skipped.
  • mask: Elements outside the mask will be ignored in the prediction.
  • disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
  • tqdm_desc: Fescription shown by the tqdm output.
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures.
  • roi: A region of interest of the input for which to run prediction.
Returns:

The model output.