torch_em.util.prediction
1import queue 2import threading 3import warnings 4from copy import deepcopy 5from concurrent import futures 6from typing import Tuple, Union, Callable, Any, List, Optional 7 8import numpy as np 9import bioimage_cpp as bic 10import torch 11from numpy.typing import ArrayLike 12 13try: 14 from napari.utils import progress as tqdm 15except ImportError: 16 from tqdm import tqdm 17 18from ..transform.raw import standardize 19 20 21def predict_with_padding( 22 model: torch.nn.Module, 23 input_: np.ndarray, 24 min_divisible: Tuple[int, ...], 25 device: Optional[Union[torch.device, str]] = None, 26 with_channels: bool = False, 27 prediction_function: Callable[[Any], Any] = None 28) -> np.ndarray: 29 """Run prediction with padding for a model that can only deal with inputs divisible by specific factors. 30 31 Args: 32 model: The model. 33 input_: The input for prediction. 34 min_divisible: The minimal factors the input shape must be divisible by. 35 For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels. 36 device: The device of the model. If not given, will be derived from the model parameters. 37 with_channels: Whether the input data contains channels. 38 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 39 40 Returns: 41 np.ndarray: The ouptut of the model. 42 """ 43 if with_channels: 44 assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}" 45 min_divisible_ = (1,) + min_divisible 46 else: 47 assert len(min_divisible) == input_.ndim 48 min_divisible_ = min_divisible 49 50 if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)): 51 pad_width = tuple( 52 (0, 0 if sh % md == 0 else md - sh % md) 53 for sh, md in zip(input_.shape, min_divisible_) 54 ) 55 crop_padding = tuple(slice(0, sh) for sh in input_.shape) 56 input_ = np.pad(input_, pad_width, mode="reflect") 57 else: 58 crop_padding = None 59 60 ndim = input_.ndim 61 ndim_model = 1 + ndim if with_channels else 2 + ndim 62 63 if device is None: 64 device = next(model.parameters()).device 65 66 expand_dim = (None,) * (ndim_model - ndim) 67 with torch.no_grad(): 68 model_input = torch.from_numpy(input_[expand_dim]).to(device) 69 output = model(model_input) if prediction_function is None else prediction_function(model, model_input) 70 output = output.cpu().numpy() 71 72 if crop_padding is not None: 73 crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding 74 output = output[crop_padding] 75 76 return output 77 78 79def _pad_for_shift_left(arr, pad_vox, with_channels, mode="constant", constant_values=0.0): 80 pad_left = tuple(pad_vox) 81 pad_right = tuple(0 for _ in pad_vox) 82 83 pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right)) 84 if with_channels: 85 pad_width = ((0, 0),) + pad_width 86 87 arr_pad = np.pad(arr, pad_width, mode=mode, constant_values=constant_values) 88 return arr_pad, pad_left 89 90 91def _crop_after_shift_left(arr, pad_left, with_channels, original_shape_spatial): 92 starts = pad_left 93 stops = tuple(st + sh for st, sh in zip(starts, original_shape_spatial)) 94 spatial_slices = tuple(slice(st, sp) for st, sp in zip(starts, stops)) 95 return arr[(slice(None),) + spatial_slices] if with_channels else arr[spatial_slices] 96 97 98def _load_block(input_, offset, block_shape, halo, padding_mode="reflect", with_channels=False): 99 shape = input_.shape 100 if with_channels: 101 shape = shape[1:] 102 103 starts = [off - ha for off, ha in zip(offset, halo)] 104 stops = [off + bs + ha for off, bs, ha in zip(offset, block_shape, halo)] 105 106 # we pad the input volume if necessary 107 pad_left = None 108 pad_right = None 109 110 # check for padding to the left 111 if any(start < 0 for start in starts): 112 pad_left = tuple(abs(start) if start < 0 else 0 for start in starts) 113 starts = [max(0, start) for start in starts] 114 115 # check for padding to the right 116 if any(stop > shape[i] for i, stop in enumerate(stops)): 117 pad_right = tuple(stop - shape[i] if stop > shape[i] else 0 for i, stop in enumerate(stops)) 118 stops = [min(shape[i], stop) for i, stop in enumerate(stops)] 119 120 bb = tuple(slice(start, stop) for start, stop in zip(starts, stops)) 121 if with_channels: 122 data = input_[(slice(None),) + bb] 123 else: 124 data = input_[bb] 125 126 ndim = len(shape) 127 # pad if necessary 128 if pad_left is not None or pad_right is not None: 129 pad_left = (0,) * ndim if pad_left is None else pad_left 130 pad_right = (0,) * ndim if pad_right is None else pad_right 131 pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right)) 132 if with_channels: 133 pad_width = ((0, 0),) + pad_width 134 data = np.pad(data, pad_width, mode=padding_mode) 135 136 # extend the bounding box for downstream 137 bb = tuple( 138 slice(b.start - pl, b.stop + pr) 139 for b, pl, pr in zip(bb, pad_left, pad_right) 140 ) 141 142 return data, bb 143 144 145def predict_with_halo( 146 input_: ArrayLike, 147 model: torch.nn.Module, 148 gpu_ids: List[Union[str, int]], 149 block_shape: Tuple[int, ...], 150 halo: Tuple[int, ...], 151 output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None, 152 preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize, 153 postprocess: Callable[[np.ndarray], np.ndarray] = None, 154 with_channels: bool = False, 155 skip_block: Callable[[Any], bool] = None, 156 mask: Optional[ArrayLike] = None, 157 disable_tqdm: bool = False, 158 tqdm_desc: str = "predict with halo", 159 prediction_function: Optional[Callable] = None, 160 roi: Optional[Tuple[slice]] = None, 161 iter_list: Optional[List[int]] = None, 162 grid_shift: Optional[Tuple[float, ...]] = None, 163) -> ArrayLike: 164 """Run block-wise network prediction with a halo. 165 166 Args: 167 input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar 168 model: The network. 169 gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`. 170 block_shape: The shape of the inner block to use for prediction. 171 halo: The shape of the halo to use for prediction 172 output: The output data, will be allocated if None is passed. 173 Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel. 174 preprocess: Function to preprocess input data before passing it to the network. 175 postprocess: Function to postprocess the network predictions. 176 with_channels: Whether the input has a channel axis. 177 skip_block: Function to evaluate whether a given input block will be skipped. 178 mask: Elements outside the mask will be ignored in the prediction. 179 disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times). 180 tqdm_desc: Fescription shown by the tqdm output. 181 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 182 roi: A region of interest of the input for which to run prediction. 183 iter_list: Optional list of block ids to iterate over. 184 grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0). 185 Returns: 186 The model output. 187 """ 188 devices = [torch.device(gpu) for gpu in gpu_ids] 189 models = [ 190 (model if next(model.parameters()).device == device else deepcopy(model).to(device), device) 191 for device in devices 192 ] 193 n_workers = len(gpu_ids) 194 195 # original shape (spatial only) 196 shape0 = input_.shape 197 shape_spatial0 = shape0[1:] if with_channels else shape0 198 ndim = len(shape_spatial0) 199 assert len(block_shape) == len(halo) == ndim 200 201 # apply grid_shift via padding+cropping (zero padding) 202 input_eff = input_ 203 mask_eff = mask 204 205 if grid_shift is not None: 206 assert len(grid_shift) == ndim, "grid_shift must match number of spatial dims" 207 pad_vox = tuple(int(np.rint(abs(gs) * bs)) for gs, bs in zip(grid_shift, block_shape)) 208 209 if not isinstance(input_eff, np.ndarray): 210 raise TypeError("grid_shift padding currently requires input_ to be a numpy array") 211 212 input_eff, pad_left = _pad_for_shift_left(input_eff, pad_vox, with_channels=with_channels, mode="constant", 213 constant_values=0) 214 215 if mask_eff is not None: 216 if not isinstance(mask_eff, np.ndarray): 217 raise TypeError("grid_shift padding currently requires mask to be a numpy array") 218 mask_eff, _ = _pad_for_shift_left(mask_eff, pad_vox, with_channels=False, mode="constant", 219 constant_values=0) 220 else: 221 pad_left = (0,) * ndim 222 223 # shapes after shift-padding 224 shape_eff = input_eff.shape 225 shape_spatial_eff = shape_eff[1:] if with_channels else shape_eff 226 227 # blocking (on the padded input) 228 if roi is None: 229 blocking = bic.utils.Blocking([0] * ndim, list(shape_spatial_eff), block_shape) 230 else: 231 assert len(roi) == ndim 232 blocking_start = [0 if ro.start is None else ro.start for ro in roi] 233 blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial_eff)] 234 blocking = bic.utils.Blocking(blocking_start, blocking_stop, block_shape) 235 236 # output allocation (for padded shape) 237 if output is None: 238 n_out = models[0][0].out_channels 239 output = np.zeros((n_out,) + tuple(shape_spatial_eff), dtype="float32") 240 elif grid_shift: 241 raise ValueError( 242 "grid_shift is not supported together with a user-provided `output`, because " 243 "grid_shift requires internal zero-padding and a final cropping step. " 244 "Pass `output=None` (let this function allocate the output) or disable `grid_shift`. " 245 "Or pad the input manually beforehand." 246 ) 247 248 def predict_block(block_id): 249 worker_id = block_id % n_workers 250 net, device = models[worker_id] 251 252 with torch.no_grad(): 253 block = blocking.get_block(block_id) 254 offset = [beg for beg in block.begin] 255 inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) 256 257 if mask_eff is not None: 258 mask_block, _ = _load_block(mask_eff, offset, block_shape, halo, with_channels=False) 259 mask_block = mask_block[inner_bb].astype("bool") 260 if mask_block.sum() == 0: 261 return 262 263 inp, _ = _load_block(input_eff, offset, block_shape, halo, with_channels=with_channels) 264 265 if skip_block is not None and skip_block(inp): 266 return 267 268 if preprocess is not None: 269 inp = preprocess(inp) 270 271 # add (channel) and batch axis 272 expand_dims = np.s_[None] if with_channels else np.s_[None, None] 273 inp = torch.from_numpy(inp[expand_dims]).to(device) 274 275 prediction = net(inp) if prediction_function is None else prediction_function(net, inp) 276 277 # allow for list of tensors 278 try: 279 prediction = prediction.cpu().numpy().squeeze(0) 280 except AttributeError: 281 prediction = prediction[0] 282 prediction = prediction.cpu().numpy().squeeze(0) 283 284 if postprocess is not None: 285 prediction = postprocess(prediction) 286 287 if prediction.ndim == ndim + 1: 288 inner_bb_pred = (slice(None),) + inner_bb 289 else: 290 inner_bb_pred = inner_bb 291 prediction = prediction[inner_bb_pred] 292 293 if mask_eff is not None: 294 if prediction.ndim == ndim + 1: 295 mb = np.broadcast_to(mask_block[None], prediction.shape) 296 else: 297 mb = mask_block 298 prediction[~mb] = 0 299 300 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) 301 if isinstance(output, list): # we have multiple outputs and split the prediction channels 302 for out, channel_slice in output: 303 this_bb = bb if out.ndim == ndim else (slice(None),) + bb 304 out[this_bb] = prediction[channel_slice] 305 else: # we only have a single output array 306 if output.ndim == ndim + 1: 307 bb = (slice(None),) + bb 308 output[bb] = prediction 309 310 n_blocks = blocking.number_of_blocks 311 iteration_ids = range(n_blocks) if iter_list is None else np.array(iter_list) 312 313 with futures.ThreadPoolExecutor(n_workers) as tp: 314 list(tqdm(tp.map(predict_block, iteration_ids), 315 total=len(iteration_ids), 316 disable=disable_tqdm, 317 desc=tqdm_desc)) 318 319 # crop away the shift padding so the returned output matches original shape 320 if grid_shift is not None: 321 output = _crop_after_shift_left(output, pad_left, with_channels=(output.ndim == ndim+1), 322 original_shape_spatial=tuple(shape_spatial0)) 323 324 return output 325 326 327# Sentinel returned by _prepare_block_input when a block should be skipped. 328_SKIP = object() 329# Sentinel pushed onto the pipeline queues to signal end-of-stream. 330_STOP = object() 331 332 333class _Aborted(Exception): 334 """@private 335 Raised inside worker threads to unwind cleanly once `stop_event` is set.""" 336 337 338class _AtomicCounter: 339 """@private 340 Lock-guarded integer counter used for sentinel reference counting.""" 341 342 def __init__(self, value: int): 343 self._value = value 344 self._lock = threading.Lock() 345 346 def decrement(self) -> int: 347 with self._lock: 348 self._value -= 1 349 return self._value 350 351 352class _BlockJob: 353 """@private 354 A unit of work travelling through the prediction pipeline.""" 355 356 __slots__ = ("block", "inner_bb", "mask_block", "tensor", "prediction") 357 358 def __init__(self, block, inner_bb, mask_block, tensor): 359 self.block = block 360 self.inner_bb = inner_bb 361 self.mask_block = mask_block 362 self.tensor = tensor # CPU tensor [1, (C,) *spatial]; cleared after prediction 363 self.prediction = None # filled by the consumer 364 365 366def _safe_get(q, stop_event, timeout=0.2): 367 """@private 368 Queue.get that aborts (raises _Aborted) once stop_event is set, to avoid deadlocks.""" 369 while not stop_event.is_set(): 370 try: 371 return q.get(timeout=timeout) 372 except queue.Empty: 373 continue 374 raise _Aborted() 375 376 377def _safe_put(q, item, stop_event, timeout=0.2): 378 """@private 379 Queue.put that aborts (raises _Aborted) once stop_event is set, to avoid deadlocks.""" 380 while not stop_event.is_set(): 381 try: 382 q.put(item, timeout=timeout) 383 return 384 except queue.Full: 385 continue 386 raise _Aborted() 387 388 389def _prepare_block_input(input_, mask, block, block_shape, halo, with_channels, skip_block, preprocess): 390 """@private 391 Producer-side block preparation: load + (optional) mask/skip check + preprocess. 392 393 Returns the `_SKIP` sentinel if the block should be skipped, otherwise a tuple 394 `(cpu_tensor, mask_block, inner_bb)` where `cpu_tensor` has a leading batch axis. 395 """ 396 offset = [beg for beg in block.begin] 397 inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) 398 399 mask_block = None 400 if mask is not None: 401 mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False) 402 mask_block = mask_block[inner_bb].astype("bool") 403 if mask_block.sum() == 0: 404 return _SKIP 405 406 inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels) 407 408 if skip_block is not None and skip_block(inp): 409 return _SKIP 410 411 if preprocess is not None: 412 inp = preprocess(inp) 413 414 # add (channel) and batch axis -> [1, (C,) *spatial] 415 expand_dims = np.s_[None] if with_channels else np.s_[None, None] 416 tensor = torch.from_numpy(inp[expand_dims]) 417 return tensor, mask_block, inner_bb 418 419 420def _write_prediction(prediction, block, output, ndim, mask_block, inner_bb, postprocess): 421 """@private 422 Writer-side logic: postprocess + inner crop + mask-zero + write to `output`.""" 423 if postprocess is not None: 424 prediction = postprocess(prediction) 425 426 if prediction.ndim == ndim + 1: 427 inner_bb_pred = (slice(None),) + inner_bb 428 else: 429 inner_bb_pred = inner_bb 430 prediction = prediction[inner_bb_pred] 431 432 if mask_block is not None: 433 if prediction.ndim == ndim + 1: 434 mb = np.broadcast_to(mask_block[None], prediction.shape) 435 else: 436 mb = mask_block 437 prediction[~mb] = 0 438 439 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) 440 if isinstance(output, list): # multiple outputs: split the prediction channels 441 for out, channel_slice in output: 442 this_bb = bb if out.ndim == ndim else (slice(None),) + bb 443 out[this_bb] = prediction[channel_slice] 444 else: # single output array 445 if output.ndim == ndim + 1: 446 bb = (slice(None),) + bb 447 output[bb] = prediction 448 449 450def _concurrent_write_safe(arr, block_shape, start): 451 """@private 452 Whether `arr` can be written to from multiple threads concurrently for the given block grid. 453 454 - numpy arrays are always safe (writers touch disjoint in-memory regions). 455 - hdf5 datasets are never safe (h5py is not thread-safe for concurrent writes). 456 - zarr / n5 datasets are safe iff the block grid is aligned with the chunks (and the shards, 457 for zarr v3), so that each chunk/shard is written by exactly one block. 458 - unknown / unchunked backends are treated conservatively as unsafe. 459 """ 460 if isinstance(arr, np.ndarray): 461 return True 462 463 module = type(arr).__module__ 464 if module.startswith("h5py"): 465 return False 466 467 chunks = getattr(arr, "chunks", None) 468 if chunks is None: # unknown backend or unchunked -> be conservative 469 return False 470 471 # zarr v3 exposes the shard shape via .shards (None if not sharded); the shard is the atomic 472 # write unit when present, otherwise the chunk is. getattr covers z5py / older zarr (no shards). 473 shards = getattr(arr, "shards", None) 474 unit = shards if shards is not None else chunks 475 476 # compare only the spatial axes: every block writes the full channel range at a disjoint 477 # spatial bounding box, so channel-axis chunking never causes a write conflict. 478 ndim = len(block_shape) 479 unit_spatial = tuple(unit[-ndim:]) 480 if any(bs % u != 0 for bs, u in zip(block_shape, unit_spatial)): 481 return False 482 if any(s % u != 0 for s, u in zip(start, unit_spatial)): 483 return False 484 return True 485 486 487def predict_with_halo_pipelined( 488 input_: ArrayLike, 489 model: torch.nn.Module, 490 gpu_ids: List[Union[str, int]], 491 block_shape: Tuple[int, ...], 492 halo: Tuple[int, ...], 493 output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None, 494 preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize, 495 postprocess: Callable[[np.ndarray], np.ndarray] = None, 496 with_channels: bool = False, 497 skip_block: Callable[[Any], bool] = None, 498 mask: Optional[ArrayLike] = None, 499 disable_tqdm: bool = False, 500 tqdm_desc: str = "predict with halo (pipelined)", 501 prediction_function: Optional[Callable] = None, 502 roi: Optional[Tuple[slice]] = None, 503 iter_list: Optional[List[int]] = None, 504 batch_size: int = 1, 505 num_prefetch_workers: int = 4, 506 queue_size: Optional[int] = None, 507 num_write_workers: int = 1, 508 write_queue_size: Optional[int] = None, 509 grid_shift: Optional[Tuple[float, ...]] = None, 510) -> ArrayLike: 511 """Run block-wise network prediction with a halo, pipelined for higher GPU throughput. 512 513 This is an alternate implementation of `predict_with_halo` that decouples block 514 loading, GPU prediction and output writing into a producer-consumer pipeline 515 connected by queues: 516 517 producers (CPU threads: load + preprocess) -> input queue 518 -> consumer(s), one per GPU (stack a batch, predict, unstack) -> output queue 519 -> writer thread(s) (postprocess + write). 520 521 While the GPU works on one batch, the prefetch workers load and preprocess the 522 next blocks and the writer drains finished predictions, keeping the GPU fed. 523 Blocks can additionally be stacked into batches for one forward pass via `batch_size`. 524 525 The pipeline is thread-based (not multiprocessing) so that lazy hdf5/zarr/n5 inputs 526 (whose file handles are not fork/pickle-safe) work, and so that writers can share the 527 output array directly. Note that heavy *Python-level* `preprocess`/`postprocess` 528 callbacks will not parallelize across prefetch workers due to the GIL; the default 529 `standardize` is numpy-vectorized and releases the GIL. 530 531 Args: 532 input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar. 533 model: The network. 534 gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`. 535 One prediction consumer thread (with its own model replica) is run per device. 536 block_shape: The shape of the inner block to use for prediction. 537 halo: The shape of the halo to use for prediction. 538 output: The output data, will be allocated if None is passed. 539 Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel. 540 preprocess: Function to preprocess input data before passing it to the network. 541 postprocess: Function to postprocess the network predictions. 542 with_channels: Whether the input has a channel axis. 543 skip_block: Function to evaluate whether a given input block will be skipped. 544 mask: Elements outside the mask will be ignored in the prediction. 545 disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times). 546 tqdm_desc: Description shown by the tqdm output. 547 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 548 It must operate on the leading batch axis; with the default `batch_size=1` it does not need changes. 549 roi: A region of interest of the input for which to run prediction. 550 iter_list: Optional list of block ids to iterate over. 551 batch_size: The number of blocks stacked into a single forward pass. Trades GPU memory for throughput. 552 num_prefetch_workers: The number of CPU threads used to load and preprocess blocks. 553 queue_size: The maximum size of the input (prefetch) queue. Provides backpressure to bound memory use. 554 If None, a value derived from the number of devices and `batch_size` is used. 555 num_write_workers: The number of threads used to write predictions. Values > 1 are safe for in-memory 556 numpy outputs, and for zarr/n5 outputs whose chunks (and shards, for zarr v3) are aligned with 557 block_shape. For hdf5, misaligned zarr/n5, or other outputs this is automatically clamped to 1. 558 write_queue_size: The maximum size of the output (write) queue. If None, a default is used. 559 grid_shift: Not supported by this function; raises NotImplementedError if passed. Use `predict_with_halo`. 560 561 Returns: 562 The model output. 563 """ 564 if grid_shift is not None: 565 raise NotImplementedError( 566 "grid_shift is not supported by predict_with_halo_pipelined. " 567 "Use predict_with_halo for grid_shift, or pre-pad the input and use roi." 568 ) 569 570 batch_size = max(1, int(batch_size)) 571 num_prefetch_workers = max(1, int(num_prefetch_workers)) 572 num_write_workers = max(1, int(num_write_workers)) 573 574 devices = [torch.device(gpu) for gpu in gpu_ids] 575 models = [ 576 (model if next(model.parameters()).device == device else deepcopy(model).to(device), device) 577 for device in devices 578 ] 579 n_consumers = len(devices) 580 581 shape0 = input_.shape 582 shape_spatial = shape0[1:] if with_channels else shape0 583 ndim = len(shape_spatial) 584 assert len(block_shape) == len(halo) == ndim 585 586 # blocking 587 if roi is None: 588 block_start = [0] * ndim 589 blocking = bic.utils.Blocking(block_start, list(shape_spatial), block_shape) 590 else: 591 assert len(roi) == ndim 592 block_start = [0 if ro.start is None else ro.start for ro in roi] 593 blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial)] 594 blocking = bic.utils.Blocking(block_start, blocking_stop, block_shape) 595 596 # output allocation 597 if output is None: 598 n_out = models[0][0].out_channels 599 output = np.zeros((n_out,) + tuple(shape_spatial), dtype="float32") 600 601 # guard against unsafe concurrent writes: numpy is always safe (disjoint regions), 602 # zarr/n5 are safe when their chunks/shards are aligned with the blocks, hdf5 is not. 603 if num_write_workers > 1: 604 out_arrays = [o for o, _ in output] if isinstance(output, list) else [output] 605 if any(not _concurrent_write_safe(o, block_shape, block_start) for o in out_arrays): 606 warnings.warn( 607 "num_write_workers > 1 requires either an in-memory numpy output or a zarr/n5 " 608 "output whose chunks (and shards, for zarr v3) are aligned with block_shape; " 609 "falling back to a single writer. HDF5 outputs are never safe for concurrent writes." 610 ) 611 num_write_workers = 1 612 613 # queue sizes 614 if queue_size is None: 615 queue_size = max(2 * n_consumers * batch_size, 2 * batch_size) 616 queue_size = max(queue_size, batch_size) 617 if write_queue_size is None: 618 write_queue_size = max(2 * n_consumers, 4) 619 620 n_blocks = blocking.number_of_blocks 621 iteration_ids = list(range(n_blocks)) if iter_list is None else list(iter_list) 622 total = len(iteration_ids) 623 624 # pre-fill the block-id queue with all ids followed by one STOP per producer 625 id_queue = queue.Queue() 626 for bid in iteration_ids: 627 id_queue.put(bid) 628 for _ in range(num_prefetch_workers): 629 id_queue.put(_STOP) 630 631 input_queue = queue.Queue(maxsize=queue_size) 632 output_queue = queue.Queue(maxsize=write_queue_size) 633 634 stop_event = threading.Event() 635 error_box = [] 636 error_lock = threading.Lock() 637 progress_lock = threading.Lock() 638 pbar = tqdm(total=total, disable=disable_tqdm, desc=tqdm_desc) 639 640 remaining_producers = _AtomicCounter(num_prefetch_workers) 641 remaining_consumers = _AtomicCounter(n_consumers) 642 643 def record_error(exc): 644 with error_lock: 645 if not error_box: 646 error_box.append(exc) 647 stop_event.set() 648 649 def producer(): 650 try: 651 while True: 652 bid = id_queue.get() 653 if bid is _STOP or stop_event.is_set(): 654 break 655 block = blocking.get_block(bid) 656 result = _prepare_block_input( 657 input_, mask, block, block_shape, halo, with_channels, skip_block, preprocess 658 ) 659 if result is _SKIP: 660 with progress_lock: 661 pbar.update(1) 662 continue 663 tensor, mask_block, inner_bb = result 664 _safe_put(input_queue, _BlockJob(block, inner_bb, mask_block, tensor), stop_event) 665 except _Aborted: 666 pass 667 except Exception as e: # noqa 668 record_error(e) 669 finally: 670 # the last producer to finish signals the consumers (skipped on the abort path, 671 # where consumers unwind via _safe_get instead) 672 if remaining_producers.decrement() == 0 and not stop_event.is_set(): 673 for _ in range(n_consumers): 674 input_queue.put(_STOP) 675 676 def consumer(worker_id): 677 net, device = models[worker_id] 678 try: 679 while True: 680 jobs = [] 681 got_stop = False 682 while len(jobs) < batch_size: 683 item = _safe_get(input_queue, stop_event) 684 if item is _STOP: 685 got_stop = True 686 break 687 jobs.append(item) 688 689 if jobs: # run (possibly partial) batch 690 batch = torch.cat([job.tensor for job in jobs], dim=0).to(device) 691 with torch.no_grad(): 692 prediction = net(batch) if prediction_function is None \ 693 else prediction_function(net, batch) 694 if not torch.is_tensor(prediction): # list/tuple of outputs -> take the first 695 prediction = prediction[0] 696 prediction = prediction.cpu().numpy() 697 for i, job in enumerate(jobs): 698 job.prediction = np.array(prediction[i]) 699 job.tensor = None 700 _safe_put(output_queue, job, stop_event) 701 702 if got_stop: 703 break 704 except _Aborted: 705 pass 706 except Exception as e: # noqa 707 record_error(e) 708 finally: 709 if remaining_consumers.decrement() == 0 and not stop_event.is_set(): 710 for _ in range(num_write_workers): 711 output_queue.put(_STOP) 712 713 def writer(): 714 try: 715 while True: 716 job = _safe_get(output_queue, stop_event) 717 if job is _STOP: 718 break 719 _write_prediction( 720 job.prediction, job.block, output, ndim, job.mask_block, job.inner_bb, postprocess 721 ) 722 with progress_lock: 723 pbar.update(1) 724 except _Aborted: 725 pass 726 except Exception as e: # noqa 727 record_error(e) 728 729 writers = [threading.Thread(target=writer, name=f"predict-writer-{i}") for i in range(num_write_workers)] 730 consumers = [threading.Thread(target=consumer, args=(i,), name=f"predict-consumer-{i}") 731 for i in range(n_consumers)] 732 producers = [threading.Thread(target=producer, name=f"predict-producer-{i}") 733 for i in range(num_prefetch_workers)] 734 threads = writers + consumers + producers 735 736 try: 737 for t in writers: 738 t.start() 739 for t in consumers: 740 t.start() 741 for t in producers: 742 t.start() 743 744 for t in producers: 745 t.join() 746 for t in consumers: 747 t.join() 748 for t in writers: 749 t.join() 750 finally: 751 stop_event.set() 752 for t in threads: 753 t.join() 754 pbar.close() 755 756 if error_box: 757 raise error_box[0] 758 759 return output
22def predict_with_padding( 23 model: torch.nn.Module, 24 input_: np.ndarray, 25 min_divisible: Tuple[int, ...], 26 device: Optional[Union[torch.device, str]] = None, 27 with_channels: bool = False, 28 prediction_function: Callable[[Any], Any] = None 29) -> np.ndarray: 30 """Run prediction with padding for a model that can only deal with inputs divisible by specific factors. 31 32 Args: 33 model: The model. 34 input_: The input for prediction. 35 min_divisible: The minimal factors the input shape must be divisible by. 36 For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels. 37 device: The device of the model. If not given, will be derived from the model parameters. 38 with_channels: Whether the input data contains channels. 39 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 40 41 Returns: 42 np.ndarray: The ouptut of the model. 43 """ 44 if with_channels: 45 assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}" 46 min_divisible_ = (1,) + min_divisible 47 else: 48 assert len(min_divisible) == input_.ndim 49 min_divisible_ = min_divisible 50 51 if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)): 52 pad_width = tuple( 53 (0, 0 if sh % md == 0 else md - sh % md) 54 for sh, md in zip(input_.shape, min_divisible_) 55 ) 56 crop_padding = tuple(slice(0, sh) for sh in input_.shape) 57 input_ = np.pad(input_, pad_width, mode="reflect") 58 else: 59 crop_padding = None 60 61 ndim = input_.ndim 62 ndim_model = 1 + ndim if with_channels else 2 + ndim 63 64 if device is None: 65 device = next(model.parameters()).device 66 67 expand_dim = (None,) * (ndim_model - ndim) 68 with torch.no_grad(): 69 model_input = torch.from_numpy(input_[expand_dim]).to(device) 70 output = model(model_input) if prediction_function is None else prediction_function(model, model_input) 71 output = output.cpu().numpy() 72 73 if crop_padding is not None: 74 crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding 75 output = output[crop_padding] 76 77 return output
Run prediction with padding for a model that can only deal with inputs divisible by specific factors.
Arguments:
- model: The model.
- input_: The input for prediction.
- min_divisible: The minimal factors the input shape must be divisible by. For example, (16, 16) for a model that needs 2D inputs divisible by at least 16 pixels.
- device: The device of the model. If not given, will be derived from the model parameters.
- with_channels: Whether the input data contains channels.
- prediction_function: A wrapper function for prediction to enable custom prediction procedures.
Returns:
np.ndarray: The ouptut of the model.
146def predict_with_halo( 147 input_: ArrayLike, 148 model: torch.nn.Module, 149 gpu_ids: List[Union[str, int]], 150 block_shape: Tuple[int, ...], 151 halo: Tuple[int, ...], 152 output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None, 153 preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize, 154 postprocess: Callable[[np.ndarray], np.ndarray] = None, 155 with_channels: bool = False, 156 skip_block: Callable[[Any], bool] = None, 157 mask: Optional[ArrayLike] = None, 158 disable_tqdm: bool = False, 159 tqdm_desc: str = "predict with halo", 160 prediction_function: Optional[Callable] = None, 161 roi: Optional[Tuple[slice]] = None, 162 iter_list: Optional[List[int]] = None, 163 grid_shift: Optional[Tuple[float, ...]] = None, 164) -> ArrayLike: 165 """Run block-wise network prediction with a halo. 166 167 Args: 168 input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar 169 model: The network. 170 gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`. 171 block_shape: The shape of the inner block to use for prediction. 172 halo: The shape of the halo to use for prediction 173 output: The output data, will be allocated if None is passed. 174 Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel. 175 preprocess: Function to preprocess input data before passing it to the network. 176 postprocess: Function to postprocess the network predictions. 177 with_channels: Whether the input has a channel axis. 178 skip_block: Function to evaluate whether a given input block will be skipped. 179 mask: Elements outside the mask will be ignored in the prediction. 180 disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times). 181 tqdm_desc: Fescription shown by the tqdm output. 182 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 183 roi: A region of interest of the input for which to run prediction. 184 iter_list: Optional list of block ids to iterate over. 185 grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0). 186 Returns: 187 The model output. 188 """ 189 devices = [torch.device(gpu) for gpu in gpu_ids] 190 models = [ 191 (model if next(model.parameters()).device == device else deepcopy(model).to(device), device) 192 for device in devices 193 ] 194 n_workers = len(gpu_ids) 195 196 # original shape (spatial only) 197 shape0 = input_.shape 198 shape_spatial0 = shape0[1:] if with_channels else shape0 199 ndim = len(shape_spatial0) 200 assert len(block_shape) == len(halo) == ndim 201 202 # apply grid_shift via padding+cropping (zero padding) 203 input_eff = input_ 204 mask_eff = mask 205 206 if grid_shift is not None: 207 assert len(grid_shift) == ndim, "grid_shift must match number of spatial dims" 208 pad_vox = tuple(int(np.rint(abs(gs) * bs)) for gs, bs in zip(grid_shift, block_shape)) 209 210 if not isinstance(input_eff, np.ndarray): 211 raise TypeError("grid_shift padding currently requires input_ to be a numpy array") 212 213 input_eff, pad_left = _pad_for_shift_left(input_eff, pad_vox, with_channels=with_channels, mode="constant", 214 constant_values=0) 215 216 if mask_eff is not None: 217 if not isinstance(mask_eff, np.ndarray): 218 raise TypeError("grid_shift padding currently requires mask to be a numpy array") 219 mask_eff, _ = _pad_for_shift_left(mask_eff, pad_vox, with_channels=False, mode="constant", 220 constant_values=0) 221 else: 222 pad_left = (0,) * ndim 223 224 # shapes after shift-padding 225 shape_eff = input_eff.shape 226 shape_spatial_eff = shape_eff[1:] if with_channels else shape_eff 227 228 # blocking (on the padded input) 229 if roi is None: 230 blocking = bic.utils.Blocking([0] * ndim, list(shape_spatial_eff), block_shape) 231 else: 232 assert len(roi) == ndim 233 blocking_start = [0 if ro.start is None else ro.start for ro in roi] 234 blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial_eff)] 235 blocking = bic.utils.Blocking(blocking_start, blocking_stop, block_shape) 236 237 # output allocation (for padded shape) 238 if output is None: 239 n_out = models[0][0].out_channels 240 output = np.zeros((n_out,) + tuple(shape_spatial_eff), dtype="float32") 241 elif grid_shift: 242 raise ValueError( 243 "grid_shift is not supported together with a user-provided `output`, because " 244 "grid_shift requires internal zero-padding and a final cropping step. " 245 "Pass `output=None` (let this function allocate the output) or disable `grid_shift`. " 246 "Or pad the input manually beforehand." 247 ) 248 249 def predict_block(block_id): 250 worker_id = block_id % n_workers 251 net, device = models[worker_id] 252 253 with torch.no_grad(): 254 block = blocking.get_block(block_id) 255 offset = [beg for beg in block.begin] 256 inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) 257 258 if mask_eff is not None: 259 mask_block, _ = _load_block(mask_eff, offset, block_shape, halo, with_channels=False) 260 mask_block = mask_block[inner_bb].astype("bool") 261 if mask_block.sum() == 0: 262 return 263 264 inp, _ = _load_block(input_eff, offset, block_shape, halo, with_channels=with_channels) 265 266 if skip_block is not None and skip_block(inp): 267 return 268 269 if preprocess is not None: 270 inp = preprocess(inp) 271 272 # add (channel) and batch axis 273 expand_dims = np.s_[None] if with_channels else np.s_[None, None] 274 inp = torch.from_numpy(inp[expand_dims]).to(device) 275 276 prediction = net(inp) if prediction_function is None else prediction_function(net, inp) 277 278 # allow for list of tensors 279 try: 280 prediction = prediction.cpu().numpy().squeeze(0) 281 except AttributeError: 282 prediction = prediction[0] 283 prediction = prediction.cpu().numpy().squeeze(0) 284 285 if postprocess is not None: 286 prediction = postprocess(prediction) 287 288 if prediction.ndim == ndim + 1: 289 inner_bb_pred = (slice(None),) + inner_bb 290 else: 291 inner_bb_pred = inner_bb 292 prediction = prediction[inner_bb_pred] 293 294 if mask_eff is not None: 295 if prediction.ndim == ndim + 1: 296 mb = np.broadcast_to(mask_block[None], prediction.shape) 297 else: 298 mb = mask_block 299 prediction[~mb] = 0 300 301 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) 302 if isinstance(output, list): # we have multiple outputs and split the prediction channels 303 for out, channel_slice in output: 304 this_bb = bb if out.ndim == ndim else (slice(None),) + bb 305 out[this_bb] = prediction[channel_slice] 306 else: # we only have a single output array 307 if output.ndim == ndim + 1: 308 bb = (slice(None),) + bb 309 output[bb] = prediction 310 311 n_blocks = blocking.number_of_blocks 312 iteration_ids = range(n_blocks) if iter_list is None else np.array(iter_list) 313 314 with futures.ThreadPoolExecutor(n_workers) as tp: 315 list(tqdm(tp.map(predict_block, iteration_ids), 316 total=len(iteration_ids), 317 disable=disable_tqdm, 318 desc=tqdm_desc)) 319 320 # crop away the shift padding so the returned output matches original shape 321 if grid_shift is not None: 322 output = _crop_after_shift_left(output, pad_left, with_channels=(output.ndim == ndim+1), 323 original_shape_spatial=tuple(shape_spatial0)) 324 325 return output
Run block-wise network prediction with a halo.
Arguments:
- input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
- model: The network.
- gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass
["cpu"]. - block_shape: The shape of the inner block to use for prediction.
- halo: The shape of the halo to use for prediction
- output: The output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
- preprocess: Function to preprocess input data before passing it to the network.
- postprocess: Function to postprocess the network predictions.
- with_channels: Whether the input has a channel axis.
- skip_block: Function to evaluate whether a given input block will be skipped.
- mask: Elements outside the mask will be ignored in the prediction.
- disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
- tqdm_desc: Fescription shown by the tqdm output.
- prediction_function: A wrapper function for prediction to enable custom prediction procedures.
- roi: A region of interest of the input for which to run prediction.
- iter_list: Optional list of block ids to iterate over.
- grid_shift: Per-axis fractional shift of the grid in units of the block size. E.g. (0, 0.25, 0).
Returns:
The model output.
488def predict_with_halo_pipelined( 489 input_: ArrayLike, 490 model: torch.nn.Module, 491 gpu_ids: List[Union[str, int]], 492 block_shape: Tuple[int, ...], 493 halo: Tuple[int, ...], 494 output: Optional[Union[ArrayLike, List[Tuple[ArrayLike, slice]]]] = None, 495 preprocess: Callable[[Union[torch.Tensor, np.ndarray]], Union[torch.Tensor, np.ndarray]] = standardize, 496 postprocess: Callable[[np.ndarray], np.ndarray] = None, 497 with_channels: bool = False, 498 skip_block: Callable[[Any], bool] = None, 499 mask: Optional[ArrayLike] = None, 500 disable_tqdm: bool = False, 501 tqdm_desc: str = "predict with halo (pipelined)", 502 prediction_function: Optional[Callable] = None, 503 roi: Optional[Tuple[slice]] = None, 504 iter_list: Optional[List[int]] = None, 505 batch_size: int = 1, 506 num_prefetch_workers: int = 4, 507 queue_size: Optional[int] = None, 508 num_write_workers: int = 1, 509 write_queue_size: Optional[int] = None, 510 grid_shift: Optional[Tuple[float, ...]] = None, 511) -> ArrayLike: 512 """Run block-wise network prediction with a halo, pipelined for higher GPU throughput. 513 514 This is an alternate implementation of `predict_with_halo` that decouples block 515 loading, GPU prediction and output writing into a producer-consumer pipeline 516 connected by queues: 517 518 producers (CPU threads: load + preprocess) -> input queue 519 -> consumer(s), one per GPU (stack a batch, predict, unstack) -> output queue 520 -> writer thread(s) (postprocess + write). 521 522 While the GPU works on one batch, the prefetch workers load and preprocess the 523 next blocks and the writer drains finished predictions, keeping the GPU fed. 524 Blocks can additionally be stacked into batches for one forward pass via `batch_size`. 525 526 The pipeline is thread-based (not multiprocessing) so that lazy hdf5/zarr/n5 inputs 527 (whose file handles are not fork/pickle-safe) work, and so that writers can share the 528 output array directly. Note that heavy *Python-level* `preprocess`/`postprocess` 529 callbacks will not parallelize across prefetch workers due to the GIL; the default 530 `standardize` is numpy-vectorized and releases the GIL. 531 532 Args: 533 input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar. 534 model: The network. 535 gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass `["cpu"]`. 536 One prediction consumer thread (with its own model replica) is run per device. 537 block_shape: The shape of the inner block to use for prediction. 538 halo: The shape of the halo to use for prediction. 539 output: The output data, will be allocated if None is passed. 540 Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel. 541 preprocess: Function to preprocess input data before passing it to the network. 542 postprocess: Function to postprocess the network predictions. 543 with_channels: Whether the input has a channel axis. 544 skip_block: Function to evaluate whether a given input block will be skipped. 545 mask: Elements outside the mask will be ignored in the prediction. 546 disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times). 547 tqdm_desc: Description shown by the tqdm output. 548 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 549 It must operate on the leading batch axis; with the default `batch_size=1` it does not need changes. 550 roi: A region of interest of the input for which to run prediction. 551 iter_list: Optional list of block ids to iterate over. 552 batch_size: The number of blocks stacked into a single forward pass. Trades GPU memory for throughput. 553 num_prefetch_workers: The number of CPU threads used to load and preprocess blocks. 554 queue_size: The maximum size of the input (prefetch) queue. Provides backpressure to bound memory use. 555 If None, a value derived from the number of devices and `batch_size` is used. 556 num_write_workers: The number of threads used to write predictions. Values > 1 are safe for in-memory 557 numpy outputs, and for zarr/n5 outputs whose chunks (and shards, for zarr v3) are aligned with 558 block_shape. For hdf5, misaligned zarr/n5, or other outputs this is automatically clamped to 1. 559 write_queue_size: The maximum size of the output (write) queue. If None, a default is used. 560 grid_shift: Not supported by this function; raises NotImplementedError if passed. Use `predict_with_halo`. 561 562 Returns: 563 The model output. 564 """ 565 if grid_shift is not None: 566 raise NotImplementedError( 567 "grid_shift is not supported by predict_with_halo_pipelined. " 568 "Use predict_with_halo for grid_shift, or pre-pad the input and use roi." 569 ) 570 571 batch_size = max(1, int(batch_size)) 572 num_prefetch_workers = max(1, int(num_prefetch_workers)) 573 num_write_workers = max(1, int(num_write_workers)) 574 575 devices = [torch.device(gpu) for gpu in gpu_ids] 576 models = [ 577 (model if next(model.parameters()).device == device else deepcopy(model).to(device), device) 578 for device in devices 579 ] 580 n_consumers = len(devices) 581 582 shape0 = input_.shape 583 shape_spatial = shape0[1:] if with_channels else shape0 584 ndim = len(shape_spatial) 585 assert len(block_shape) == len(halo) == ndim 586 587 # blocking 588 if roi is None: 589 block_start = [0] * ndim 590 blocking = bic.utils.Blocking(block_start, list(shape_spatial), block_shape) 591 else: 592 assert len(roi) == ndim 593 block_start = [0 if ro.start is None else ro.start for ro in roi] 594 blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape_spatial)] 595 blocking = bic.utils.Blocking(block_start, blocking_stop, block_shape) 596 597 # output allocation 598 if output is None: 599 n_out = models[0][0].out_channels 600 output = np.zeros((n_out,) + tuple(shape_spatial), dtype="float32") 601 602 # guard against unsafe concurrent writes: numpy is always safe (disjoint regions), 603 # zarr/n5 are safe when their chunks/shards are aligned with the blocks, hdf5 is not. 604 if num_write_workers > 1: 605 out_arrays = [o for o, _ in output] if isinstance(output, list) else [output] 606 if any(not _concurrent_write_safe(o, block_shape, block_start) for o in out_arrays): 607 warnings.warn( 608 "num_write_workers > 1 requires either an in-memory numpy output or a zarr/n5 " 609 "output whose chunks (and shards, for zarr v3) are aligned with block_shape; " 610 "falling back to a single writer. HDF5 outputs are never safe for concurrent writes." 611 ) 612 num_write_workers = 1 613 614 # queue sizes 615 if queue_size is None: 616 queue_size = max(2 * n_consumers * batch_size, 2 * batch_size) 617 queue_size = max(queue_size, batch_size) 618 if write_queue_size is None: 619 write_queue_size = max(2 * n_consumers, 4) 620 621 n_blocks = blocking.number_of_blocks 622 iteration_ids = list(range(n_blocks)) if iter_list is None else list(iter_list) 623 total = len(iteration_ids) 624 625 # pre-fill the block-id queue with all ids followed by one STOP per producer 626 id_queue = queue.Queue() 627 for bid in iteration_ids: 628 id_queue.put(bid) 629 for _ in range(num_prefetch_workers): 630 id_queue.put(_STOP) 631 632 input_queue = queue.Queue(maxsize=queue_size) 633 output_queue = queue.Queue(maxsize=write_queue_size) 634 635 stop_event = threading.Event() 636 error_box = [] 637 error_lock = threading.Lock() 638 progress_lock = threading.Lock() 639 pbar = tqdm(total=total, disable=disable_tqdm, desc=tqdm_desc) 640 641 remaining_producers = _AtomicCounter(num_prefetch_workers) 642 remaining_consumers = _AtomicCounter(n_consumers) 643 644 def record_error(exc): 645 with error_lock: 646 if not error_box: 647 error_box.append(exc) 648 stop_event.set() 649 650 def producer(): 651 try: 652 while True: 653 bid = id_queue.get() 654 if bid is _STOP or stop_event.is_set(): 655 break 656 block = blocking.get_block(bid) 657 result = _prepare_block_input( 658 input_, mask, block, block_shape, halo, with_channels, skip_block, preprocess 659 ) 660 if result is _SKIP: 661 with progress_lock: 662 pbar.update(1) 663 continue 664 tensor, mask_block, inner_bb = result 665 _safe_put(input_queue, _BlockJob(block, inner_bb, mask_block, tensor), stop_event) 666 except _Aborted: 667 pass 668 except Exception as e: # noqa 669 record_error(e) 670 finally: 671 # the last producer to finish signals the consumers (skipped on the abort path, 672 # where consumers unwind via _safe_get instead) 673 if remaining_producers.decrement() == 0 and not stop_event.is_set(): 674 for _ in range(n_consumers): 675 input_queue.put(_STOP) 676 677 def consumer(worker_id): 678 net, device = models[worker_id] 679 try: 680 while True: 681 jobs = [] 682 got_stop = False 683 while len(jobs) < batch_size: 684 item = _safe_get(input_queue, stop_event) 685 if item is _STOP: 686 got_stop = True 687 break 688 jobs.append(item) 689 690 if jobs: # run (possibly partial) batch 691 batch = torch.cat([job.tensor for job in jobs], dim=0).to(device) 692 with torch.no_grad(): 693 prediction = net(batch) if prediction_function is None \ 694 else prediction_function(net, batch) 695 if not torch.is_tensor(prediction): # list/tuple of outputs -> take the first 696 prediction = prediction[0] 697 prediction = prediction.cpu().numpy() 698 for i, job in enumerate(jobs): 699 job.prediction = np.array(prediction[i]) 700 job.tensor = None 701 _safe_put(output_queue, job, stop_event) 702 703 if got_stop: 704 break 705 except _Aborted: 706 pass 707 except Exception as e: # noqa 708 record_error(e) 709 finally: 710 if remaining_consumers.decrement() == 0 and not stop_event.is_set(): 711 for _ in range(num_write_workers): 712 output_queue.put(_STOP) 713 714 def writer(): 715 try: 716 while True: 717 job = _safe_get(output_queue, stop_event) 718 if job is _STOP: 719 break 720 _write_prediction( 721 job.prediction, job.block, output, ndim, job.mask_block, job.inner_bb, postprocess 722 ) 723 with progress_lock: 724 pbar.update(1) 725 except _Aborted: 726 pass 727 except Exception as e: # noqa 728 record_error(e) 729 730 writers = [threading.Thread(target=writer, name=f"predict-writer-{i}") for i in range(num_write_workers)] 731 consumers = [threading.Thread(target=consumer, args=(i,), name=f"predict-consumer-{i}") 732 for i in range(n_consumers)] 733 producers = [threading.Thread(target=producer, name=f"predict-producer-{i}") 734 for i in range(num_prefetch_workers)] 735 threads = writers + consumers + producers 736 737 try: 738 for t in writers: 739 t.start() 740 for t in consumers: 741 t.start() 742 for t in producers: 743 t.start() 744 745 for t in producers: 746 t.join() 747 for t in consumers: 748 t.join() 749 for t in writers: 750 t.join() 751 finally: 752 stop_event.set() 753 for t in threads: 754 t.join() 755 pbar.close() 756 757 if error_box: 758 raise error_box[0] 759 760 return output
Run block-wise network prediction with a halo, pipelined for higher GPU throughput.
This is an alternate implementation of predict_with_halo that decouples block
loading, GPU prediction and output writing into a producer-consumer pipeline
connected by queues:
producers (CPU threads: load + preprocess) -> input queue
-> consumer(s), one per GPU (stack a batch, predict, unstack) -> output queue
-> writer thread(s) (postprocess + write).
While the GPU works on one batch, the prefetch workers load and preprocess the
next blocks and the writer drains finished predictions, keeping the GPU fed.
Blocks can additionally be stacked into batches for one forward pass via batch_size.
The pipeline is thread-based (not multiprocessing) so that lazy hdf5/zarr/n5 inputs
(whose file handles are not fork/pickle-safe) work, and so that writers can share the
output array directly. Note that heavy Python-level preprocess/postprocess
callbacks will not parallelize across prefetch workers due to the GIL; the default
standardize is numpy-vectorized and releases the GIL.
Arguments:
- input_: The input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar.
- model: The network.
- gpu_ids: List of device ids to use for prediction. To run prediction on the CPU, pass
["cpu"]. One prediction consumer thread (with its own model replica) is run per device. - block_shape: The shape of the inner block to use for prediction.
- halo: The shape of the halo to use for prediction.
- output: The output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and a slice for the corresponding channel.
- preprocess: Function to preprocess input data before passing it to the network.
- postprocess: Function to postprocess the network predictions.
- with_channels: Whether the input has a channel axis.
- skip_block: Function to evaluate whether a given input block will be skipped.
- mask: Elements outside the mask will be ignored in the prediction.
- disable_tqdm: Flag that allows to disable tqdm output (e.g. if function is called multiple times).
- tqdm_desc: Description shown by the tqdm output.
- prediction_function: A wrapper function for prediction to enable custom prediction procedures.
It must operate on the leading batch axis; with the default
batch_size=1it does not need changes. - roi: A region of interest of the input for which to run prediction.
- iter_list: Optional list of block ids to iterate over.
- batch_size: The number of blocks stacked into a single forward pass. Trades GPU memory for throughput.
- num_prefetch_workers: The number of CPU threads used to load and preprocess blocks.
- queue_size: The maximum size of the input (prefetch) queue. Provides backpressure to bound memory use.
If None, a value derived from the number of devices and
batch_sizeis used. - num_write_workers: The number of threads used to write predictions. Values > 1 are safe for in-memory numpy outputs, and for zarr/n5 outputs whose chunks (and shards, for zarr v3) are aligned with block_shape. For hdf5, misaligned zarr/n5, or other outputs this is automatically clamped to 1.
- write_queue_size: The maximum size of the output (write) queue. If None, a default is used.
- grid_shift: Not supported by this function; raises NotImplementedError if passed. Use
predict_with_halo.
Returns:
The model output.