torch_em.util.prediction

  1import queue
  2import threading
  3import warnings
  4from copy import deepcopy
  5from concurrent import futures
  6from typing import Tuple, Union, Callable, Any, List, Optional
  7
  8import numpy as np
  9import bioimage_cpp as bic
 10import torch
 11from numpy.typing import ArrayLike
 12
 13try:
 14    from napari.utils import progress as tqdm
 15except ImportError:
 16    from tqdm import tqdm
 17
 18from ..transform.raw import standardize
 19
 20
 21def predict_with_padding(
 22    model: torch.nn.Module,
 23    input_: np.ndarray,
 24    min_divisible: Tuple[int, ...],
 25    device: Optional[Union[torch.device, str]] = None,
 26    with_channels: bool = False,
 27    prediction_function: Callable[[Any], Any] = None
 28) -> np.ndarray:
 29    """Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
 30
 31    Args:
 32        model: The model.
 33        input_: The input for prediction.
 34        min_divisible: The minimal factors the input shape must be divisible by.
 35            For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
 36        device: The device of the model. If not given, will be derived from the model parameters.
 37        with_channels: Whether the input data contains channels.
 38        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
 39
 40    Returns:
 41        np.ndarray: The ouptut of the model.
 42    """
 43    if with_channels:
 44        assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}"
 45        min_divisible_ = (1,) + min_divisible
 46    else:
 47        assert len(min_divisible) == input_.ndim
 48        min_divisible_ = min_divisible
 49
 50    if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)):
 51        pad_width = tuple(
 52            (0, 0 if sh % md == 0 else md - sh % md)
 53            for sh, md in zip(input_.shape, min_divisible_)
 54        )
 55        crop_padding = tuple(slice(0, sh) for sh in input_.shape)
 56        input_ = np.pad(input_, pad_width, mode="reflect")
 57    else:
 58        crop_padding = None
 59
 60    ndim = input_.ndim
 61    ndim_model = 1 + ndim if with_channels else 2 + ndim
 62
 63    if device is None:
 64        device = next(model.parameters()).device
 65
 66    expand_dim = (None,) * (ndim_model - ndim)
 67    with torch.no_grad():
 68        model_input = torch.from_numpy(input_[expand_dim]).to(device)
 69        output = model(model_input) if prediction_function is None else prediction_function(model, model_input)
 70        output = output.cpu().numpy()
 71
 72    if crop_padding is not None:
 73        crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding
 74        output = output[crop_padding]
 75
 76    return output
 77
 78
 79def _pad_for_shift_left(arr, pad_vox, with_channels, mode="constant", constant_values=0.0):
 80    pad_left = tuple(pad_vox)
 81    pad_right = tuple(0 for _ in pad_vox)
 82
 83    pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right))
 84    if with_channels:
 85        pad_width = ((0, 0),) + pad_width
 86
 87    arr_pad = np.pad(arr, pad_width, mode=mode, constant_values=constant_values)
 88    return arr_pad, pad_left
 89
 90
 91def _crop_after_shift_left(arr, pad_left, with_channels, original_shape_spatial):
 92    starts = pad_left
 93    stops = tuple(st + sh for st, sh in zip(starts, original_shape_spatial))
 94    spatial_slices = tuple(slice(st, sp) for st, sp in zip(starts, stops))
 95    return arr[(slice(None),) + spatial_slices] if with_channels else arr[spatial_slices]
 96
 97
 98def _load_block(input_, offset, block_shape, halo, padding_mode="reflect", with_channels=False):
 99    shape = input_.shape
100    if with_channels:
101        shape = shape[1:]
102
103    starts = [off - ha for off, ha in zip(offset, halo)]
104    stops = [off + bs + ha for off, bs, ha in zip(offset, block_shape, halo)]
105
106    # we pad the input volume if necessary
107    pad_left = None
108    pad_right = None
109
110    # check for padding to the left
111    if any(start < 0 for start in starts):
112        pad_left = tuple(abs(start) if start < 0 else 0 for start in starts)
113        starts = [max(0, start) for start in starts]
114
115    # check for padding to the right
116    if any(stop > shape[i] for i, stop in enumerate(stops)):
117        pad_right = tuple(stop - shape[i] if stop > shape[i] else 0 for i, stop in enumerate(stops))
118        stops = [min(shape[i], stop) for i, stop in enumerate(stops)]
119
120    bb = tuple(slice(start, stop) for start, stop in zip(starts, stops))
121    if with_channels:
122        data = input_[(slice(None),) + bb]
123    else:
124        data = input_[bb]
125
126    ndim = len(shape)
127    # pad if necessary
128    if pad_left is not None or pad_right is not None:
129        pad_left = (0,) * ndim if pad_left is None else pad_left
130        pad_right = (0,) * ndim if pad_right is None else pad_right
131        pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right))
132        if with_channels:
133            pad_width = ((0, 0),) + pad_width
134        data = np.pad(data, pad_width, mode=padding_mode)
135
136        # extend the bounding box for downstream
137        bb = tuple(
138            slice(b.start - pl, b.stop + pr)
139            for b, pl, pr in zip(bb, pad_left, pad_right)
140        )
141
142    return data, bb
143
144
145def predict_with_halo(
146    input_: ArrayLike,
147    model: torch.nn.Module,
148    gpu_ids: List[Union[str, int]],
149    block_shape: Tuple[int, ...],
150    halo: Tuple[int, ...],
151    output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None,
152    preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize,
153    postprocess: Callable[[np.ndarray], np.ndarray] = None,
154    with_channels: bool = False,
155    skip_block: Callable[[Any], bool] = None,
156    mask: Optional[ArrayLike] = None,
157    disable_tqdm: bool = False,
158    tqdm_desc: str = "predict with halo",
159    prediction_function: Optional[Callable] = None,
160    roi: Optional[Tuple[slice]] = None,
161    iter_list: Optional[List[int]] = None,
162    grid_shift: Optional[Tuple[float, ...]] = None,
163) -> ArrayLike:
164    """Run block-wise network prediction with a halo.
165
166    Args:
167        input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
168        model: The network.
169        gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`.
170        block_shape: The shape of the inner block to use for prediction.
171        halo: The shape of the halo to use for prediction
172        output: The output data, will be allocated if None is passed.
173            Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
174        preprocess: Function to preprocess input data before passing it to the network.
175        postprocess: Function to postprocess the network predictions.
176        with_channels: Whether the input has a channel axis.
177        skip_block: Function to evaluate whether a given input block will be skipped.
178        mask: Elements outside the mask will be ignored in the prediction.
179        disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
180        tqdm_desc: Fescription shown by the tqdm output.
181        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
182        roi: A region of interest of the input for which to run prediction.
183        iter_list: Optional list of block ids to iterate over.
184        grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0).
185    Returns:
186        The model output.
187    """
188    devices = [torch.device(gpu) for gpu in gpu_ids]
189    models = [
190        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
191        for device in devices
192    ]
193    n_workers = len(gpu_ids)
194
195    # original shape (spatial only)
196    shape0 = input_.shape
197    shape_spatial0 = shape0[1:] if with_channels else shape0
198    ndim = len(shape_spatial0)
199    assert len(block_shape) == len(halo) == ndim
200
201    # apply grid_shift via padding+cropping (zero padding)
202    input_eff = input_
203    mask_eff = mask
204
205    if grid_shift is not None:
206        assert len(grid_shift) == ndim, "grid_shift must match number of spatial dims"
207        pad_vox = tuple(int(np.rint(abs(gs) * bs)) for gs, bs in zip(grid_shift, block_shape))
208
209        if not isinstance(input_eff, np.ndarray):
210            raise TypeError("grid_shift padding currently requires input_ to be a numpy array")
211
212        input_eff, pad_left = _pad_for_shift_left(input_eff, pad_vox, with_channels=with_channels, mode="constant",
213                                                  constant_values=0)
214
215        if mask_eff is not None:
216            if not isinstance(mask_eff, np.ndarray):
217                raise TypeError("grid_shift padding currently requires mask to be a numpy array")
218            mask_eff, _ = _pad_for_shift_left(mask_eff, pad_vox, with_channels=False, mode="constant",
219                                              constant_values=0)
220    else:
221        pad_left = (0,) * ndim
222
223    # shapes after shift-padding
224    shape_eff = input_eff.shape
225    shape_spatial_eff = shape_eff[1:] if with_channels else shape_eff
226
227    # blocking (on the padded input)
228    if roi is None:
229        blocking = bic.utils.Blocking([0] * ndim, list(shape_spatial_eff), block_shape)
230    else:
231        assert len(roi) == ndim
232        blocking_start = [0 if ro.start is None else ro.start for ro in roi]
233        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial_eff)]
234        blocking = bic.utils.Blocking(blocking_start, blocking_stop, block_shape)
235
236    # output allocation (for padded shape)
237    if output is None:
238        n_out = models[0][0].out_channels
239        output = np.zeros((n_out,) + tuple(shape_spatial_eff), dtype="float32")
240    elif grid_shift:
241        raise ValueError(
242            "grid_shift is not supported together with a user-provided `output`, because "
243            "grid_shift requires internal zero-padding and a final cropping step. "
244            "Pass `output=None` (let this function allocate the output) or disable `grid_shift`. "
245            "Or pad the input manually beforehand."
246        )
247
248    def predict_block(block_id):
249        worker_id = block_id % n_workers
250        net, device = models[worker_id]
251
252        with torch.no_grad():
253            block = blocking.get_block(block_id)
254            offset = [beg for beg in block.begin]
255            inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape))
256
257            if mask_eff is not None:
258                mask_block, _ = _load_block(mask_eff, offset, block_shape, halo, with_channels=False)
259                mask_block = mask_block[inner_bb].astype("bool")
260                if mask_block.sum() == 0:
261                    return
262
263            inp, _ = _load_block(input_eff, offset, block_shape, halo, with_channels=with_channels)
264
265            if skip_block is not None and skip_block(inp):
266                return
267
268            if preprocess is not None:
269                inp = preprocess(inp)
270
271            # add (channel) and batch axis
272            expand_dims = np.s_[None] if with_channels else np.s_[None, None]
273            inp = torch.from_numpy(inp[expand_dims]).to(device)
274
275            prediction = net(inp) if prediction_function is None else prediction_function(net, inp)
276
277            # allow for list of tensors
278            try:
279                prediction = prediction.cpu().numpy().squeeze(0)
280            except AttributeError:
281                prediction = prediction[0]
282                prediction = prediction.cpu().numpy().squeeze(0)
283
284            if postprocess is not None:
285                prediction = postprocess(prediction)
286
287            if prediction.ndim == ndim + 1:
288                inner_bb_pred = (slice(None),) + inner_bb
289            else:
290                inner_bb_pred = inner_bb
291            prediction = prediction[inner_bb_pred]
292
293            if mask_eff is not None:
294                if prediction.ndim == ndim + 1:
295                    mb = np.broadcast_to(mask_block[None], prediction.shape)
296                else:
297                    mb = mask_block
298                prediction[~mb] = 0
299
300            bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
301            if isinstance(output, list):  # we have multiple outputs and split the prediction channels
302                for out, channel_slice in output:
303                    this_bb = bb if out.ndim == ndim else (slice(None),) + bb
304                    out[this_bb] = prediction[channel_slice]
305            else:  # we only have a single output array
306                if output.ndim == ndim + 1:
307                    bb = (slice(None),) + bb
308                output[bb] = prediction
309
310    n_blocks = blocking.number_of_blocks
311    iteration_ids = range(n_blocks) if iter_list is None else np.array(iter_list)
312
313    with futures.ThreadPoolExecutor(n_workers) as tp:
314        list(tqdm(tp.map(predict_block, iteration_ids),
315                  total=len(iteration_ids),
316                  disable=disable_tqdm,
317                  desc=tqdm_desc))
318
319    # crop away the shift padding so the returned output matches original shape
320    if grid_shift is not None:
321        output = _crop_after_shift_left(output, pad_left, with_channels=(output.ndim == ndim+1),
322                                        original_shape_spatial=tuple(shape_spatial0))
323
324    return output
325
326
327# Sentinel returned by _prepare_block_input when a block should be skipped.
328_SKIP = object()
329# Sentinel pushed onto the pipeline queues to signal end-of-stream.
330_STOP = object()
331
332
333class _Aborted(Exception):
334    """@private
335    Raised inside worker threads to unwind cleanly once `stop_event` is set."""
336
337
338class _AtomicCounter:
339    """@private
340    Lock-guarded integer counter used for sentinel reference counting."""
341
342    def __init__(self, value: int):
343        self._value = value
344        self._lock = threading.Lock()
345
346    def decrement(self) -> int:
347        with self._lock:
348            self._value -= 1
349            return self._value
350
351
352class _BlockJob:
353    """@private
354    A unit of work travelling through the prediction pipeline."""
355
356    __slots__ = ("block", "inner_bb", "mask_block", "tensor", "prediction")
357
358    def __init__(self, block, inner_bb, mask_block, tensor):
359        self.block = block
360        self.inner_bb = inner_bb
361        self.mask_block = mask_block
362        self.tensor = tensor  # CPU tensor [1, (C,) *spatial]; cleared after prediction
363        self.prediction = None  # filled by the consumer
364
365
366def _safe_get(q, stop_event, timeout=0.2):
367    """@private
368    Queue.get that aborts (raises _Aborted) once stop_event is set, to avoid deadlocks."""
369    while not stop_event.is_set():
370        try:
371            return q.get(timeout=timeout)
372        except queue.Empty:
373            continue
374    raise _Aborted()
375
376
377def _safe_put(q, item, stop_event, timeout=0.2):
378    """@private
379    Queue.put that aborts (raises _Aborted) once stop_event is set, to avoid deadlocks."""
380    while not stop_event.is_set():
381        try:
382            q.put(item, timeout=timeout)
383            return
384        except queue.Full:
385            continue
386    raise _Aborted()
387
388
389def _prepare_block_input(input_, mask, block, block_shape, halo, with_channels, skip_block, preprocess):
390    """@private
391    Producer-side block preparation: load + (optional) mask/skip check + preprocess.
392
393    Returns the `_SKIP` sentinel if the block should be skipped, otherwise a tuple
394    `(cpu_tensor, mask_block, inner_bb)` where `cpu_tensor` has a leading batch axis.
395    """
396    offset = [beg for beg in block.begin]
397    inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape))
398
399    mask_block = None
400    if mask is not None:
401        mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False)
402        mask_block = mask_block[inner_bb].astype("bool")
403        if mask_block.sum() == 0:
404            return _SKIP
405
406    inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels)
407
408    if skip_block is not None and skip_block(inp):
409        return _SKIP
410
411    if preprocess is not None:
412        inp = preprocess(inp)
413
414    # add (channel) and batch axis -> [1, (C,) *spatial]
415    expand_dims = np.s_[None] if with_channels else np.s_[None, None]
416    tensor = torch.from_numpy(inp[expand_dims])
417    return tensor, mask_block, inner_bb
418
419
420def _write_prediction(prediction, block, output, ndim, mask_block, inner_bb, postprocess):
421    """@private
422    Writer-side logic: postprocess + inner crop + mask-zero + write to `output`."""
423    if postprocess is not None:
424        prediction = postprocess(prediction)
425
426    if prediction.ndim == ndim + 1:
427        inner_bb_pred = (slice(None),) + inner_bb
428    else:
429        inner_bb_pred = inner_bb
430    prediction = prediction[inner_bb_pred]
431
432    if mask_block is not None:
433        if prediction.ndim == ndim + 1:
434            mb = np.broadcast_to(mask_block[None], prediction.shape)
435        else:
436            mb = mask_block
437        prediction[~mb] = 0
438
439    bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
440    if isinstance(output, list):  # multiple outputs: split the prediction channels
441        for out, channel_slice in output:
442            this_bb = bb if out.ndim == ndim else (slice(None),) + bb
443            out[this_bb] = prediction[channel_slice]
444    else:  # single output array
445        if output.ndim == ndim + 1:
446            bb = (slice(None),) + bb
447        output[bb] = prediction
448
449
450def _concurrent_write_safe(arr, block_shape, start):
451    """@private
452    Whether `arr` can be written to from multiple threads concurrently for the given block grid.
453
454    - numpy arrays are always safe (writers touch disjoint in-memory regions).
455    - hdf5 datasets are never safe (h5py is not thread-safe for concurrent writes).
456    - zarr / n5 datasets are safe iff the block grid is aligned with the chunks (and the shards,
457      for zarr v3), so that each chunk/shard is written by exactly one block.
458    - unknown / unchunked backends are treated conservatively as unsafe.
459    """
460    if isinstance(arr, np.ndarray):
461        return True
462
463    module = type(arr).__module__
464    if module.startswith("h5py"):
465        return False
466
467    chunks = getattr(arr, "chunks", None)
468    if chunks is None:  # unknown backend or unchunked -> be conservative
469        return False
470
471    # zarr v3 exposes the shard shape via .shards (None if not sharded); the shard is the atomic
472    # write unit when present, otherwise the chunk is. getattr covers z5py / older zarr (no shards).
473    shards = getattr(arr, "shards", None)
474    unit = shards if shards is not None else chunks
475
476    # compare only the spatial axes: every block writes the full channel range at a disjoint
477    # spatial bounding box, so channel-axis chunking never causes a write conflict.
478    ndim = len(block_shape)
479    unit_spatial = tuple(unit[-ndim:])
480    if any(bs % u != 0 for bs, u in zip(block_shape, unit_spatial)):
481        return False
482    if any(s % u != 0 for s, u in zip(start, unit_spatial)):
483        return False
484    return True
485
486
487def predict_with_halo_pipelined(
488    input_: ArrayLike,
489    model: torch.nn.Module,
490    gpu_ids: List[Union[str, int]],
491    block_shape: Tuple[int, ...],
492    halo: Tuple[int, ...],
493    output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None,
494    preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize,
495    postprocess: Callable[[np.ndarray], np.ndarray] = None,
496    with_channels: bool = False,
497    skip_block: Callable[[Any], bool] = None,
498    mask: Optional[ArrayLike] = None,
499    disable_tqdm: bool = False,
500    tqdm_desc: str = "predict with halo (pipelined)",
501    prediction_function: Optional[Callable] = None,
502    roi: Optional[Tuple[slice]] = None,
503    iter_list: Optional[List[int]] = None,
504    batch_size: int = 1,
505    num_prefetch_workers: int = 4,
506    queue_size: Optional[int] = None,
507    num_write_workers: int = 1,
508    write_queue_size: Optional[int] = None,
509    grid_shift: Optional[Tuple[float, ...]] = None,
510) -> ArrayLike:
511    """Run block-wise network prediction with a halo, pipelined for higher GPU throughput.
512
513    This is an alternate implementation of `predict_with_halo` that decouples block
514    loading, GPU prediction and output writing into a producer-consumer pipeline
515    connected by queues:
516
517        producers (CPU threads: load + preprocess) -> input queue
518          -> consumer(s), one per GPU (stack a batch, predict, unstack) -> output queue
519          -> writer thread(s) (postprocess + write).
520
521    While the GPU works on one batch, the prefetch workers load and preprocess the
522    next blocks and the writer drains finished predictions, keeping the GPU fed.
523    Blocks can additionally be stacked into batches for one forward pass via `batch_size`.
524
525    The pipeline is thread-based (not multiprocessing) so that lazy hdf5/zarr/n5 inputs
526    (whose file handles are not fork/pickle-safe) work, and so that writers can share the
527    output array directly. Note that heavy *Python-level* `preprocess`/`postprocess`
528    callbacks will not parallelize across prefetch workers due to the GIL; the default
529    `standardize` is numpy-vectorized and releases the GIL.
530
531    Args:
532        input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar.
533        model: The network.
534        gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`.
535            One prediction consumer thread (with its own model replica) is run per device.
536        block_shape: The shape of the inner block to use for prediction.
537        halo: The shape of the halo to use for prediction.
538        output: The output data, will be allocated if None is passed.
539            Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
540        preprocess: Function to preprocess input data before passing it to the network.
541        postprocess: Function to postprocess the network predictions.
542        with_channels: Whether the input has a channel axis.
543        skip_block: Function to evaluate whether a given input block will be skipped.
544        mask: Elements outside the mask will be ignored in the prediction.
545        disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
546        tqdm_desc: Description shown by the tqdm output.
547        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
548            It must operate on the leading batch axis; with the default `batch_size=1` it does not need changes.
549        roi: A region of interest of the input for which to run prediction.
550        iter_list: Optional list of block ids to iterate over.
551        batch_size: The number of blocks stacked into a single forward pass. Trades GPU memory for throughput.
552        num_prefetch_workers: The number of CPU threads used to load and preprocess blocks.
553        queue_size: The maximum size of the input (prefetch) queue. Provides backpressure to bound memory use.
554            If None, a value derived from the number of devices and `batch_size` is used.
555        num_write_workers: The number of threads used to write predictions. Values > 1 are safe for in-memory
556            numpy outputs, and for zarr/n5 outputs whose chunks (and shards, for zarr v3) are aligned with
557            block_shape. For hdf5, misaligned zarr/n5, or other outputs this is automatically clamped to 1.
558        write_queue_size: The maximum size of the output (write) queue. If None, a default is used.
559        grid_shift: Not supported by this function; raises NotImplementedError if passed. Use `predict_with_halo`.
560
561    Returns:
562        The model output.
563    """
564    if grid_shift is not None:
565        raise NotImplementedError(
566            "grid_shift is not supported by predict_with_halo_pipelined. "
567            "Use predict_with_halo for grid_shift, or pre-pad the input and use roi."
568        )
569
570    batch_size = max(1, int(batch_size))
571    num_prefetch_workers = max(1, int(num_prefetch_workers))
572    num_write_workers = max(1, int(num_write_workers))
573
574    devices = [torch.device(gpu) for gpu in gpu_ids]
575    models = [
576        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
577        for device in devices
578    ]
579    n_consumers = len(devices)
580
581    shape0 = input_.shape
582    shape_spatial = shape0[1:] if with_channels else shape0
583    ndim = len(shape_spatial)
584    assert len(block_shape) == len(halo) == ndim
585
586    # blocking
587    if roi is None:
588        block_start = [0] * ndim
589        blocking = bic.utils.Blocking(block_start, list(shape_spatial), block_shape)
590    else:
591        assert len(roi) == ndim
592        block_start = [0 if ro.start is None else ro.start for ro in roi]
593        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial)]
594        blocking = bic.utils.Blocking(block_start, blocking_stop, block_shape)
595
596    # output allocation
597    if output is None:
598        n_out = models[0][0].out_channels
599        output = np.zeros((n_out,) + tuple(shape_spatial), dtype="float32")
600
601    # guard against unsafe concurrent writes: numpy is always safe (disjoint regions),
602    # zarr/n5 are safe when their chunks/shards are aligned with the blocks, hdf5 is not.
603    if num_write_workers > 1:
604        out_arrays = [o for o, _ in output] if isinstance(output, list) else [output]
605        if any(not _concurrent_write_safe(o, block_shape, block_start) for o in out_arrays):
606            warnings.warn(
607                "num_write_workers > 1 requires either an in-memory numpy output or a zarr/n5 "
608                "output whose chunks (and shards, for zarr v3) are aligned with block_shape; "
609                "falling back to a single writer. HDF5 outputs are never safe for concurrent writes."
610            )
611            num_write_workers = 1
612
613    # queue sizes
614    if queue_size is None:
615        queue_size = max(2 * n_consumers * batch_size, 2 * batch_size)
616    queue_size = max(queue_size, batch_size)
617    if write_queue_size is None:
618        write_queue_size = max(2 * n_consumers, 4)
619
620    n_blocks = blocking.number_of_blocks
621    iteration_ids = list(range(n_blocks)) if iter_list is None else list(iter_list)
622    total = len(iteration_ids)
623
624    # pre-fill the block-id queue with all ids followed by one STOP per producer
625    id_queue = queue.Queue()
626    for bid in iteration_ids:
627        id_queue.put(bid)
628    for _ in range(num_prefetch_workers):
629        id_queue.put(_STOP)
630
631    input_queue = queue.Queue(maxsize=queue_size)
632    output_queue = queue.Queue(maxsize=write_queue_size)
633
634    stop_event = threading.Event()
635    error_box = []
636    error_lock = threading.Lock()
637    progress_lock = threading.Lock()
638    pbar = tqdm(total=total, disable=disable_tqdm, desc=tqdm_desc)
639
640    remaining_producers = _AtomicCounter(num_prefetch_workers)
641    remaining_consumers = _AtomicCounter(n_consumers)
642
643    def record_error(exc):
644        with error_lock:
645            if not error_box:
646                error_box.append(exc)
647        stop_event.set()
648
649    def producer():
650        try:
651            while True:
652                bid = id_queue.get()
653                if bid is _STOP or stop_event.is_set():
654                    break
655                block = blocking.get_block(bid)
656                result = _prepare_block_input(
657                    input_, mask, block, block_shape, halo, with_channels, skip_block, preprocess
658                )
659                if result is _SKIP:
660                    with progress_lock:
661                        pbar.update(1)
662                    continue
663                tensor, mask_block, inner_bb = result
664                _safe_put(input_queue, _BlockJob(block, inner_bb, mask_block, tensor), stop_event)
665        except _Aborted:
666            pass
667        except Exception as e:  # noqa
668            record_error(e)
669        finally:
670            # the last producer to finish signals the consumers (skipped on the abort path,
671            # where consumers unwind via _safe_get instead)
672            if remaining_producers.decrement() == 0 and not stop_event.is_set():
673                for _ in range(n_consumers):
674                    input_queue.put(_STOP)
675
676    def consumer(worker_id):
677        net, device = models[worker_id]
678        try:
679            while True:
680                jobs = []
681                got_stop = False
682                while len(jobs) < batch_size:
683                    item = _safe_get(input_queue, stop_event)
684                    if item is _STOP:
685                        got_stop = True
686                        break
687                    jobs.append(item)
688
689                if jobs:  # run (possibly partial) batch
690                    batch = torch.cat([job.tensor for job in jobs], dim=0).to(device)
691                    with torch.no_grad():
692                        prediction = net(batch) if prediction_function is None \
693                            else prediction_function(net, batch)
694                    if not torch.is_tensor(prediction):  # list/tuple of outputs -> take the first
695                        prediction = prediction[0]
696                    prediction = prediction.cpu().numpy()
697                    for i, job in enumerate(jobs):
698                        job.prediction = np.array(prediction[i])
699                        job.tensor = None
700                        _safe_put(output_queue, job, stop_event)
701
702                if got_stop:
703                    break
704        except _Aborted:
705            pass
706        except Exception as e:  # noqa
707            record_error(e)
708        finally:
709            if remaining_consumers.decrement() == 0 and not stop_event.is_set():
710                for _ in range(num_write_workers):
711                    output_queue.put(_STOP)
712
713    def writer():
714        try:
715            while True:
716                job = _safe_get(output_queue, stop_event)
717                if job is _STOP:
718                    break
719                _write_prediction(
720                    job.prediction, job.block, output, ndim, job.mask_block, job.inner_bb, postprocess
721                )
722                with progress_lock:
723                    pbar.update(1)
724        except _Aborted:
725            pass
726        except Exception as e:  # noqa
727            record_error(e)
728
729    writers = [threading.Thread(target=writer, name=f"predict-writer-{i}") for i in range(num_write_workers)]
730    consumers = [threading.Thread(target=consumer, args=(i,), name=f"predict-consumer-{i}")
731                 for i in range(n_consumers)]
732    producers = [threading.Thread(target=producer, name=f"predict-producer-{i}")
733                 for i in range(num_prefetch_workers)]
734    threads = writers + consumers + producers
735
736    try:
737        for t in writers:
738            t.start()
739        for t in consumers:
740            t.start()
741        for t in producers:
742            t.start()
743
744        for t in producers:
745            t.join()
746        for t in consumers:
747            t.join()
748        for t in writers:
749            t.join()
750    finally:
751        stop_event.set()
752        for t in threads:
753            t.join()
754        pbar.close()
755
756    if error_box:
757        raise error_box[0]
758
759    return output
def predict_with_padding( model: torch.nn.modules.module.Module, input_: numpy.ndarray, min_divisible: Tuple[int, ...], device: Union[torch.device, str, NoneType] = None, with_channels: bool = False, prediction_function: Callable[[Any], Any] = None) -> numpy.ndarray:
22def predict_with_padding(
23    model: torch.nn.Module,
24    input_: np.ndarray,
25    min_divisible: Tuple[int, ...],
26    device: Optional[Union[torch.device, str]] = None,
27    with_channels: bool = False,
28    prediction_function: Callable[[Any], Any] = None
29) -> np.ndarray:
30    """Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
31
32    Args:
33        model: The model.
34        input_: The input for prediction.
35        min_divisible: The minimal factors the input shape must be divisible by.
36            For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
37        device: The device of the model. If not given, will be derived from the model parameters.
38        with_channels: Whether the input data contains channels.
39        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
40
41    Returns:
42        np.ndarray: The ouptut of the model.
43    """
44    if with_channels:
45        assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}"
46        min_divisible_ = (1,) + min_divisible
47    else:
48        assert len(min_divisible) == input_.ndim
49        min_divisible_ = min_divisible
50
51    if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)):
52        pad_width = tuple(
53            (0, 0 if sh % md == 0 else md - sh % md)
54            for sh, md in zip(input_.shape, min_divisible_)
55        )
56        crop_padding = tuple(slice(0, sh) for sh in input_.shape)
57        input_ = np.pad(input_, pad_width, mode="reflect")
58    else:
59        crop_padding = None
60
61    ndim = input_.ndim
62    ndim_model = 1 + ndim if with_channels else 2 + ndim
63
64    if device is None:
65        device = next(model.parameters()).device
66
67    expand_dim = (None,) * (ndim_model - ndim)
68    with torch.no_grad():
69        model_input = torch.from_numpy(input_[expand_dim]).to(device)
70        output = model(model_input) if prediction_function is None else prediction_function(model, model_input)
71        output = output.cpu().numpy()
72
73    if crop_padding is not None:
74        crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding
75        output = output[crop_padding]
76
77    return output

Run prediction with padding for a model that can only deal with inputs divisible by specific factors.

Arguments:
  • model: The model.
  • input_: The input for prediction.
  • min_divisible: The minimal factors the input shape must be divisible by. For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
  • device: The device of the model. If not given, will be derived from the model parameters.
  • with_channels: Whether the input data contains channels.
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures.
Returns:

np.ndarray: The ouptut of the model.

def predict_with_halo( input_: Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], model: torch.nn.modules.module.Module, gpu_ids: List[Union[str, int]], block_shape: Tuple[int, ...], halo: Tuple[int, ...], output: Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], List[Tuple[Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], slice]], NoneType] = None, preprocess: Callable[[Union[torch.Tensor, numpy.ndarray]], Union[torch.Tensor, numpy.ndarray]] = <function standardize>, postprocess: Callable[[numpy.ndarray], numpy.ndarray] = None, with_channels: bool = False, skip_block: Callable[[Any], bool] = None, mask: Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], NoneType] = None, disable_tqdm: bool = False, tqdm_desc: str = 'predict with halo', prediction_function: Optional[Callable] = None, roi: Optional[Tuple[slice]] = None, iter_list: Optional[List[int]] = None, grid_shift: Optional[Tuple[float, ...]] = None) -> Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]]:
146def predict_with_halo(
147    input_: ArrayLike,
148    model: torch.nn.Module,
149    gpu_ids: List[Union[str, int]],
150    block_shape: Tuple[int, ...],
151    halo: Tuple[int, ...],
152    output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None,
153    preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize,
154    postprocess: Callable[[np.ndarray], np.ndarray] = None,
155    with_channels: bool = False,
156    skip_block: Callable[[Any], bool] = None,
157    mask: Optional[ArrayLike] = None,
158    disable_tqdm: bool = False,
159    tqdm_desc: str = "predict with halo",
160    prediction_function: Optional[Callable] = None,
161    roi: Optional[Tuple[slice]] = None,
162    iter_list: Optional[List[int]] = None,
163    grid_shift: Optional[Tuple[float, ...]] = None,
164) -> ArrayLike:
165    """Run block-wise network prediction with a halo.
166
167    Args:
168        input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
169        model: The network.
170        gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`.
171        block_shape: The shape of the inner block to use for prediction.
172        halo: The shape of the halo to use for prediction
173        output: The output data, will be allocated if None is passed.
174            Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
175        preprocess: Function to preprocess input data before passing it to the network.
176        postprocess: Function to postprocess the network predictions.
177        with_channels: Whether the input has a channel axis.
178        skip_block: Function to evaluate whether a given input block will be skipped.
179        mask: Elements outside the mask will be ignored in the prediction.
180        disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
181        tqdm_desc: Fescription shown by the tqdm output.
182        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
183        roi: A region of interest of the input for which to run prediction.
184        iter_list: Optional list of block ids to iterate over.
185        grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0).
186    Returns:
187        The model output.
188    """
189    devices = [torch.device(gpu) for gpu in gpu_ids]
190    models = [
191        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
192        for device in devices
193    ]
194    n_workers = len(gpu_ids)
195
196    # original shape (spatial only)
197    shape0 = input_.shape
198    shape_spatial0 = shape0[1:] if with_channels else shape0
199    ndim = len(shape_spatial0)
200    assert len(block_shape) == len(halo) == ndim
201
202    # apply grid_shift via padding+cropping (zero padding)
203    input_eff = input_
204    mask_eff = mask
205
206    if grid_shift is not None:
207        assert len(grid_shift) == ndim, "grid_shift must match number of spatial dims"
208        pad_vox = tuple(int(np.rint(abs(gs) * bs)) for gs, bs in zip(grid_shift, block_shape))
209
210        if not isinstance(input_eff, np.ndarray):
211            raise TypeError("grid_shift padding currently requires input_ to be a numpy array")
212
213        input_eff, pad_left = _pad_for_shift_left(input_eff, pad_vox, with_channels=with_channels, mode="constant",
214                                                  constant_values=0)
215
216        if mask_eff is not None:
217            if not isinstance(mask_eff, np.ndarray):
218                raise TypeError("grid_shift padding currently requires mask to be a numpy array")
219            mask_eff, _ = _pad_for_shift_left(mask_eff, pad_vox, with_channels=False, mode="constant",
220                                              constant_values=0)
221    else:
222        pad_left = (0,) * ndim
223
224    # shapes after shift-padding
225    shape_eff = input_eff.shape
226    shape_spatial_eff = shape_eff[1:] if with_channels else shape_eff
227
228    # blocking (on the padded input)
229    if roi is None:
230        blocking = bic.utils.Blocking([0] * ndim, list(shape_spatial_eff), block_shape)
231    else:
232        assert len(roi) == ndim
233        blocking_start = [0 if ro.start is None else ro.start for ro in roi]
234        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial_eff)]
235        blocking = bic.utils.Blocking(blocking_start, blocking_stop, block_shape)
236
237    # output allocation (for padded shape)
238    if output is None:
239        n_out = models[0][0].out_channels
240        output = np.zeros((n_out,) + tuple(shape_spatial_eff), dtype="float32")
241    elif grid_shift:
242        raise ValueError(
243            "grid_shift is not supported together with a user-provided `output`, because "
244            "grid_shift requires internal zero-padding and a final cropping step. "
245            "Pass `output=None` (let this function allocate the output) or disable `grid_shift`. "
246            "Or pad the input manually beforehand."
247        )
248
249    def predict_block(block_id):
250        worker_id = block_id % n_workers
251        net, device = models[worker_id]
252
253        with torch.no_grad():
254            block = blocking.get_block(block_id)
255            offset = [beg for beg in block.begin]
256            inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape))
257
258            if mask_eff is not None:
259                mask_block, _ = _load_block(mask_eff, offset, block_shape, halo, with_channels=False)
260                mask_block = mask_block[inner_bb].astype("bool")
261                if mask_block.sum() == 0:
262                    return
263
264            inp, _ = _load_block(input_eff, offset, block_shape, halo, with_channels=with_channels)
265
266            if skip_block is not None and skip_block(inp):
267                return
268
269            if preprocess is not None:
270                inp = preprocess(inp)
271
272            # add (channel) and batch axis
273            expand_dims = np.s_[None] if with_channels else np.s_[None, None]
274            inp = torch.from_numpy(inp[expand_dims]).to(device)
275
276            prediction = net(inp) if prediction_function is None else prediction_function(net, inp)
277
278            # allow for list of tensors
279            try:
280                prediction = prediction.cpu().numpy().squeeze(0)
281            except AttributeError:
282                prediction = prediction[0]
283                prediction = prediction.cpu().numpy().squeeze(0)
284
285            if postprocess is not None:
286                prediction = postprocess(prediction)
287
288            if prediction.ndim == ndim + 1:
289                inner_bb_pred = (slice(None),) + inner_bb
290            else:
291                inner_bb_pred = inner_bb
292            prediction = prediction[inner_bb_pred]
293
294            if mask_eff is not None:
295                if prediction.ndim == ndim + 1:
296                    mb = np.broadcast_to(mask_block[None], prediction.shape)
297                else:
298                    mb = mask_block
299                prediction[~mb] = 0
300
301            bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
302            if isinstance(output, list):  # we have multiple outputs and split the prediction channels
303                for out, channel_slice in output:
304                    this_bb = bb if out.ndim == ndim else (slice(None),) + bb
305                    out[this_bb] = prediction[channel_slice]
306            else:  # we only have a single output array
307                if output.ndim == ndim + 1:
308                    bb = (slice(None),) + bb
309                output[bb] = prediction
310
311    n_blocks = blocking.number_of_blocks
312    iteration_ids = range(n_blocks) if iter_list is None else np.array(iter_list)
313
314    with futures.ThreadPoolExecutor(n_workers) as tp:
315        list(tqdm(tp.map(predict_block, iteration_ids),
316                  total=len(iteration_ids),
317                  disable=disable_tqdm,
318                  desc=tqdm_desc))
319
320    # crop away the shift padding so the returned output matches original shape
321    if grid_shift is not None:
322        output = _crop_after_shift_left(output, pad_left, with_channels=(output.ndim == ndim+1),
323                                        original_shape_spatial=tuple(shape_spatial0))
324
325    return output

Run block-wise network prediction with a halo.

Arguments:
  • input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
  • model: The network.
  • gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass ["cpu"].
  • block_shape: The shape of the inner block to use for prediction.
  • halo: The shape of the halo to use for prediction
  • output: The output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
  • preprocess: Function to preprocess input data before passing it to the network.
  • postprocess: Function to postprocess the network predictions.
  • with_channels: Whether the input has a channel axis.
  • skip_block: Function to evaluate whether a given input block will be skipped.
  • mask: Elements outside the mask will be ignored in the prediction.
  • disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
  • tqdm_desc: Fescription shown by the tqdm output.
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures.
  • roi: A region of interest of the input for which to run prediction.
  • iter_list: Optional list of block ids to iterate over.
  • grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0).
Returns:

The model output.

def predict_with_halo_pipelined( input_: Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], model: torch.nn.modules.module.Module, gpu_ids: List[Union[str, int]], block_shape: Tuple[int, ...], halo: Tuple[int, ...], output: Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], List[Tuple[Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], slice]], NoneType] = None, preprocess: Callable[[Union[torch.Tensor, numpy.ndarray]], Union[torch.Tensor, numpy.ndarray]] = <function standardize>, postprocess: Callable[[numpy.ndarray], numpy.ndarray] = None, with_channels: bool = False, skip_block: Callable[[Any], bool] = None, mask: Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], NoneType] = None, disable_tqdm: bool = False, tqdm_desc: str = 'predict with halo (pipelined)', prediction_function: Optional[Callable] = None, roi: Optional[Tuple[slice]] = None, iter_list: Optional[List[int]] = None, batch_size: int = 1, num_prefetch_workers: int = 4, queue_size: Optional[int] = None, num_write_workers: int = 1, write_queue_size: Optional[int] = None, grid_shift: Optional[Tuple[float, ...]] = None) -> Union[numpy._typing._array_like._Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]]:
488def predict_with_halo_pipelined(
489    input_: ArrayLike,
490    model: torch.nn.Module,
491    gpu_ids: List[Union[str, int]],
492    block_shape: Tuple[int, ...],
493    halo: Tuple[int, ...],
494    output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None,
495    preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize,
496    postprocess: Callable[[np.ndarray], np.ndarray] = None,
497    with_channels: bool = False,
498    skip_block: Callable[[Any], bool] = None,
499    mask: Optional[ArrayLike] = None,
500    disable_tqdm: bool = False,
501    tqdm_desc: str = "predict with halo (pipelined)",
502    prediction_function: Optional[Callable] = None,
503    roi: Optional[Tuple[slice]] = None,
504    iter_list: Optional[List[int]] = None,
505    batch_size: int = 1,
506    num_prefetch_workers: int = 4,
507    queue_size: Optional[int] = None,
508    num_write_workers: int = 1,
509    write_queue_size: Optional[int] = None,
510    grid_shift: Optional[Tuple[float, ...]] = None,
511) -> ArrayLike:
512    """Run block-wise network prediction with a halo, pipelined for higher GPU throughput.
513
514    This is an alternate implementation of `predict_with_halo` that decouples block
515    loading, GPU prediction and output writing into a producer-consumer pipeline
516    connected by queues:
517
518        producers (CPU threads: load + preprocess) -> input queue
519          -> consumer(s), one per GPU (stack a batch, predict, unstack) -> output queue
520          -> writer thread(s) (postprocess + write).
521
522    While the GPU works on one batch, the prefetch workers load and preprocess the
523    next blocks and the writer drains finished predictions, keeping the GPU fed.
524    Blocks can additionally be stacked into batches for one forward pass via `batch_size`.
525
526    The pipeline is thread-based (not multiprocessing) so that lazy hdf5/zarr/n5 inputs
527    (whose file handles are not fork/pickle-safe) work, and so that writers can share the
528    output array directly. Note that heavy *Python-level* `preprocess`/`postprocess`
529    callbacks will not parallelize across prefetch workers due to the GIL; the default
530    `standardize` is numpy-vectorized and releases the GIL.
531
532    Args:
533        input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar.
534        model: The network.
535        gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`.
536            One prediction consumer thread (with its own model replica) is run per device.
537        block_shape: The shape of the inner block to use for prediction.
538        halo: The shape of the halo to use for prediction.
539        output: The output data, will be allocated if None is passed.
540            Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
541        preprocess: Function to preprocess input data before passing it to the network.
542        postprocess: Function to postprocess the network predictions.
543        with_channels: Whether the input has a channel axis.
544        skip_block: Function to evaluate whether a given input block will be skipped.
545        mask: Elements outside the mask will be ignored in the prediction.
546        disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
547        tqdm_desc: Description shown by the tqdm output.
548        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
549            It must operate on the leading batch axis; with the default `batch_size=1` it does not need changes.
550        roi: A region of interest of the input for which to run prediction.
551        iter_list: Optional list of block ids to iterate over.
552        batch_size: The number of blocks stacked into a single forward pass. Trades GPU memory for throughput.
553        num_prefetch_workers: The number of CPU threads used to load and preprocess blocks.
554        queue_size: The maximum size of the input (prefetch) queue. Provides backpressure to bound memory use.
555            If None, a value derived from the number of devices and `batch_size` is used.
556        num_write_workers: The number of threads used to write predictions. Values > 1 are safe for in-memory
557            numpy outputs, and for zarr/n5 outputs whose chunks (and shards, for zarr v3) are aligned with
558            block_shape. For hdf5, misaligned zarr/n5, or other outputs this is automatically clamped to 1.
559        write_queue_size: The maximum size of the output (write) queue. If None, a default is used.
560        grid_shift: Not supported by this function; raises NotImplementedError if passed. Use `predict_with_halo`.
561
562    Returns:
563        The model output.
564    """
565    if grid_shift is not None:
566        raise NotImplementedError(
567            "grid_shift is not supported by predict_with_halo_pipelined. "
568            "Use predict_with_halo for grid_shift, or pre-pad the input and use roi."
569        )
570
571    batch_size = max(1, int(batch_size))
572    num_prefetch_workers = max(1, int(num_prefetch_workers))
573    num_write_workers = max(1, int(num_write_workers))
574
575    devices = [torch.device(gpu) for gpu in gpu_ids]
576    models = [
577        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
578        for device in devices
579    ]
580    n_consumers = len(devices)
581
582    shape0 = input_.shape
583    shape_spatial = shape0[1:] if with_channels else shape0
584    ndim = len(shape_spatial)
585    assert len(block_shape) == len(halo) == ndim
586
587    # blocking
588    if roi is None:
589        block_start = [0] * ndim
590        blocking = bic.utils.Blocking(block_start, list(shape_spatial), block_shape)
591    else:
592        assert len(roi) == ndim
593        block_start = [0 if ro.start is None else ro.start for ro in roi]
594        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial)]
595        blocking = bic.utils.Blocking(block_start, blocking_stop, block_shape)
596
597    # output allocation
598    if output is None:
599        n_out = models[0][0].out_channels
600        output = np.zeros((n_out,) + tuple(shape_spatial), dtype="float32")
601
602    # guard against unsafe concurrent writes: numpy is always safe (disjoint regions),
603    # zarr/n5 are safe when their chunks/shards are aligned with the blocks, hdf5 is not.
604    if num_write_workers > 1:
605        out_arrays = [o for o, _ in output] if isinstance(output, list) else [output]
606        if any(not _concurrent_write_safe(o, block_shape, block_start) for o in out_arrays):
607            warnings.warn(
608                "num_write_workers > 1 requires either an in-memory numpy output or a zarr/n5 "
609                "output whose chunks (and shards, for zarr v3) are aligned with block_shape; "
610                "falling back to a single writer. HDF5 outputs are never safe for concurrent writes."
611            )
612            num_write_workers = 1
613
614    # queue sizes
615    if queue_size is None:
616        queue_size = max(2 * n_consumers * batch_size, 2 * batch_size)
617    queue_size = max(queue_size, batch_size)
618    if write_queue_size is None:
619        write_queue_size = max(2 * n_consumers, 4)
620
621    n_blocks = blocking.number_of_blocks
622    iteration_ids = list(range(n_blocks)) if iter_list is None else list(iter_list)
623    total = len(iteration_ids)
624
625    # pre-fill the block-id queue with all ids followed by one STOP per producer
626    id_queue = queue.Queue()
627    for bid in iteration_ids:
628        id_queue.put(bid)
629    for _ in range(num_prefetch_workers):
630        id_queue.put(_STOP)
631
632    input_queue = queue.Queue(maxsize=queue_size)
633    output_queue = queue.Queue(maxsize=write_queue_size)
634
635    stop_event = threading.Event()
636    error_box = []
637    error_lock = threading.Lock()
638    progress_lock = threading.Lock()
639    pbar = tqdm(total=total, disable=disable_tqdm, desc=tqdm_desc)
640
641    remaining_producers = _AtomicCounter(num_prefetch_workers)
642    remaining_consumers = _AtomicCounter(n_consumers)
643
644    def record_error(exc):
645        with error_lock:
646            if not error_box:
647                error_box.append(exc)
648        stop_event.set()
649
650    def producer():
651        try:
652            while True:
653                bid = id_queue.get()
654                if bid is _STOP or stop_event.is_set():
655                    break
656                block = blocking.get_block(bid)
657                result = _prepare_block_input(
658                    input_, mask, block, block_shape, halo, with_channels, skip_block, preprocess
659                )
660                if result is _SKIP:
661                    with progress_lock:
662                        pbar.update(1)
663                    continue
664                tensor, mask_block, inner_bb = result
665                _safe_put(input_queue, _BlockJob(block, inner_bb, mask_block, tensor), stop_event)
666        except _Aborted:
667            pass
668        except Exception as e:  # noqa
669            record_error(e)
670        finally:
671            # the last producer to finish signals the consumers (skipped on the abort path,
672            # where consumers unwind via _safe_get instead)
673            if remaining_producers.decrement() == 0 and not stop_event.is_set():
674                for _ in range(n_consumers):
675                    input_queue.put(_STOP)
676
677    def consumer(worker_id):
678        net, device = models[worker_id]
679        try:
680            while True:
681                jobs = []
682                got_stop = False
683                while len(jobs) < batch_size:
684                    item = _safe_get(input_queue, stop_event)
685                    if item is _STOP:
686                        got_stop = True
687                        break
688                    jobs.append(item)
689
690                if jobs:  # run (possibly partial) batch
691                    batch = torch.cat([job.tensor for job in jobs], dim=0).to(device)
692                    with torch.no_grad():
693                        prediction = net(batch) if prediction_function is None \
694                            else prediction_function(net, batch)
695                    if not torch.is_tensor(prediction):  # list/tuple of outputs -> take the first
696                        prediction = prediction[0]
697                    prediction = prediction.cpu().numpy()
698                    for i, job in enumerate(jobs):
699                        job.prediction = np.array(prediction[i])
700                        job.tensor = None
701                        _safe_put(output_queue, job, stop_event)
702
703                if got_stop:
704                    break
705        except _Aborted:
706            pass
707        except Exception as e:  # noqa
708            record_error(e)
709        finally:
710            if remaining_consumers.decrement() == 0 and not stop_event.is_set():
711                for _ in range(num_write_workers):
712                    output_queue.put(_STOP)
713
714    def writer():
715        try:
716            while True:
717                job = _safe_get(output_queue, stop_event)
718                if job is _STOP:
719                    break
720                _write_prediction(
721                    job.prediction, job.block, output, ndim, job.mask_block, job.inner_bb, postprocess
722                )
723                with progress_lock:
724                    pbar.update(1)
725        except _Aborted:
726            pass
727        except Exception as e:  # noqa
728            record_error(e)
729
730    writers = [threading.Thread(target=writer, name=f"predict-writer-{i}") for i in range(num_write_workers)]
731    consumers = [threading.Thread(target=consumer, args=(i,), name=f"predict-consumer-{i}")
732                 for i in range(n_consumers)]
733    producers = [threading.Thread(target=producer, name=f"predict-producer-{i}")
734                 for i in range(num_prefetch_workers)]
735    threads = writers + consumers + producers
736
737    try:
738        for t in writers:
739            t.start()
740        for t in consumers:
741            t.start()
742        for t in producers:
743            t.start()
744
745        for t in producers:
746            t.join()
747        for t in consumers:
748            t.join()
749        for t in writers:
750            t.join()
751    finally:
752        stop_event.set()
753        for t in threads:
754            t.join()
755        pbar.close()
756
757    if error_box:
758        raise error_box[0]
759
760    return output

Run block-wise network prediction with a halo, pipelined for higher GPU throughput.

This is an alternate implementation of predict_with_halo that decouples block loading, GPU prediction and output writing into a producer-consumer pipeline connected by queues:

producers (CPU threads: load + preprocess) -> input queue
  -> consumer(s), one per GPU (stack a batch, predict, unstack) -> output queue
  -> writer thread(s) (postprocess + write).

While the GPU works on one batch, the prefetch workers load and preprocess the next blocks and the writer drains finished predictions, keeping the GPU fed. Blocks can additionally be stacked into batches for one forward pass via batch_size.

The pipeline is thread-based (not multiprocessing) so that lazy hdf5/zarr/n5 inputs (whose file handles are not fork/pickle-safe) work, and so that writers can share the output array directly. Note that heavy Python-level preprocess/postprocess callbacks will not parallelize across prefetch workers due to the GIL; the default standardize is numpy-vectorized and releases the GIL.

Arguments:
  • input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar.
  • model: The network.
  • gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass ["cpu"]. One prediction consumer thread (with its own model replica) is run per device.
  • block_shape: The shape of the inner block to use for prediction.
  • halo: The shape of the halo to use for prediction.
  • output: The output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
  • preprocess: Function to preprocess input data before passing it to the network.
  • postprocess: Function to postprocess the network predictions.
  • with_channels: Whether the input has a channel axis.
  • skip_block: Function to evaluate whether a given input block will be skipped.
  • mask: Elements outside the mask will be ignored in the prediction.
  • disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
  • tqdm_desc: Description shown by the tqdm output.
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures. It must operate on the leading batch axis; with the default batch_size=1 it does not need changes.
  • roi: A region of interest of the input for which to run prediction.
  • iter_list: Optional list of block ids to iterate over.
  • batch_size: The number of blocks stacked into a single forward pass. Trades GPU memory for throughput.
  • num_prefetch_workers: The number of CPU threads used to load and preprocess blocks.
  • queue_size: The maximum size of the input (prefetch) queue. Provides backpressure to bound memory use. If None, a value derived from the number of devices and batch_size is used.
  • num_write_workers: The number of threads used to write predictions. Values > 1 are safe for in-memory numpy outputs, and for zarr/n5 outputs whose chunks (and shards, for zarr v3) are aligned with block_shape. For hdf5, misaligned zarr/n5, or other outputs this is automatically clamped to 1.
  • write_queue_size: The maximum size of the output (write) queue. If None, a default is used.
  • grid_shift: Not supported by this function; raises NotImplementedError if passed. Use predict_with_halo.
Returns:

The model output.