torch_em.util.memory

Helper functionality for empirically determining the largest batch size or patch shape that fits into GPU memory for a given network.

The maximum is found by running a forward pass (prediction) with dummy data and increasing the respective parameter until the GPU runs out of memory, using an exponential bracketing followed by a binary search to converge in a logarithmic number of steps.

  1"""Helper functionality for empirically determining the largest batch size or patch shape
  2that fits into GPU memory for a given network.
  3
  4The maximum is found by running a forward pass (prediction) with dummy data and increasing
  5the respective parameter until the GPU runs out of memory, using an exponential bracketing
  6followed by a binary search to converge in a logarithmic number of steps.
  7"""
  8
  9import gc
 10import warnings
 11from typing import Any, Callable, Optional, Tuple, Union
 12
 13import torch
 14
 15
 16def _is_oom_error(exc: BaseException) -> bool:
 17    """Check whether an exception was raised because the GPU ran out of memory."""
 18    oom_type = getattr(torch.cuda, "OutOfMemoryError", None)
 19    if oom_type is not None and isinstance(exc, oom_type):
 20        return True
 21    return isinstance(exc, RuntimeError) and "out of memory" in str(exc).lower()
 22
 23
 24def _resolve_device(model: torch.nn.Module, device: Optional[Union[torch.device, str]]) -> torch.device:
 25    """Resolve the device and ensure it is a CUDA device (OOM cannot be detected on CPU)."""
 26    if device is None:
 27        device = next(model.parameters()).device
 28    device = torch.device(device)
 29    if device.type != "cuda" or not torch.cuda.is_available():
 30        raise RuntimeError(
 31            "compute_max_batch_size and compute_max_patch_shape require a CUDA device, because running out "
 32            f"of GPU memory is used as the termination signal, but got a model / device of type '{device.type}'. "
 33            "Move the model to a GPU and / or pass device='cuda'."
 34        )
 35    return device
 36
 37
 38def _resolve_in_channels(model: torch.nn.Module, in_channels: Optional[int]) -> int:
 39    """Resolve the number of input channels, falling back to the model's `in_channels` attribute."""
 40    if in_channels is None:
 41        in_channels = getattr(model, "in_channels", None)
 42    if in_channels is None:
 43        raise ValueError(
 44            "Could not determine the number of input channels from the model. Please pass 'in_channels' explicitly."
 45        )
 46    return int(in_channels)
 47
 48
 49def _resolve_min_divisible(
 50    model: torch.nn.Module, ndim: int, min_divisible: Optional[Tuple[int, ...]]
 51) -> Tuple[int, ...]:
 52    """Resolve the per-axis divisibility constraint, defaulting to (2 ** depth,) * ndim for U-Nets."""
 53    if min_divisible is None:
 54        depth = getattr(model, "depth", None)
 55        factor = 2 ** depth if depth is not None else 1
 56        min_divisible = (factor,) * ndim
 57    if len(min_divisible) != ndim:
 58        raise ValueError(f"min_divisible {min_divisible} does not match the number of dimensions {ndim}.")
 59    return tuple(int(d) for d in min_divisible)
 60
 61
 62def _attempt_forward(
 63    model: torch.nn.Module,
 64    device: torch.device,
 65    dtype: torch.dtype,
 66    in_channels: int,
 67    batch_size: int,
 68    patch_shape: Tuple[int, ...],
 69    prediction_function: Optional[Callable[[Any], Any]],
 70) -> bool:
 71    """Run a single forward pass with dummy data and report whether it fit into memory.
 72
 73    Returns True if the forward pass succeeded, False if the GPU ran out of memory.
 74    Any other (non-OOM) exception is re-raised, since it indicates a genuine problem
 75    such as an invalid patch shape rather than a memory limit.
 76    """
 77    inp = out = None
 78    try:
 79        inp = torch.empty((batch_size, in_channels, *patch_shape), dtype=dtype, device=device).normal_()
 80        with torch.no_grad():
 81            out = model(inp) if prediction_function is None else prediction_function(model, inp)
 82        torch.cuda.synchronize(device)  # Surface asynchronous / lazy OOM errors here instead of later.
 83        return True
 84    except Exception as exc:
 85        if _is_oom_error(exc):
 86            return False
 87        raise
 88    finally:
 89        del inp, out
 90        gc.collect()
 91        torch.cuda.empty_cache()
 92
 93
 94def _search_max_int(fits: Callable[[int], bool], upper_bound: int) -> int:
 95    """Find the largest integer in [1, upper_bound] for which `fits` returns True.
 96
 97    Assumes monotonicity: if a value fits, every smaller value fits as well. First an upper
 98    bracket is found by doubling, then the exact value is determined by binary search.
 99    The return value equals `upper_bound` exactly if the bound was reached without a failure.
100    """
101    if upper_bound < 1:
102        raise ValueError(f"upper_bound must be a positive integer, got {upper_bound}.")
103    if not fits(1):
104        raise RuntimeError(
105            "The model does not fit into memory even for the smallest configuration (batch size 1 / the smallest "
106            "patch shape). Reduce the patch shape, use a smaller model or run on a device with more memory."
107        )
108
109    # Phase 1: exponential bracketing - double the candidate until it fails or exceeds the bound.
110    last_ok = 1
111    candidate = 2
112    while candidate <= upper_bound and fits(candidate):
113        last_ok = candidate
114        candidate *= 2
115
116    if candidate > upper_bound:
117        # We never observed a failure within the bound. Check the bound itself.
118        if last_ok == upper_bound or fits(upper_bound):
119            return upper_bound
120        first_fail = upper_bound
121    else:
122        # The loop terminated because `fits(candidate)` returned False.
123        first_fail = candidate
124
125    # Phase 2: binary search in the open interval (last_ok, first_fail).
126    lo, hi = last_ok, first_fail
127    while hi - lo > 1:
128        mid = (lo + hi) // 2
129        if fits(mid):
130            lo = mid
131        else:
132            hi = mid
133    return lo
134
135
136def compute_max_batch_size(
137    model: torch.nn.Module,
138    patch_shape: Tuple[int, ...],
139    in_channels: Optional[int] = None,
140    device: Optional[Union[torch.device, str]] = None,
141    dtype: torch.dtype = torch.float32,
142    safety_factor: float = 0.9,
143    max_batch_size: int = 1024,
144    prediction_function: Optional[Callable[[Any], Any]] = None,
145) -> int:
146    """Empirically determine the largest batch size that fits into GPU memory for a fixed patch shape.
147
148    The batch size is increased (forward pass with dummy data, exponential bracketing followed by
149    binary search) until the GPU runs out of memory. This requires a CUDA device, since running out
150    of memory is used as the termination signal.
151
152    Args:
153        model: The model.
154        patch_shape: The spatial shape of a single sample, without batch or channel axis,
155            e.g. (512, 512) for 2D or (64, 128, 128) for 3D.
156        in_channels: The number of input channels. By default this is derived from the model's
157            'in_channels' attribute.
158        device: The device of the model. If not given, will be derived from the model parameters.
159            Must be a CUDA device.
160        dtype: The data type of the dummy input data.
161        safety_factor: Factor in (0, 1] applied to the empirically determined maximum, to stay clear
162            of the out-of-memory boundary. The returned batch size is at least one.
163        max_batch_size: The upper bound for the search. If the model does not run out of memory at this
164            batch size, it is returned and a warning is issued (the true maximum may be larger).
165        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
166
167    Returns:
168        The maximum batch size.
169    """
170    device = _resolve_device(model, device)
171    in_channels = _resolve_in_channels(model, in_channels)
172
173    was_training = model.training
174    model.eval()
175    try:
176        def fits(batch_size: int) -> bool:
177            return _attempt_forward(
178                model, device, dtype, in_channels, batch_size, patch_shape, prediction_function
179            )
180
181        max_fitting = _search_max_int(fits, max_batch_size)
182    finally:
183        model.train(was_training)
184
185    if max_fitting >= max_batch_size:
186        warnings.warn(
187            f"The batch size search reached the upper bound 'max_batch_size'={max_batch_size} without running "
188            "out of memory. The true maximum may be larger; increase 'max_batch_size' to search further."
189        )
190        return max_fitting
191
192    return max(1, int(max_fitting * safety_factor))
193
194
195def compute_max_patch_shape(
196    model: torch.nn.Module,
197    ndim: int,
198    batch_size: int = 1,
199    min_divisible: Optional[Tuple[int, ...]] = None,
200    in_channels: Optional[int] = None,
201    device: Optional[Union[torch.device, str]] = None,
202    dtype: torch.dtype = torch.float32,
203    safety_factor: float = 0.9,
204    max_scale_factor: int = 128,
205    prediction_function: Optional[Callable[[Any], Any]] = None,
206) -> Tuple[int, ...]:
207    """Empirically determine the largest patch shape that fits into GPU memory for a fixed batch size.
208
209    The patch shape is grown isotropically as integer multiples of `min_divisible`, i.e. the candidate
210    shapes are `(k * d0, k * d1, ...)` for increasing `k`, so that the network's divisibility constraints
211    are always satisfied. The multiplier is increased (forward pass with dummy data, exponential bracketing
212    followed by binary search) until the GPU runs out of memory. This requires a CUDA device, since running
213    out of memory is used as the termination signal.
214
215    Args:
216        model: The model.
217        ndim: The number of spatial dimensions, i.e. 2 for a 2D and 3 for a 3D model.
218        batch_size: The (fixed) batch size to use for the search.
219        min_divisible: The factors each spatial axis must be divisible by, which also define the smallest
220            patch shape and the increment of the search. By default this is derived from the model's 'depth'
221            attribute as (2 ** depth,) * ndim (the constraint for a U-Net), falling back to (1,) * ndim.
222        in_channels: The number of input channels. By default this is derived from the model's
223            'in_channels' attribute.
224        device: The device of the model. If not given, will be derived from the model parameters.
225            Must be a CUDA device.
226        dtype: The data type of the dummy input data.
227        safety_factor: Factor in (0, 1] applied to the empirically determined maximum multiplier, to stay
228            clear of the out-of-memory boundary. The returned patch shape is at least 'min_divisible'.
229        max_scale_factor: The upper bound for the multiplier search. If the model does not run out of memory
230            at this multiplier, the corresponding patch shape is returned and a warning is issued.
231        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
232
233    Returns:
234        The maximum patch shape, as a tuple of length 'ndim'.
235    """
236    min_divisible = _resolve_min_divisible(model, ndim, min_divisible)
237    device = _resolve_device(model, device)
238    in_channels = _resolve_in_channels(model, in_channels)
239
240    def scale(k: int) -> Tuple[int, ...]:
241        return tuple(k * d for d in min_divisible)
242
243    was_training = model.training
244    model.eval()
245    try:
246        def fits(k: int) -> bool:
247            return _attempt_forward(
248                model, device, dtype, in_channels, batch_size, scale(k), prediction_function
249            )
250
251        max_fitting = _search_max_int(fits, max_scale_factor)
252    finally:
253        model.train(was_training)
254
255    if max_fitting >= max_scale_factor:
256        warnings.warn(
257            f"The patch shape search reached the upper bound 'max_scale_factor'={max_scale_factor} without "
258            "running out of memory. The true maximum may be larger; increase 'max_scale_factor' to search further."
259        )
260        return scale(max_fitting)
261
262    return scale(max(1, int(max_fitting * safety_factor)))
def compute_max_batch_size( model: torch.nn.modules.module.Module, patch_shape: Tuple[int, ...], in_channels: Optional[int] = None, device: Union[torch.device, str, NoneType] = None, dtype: torch.dtype = torch.float32, safety_factor: float = 0.9, max_batch_size: int = 1024, prediction_function: Optional[Callable[[Any], Any]] = None) -> int:
137def compute_max_batch_size(
138    model: torch.nn.Module,
139    patch_shape: Tuple[int, ...],
140    in_channels: Optional[int] = None,
141    device: Optional[Union[torch.device, str]] = None,
142    dtype: torch.dtype = torch.float32,
143    safety_factor: float = 0.9,
144    max_batch_size: int = 1024,
145    prediction_function: Optional[Callable[[Any], Any]] = None,
146) -> int:
147    """Empirically determine the largest batch size that fits into GPU memory for a fixed patch shape.
148
149    The batch size is increased (forward pass with dummy data, exponential bracketing followed by
150    binary search) until the GPU runs out of memory. This requires a CUDA device, since running out
151    of memory is used as the termination signal.
152
153    Args:
154        model: The model.
155        patch_shape: The spatial shape of a single sample, without batch or channel axis,
156            e.g. (512, 512) for 2D or (64, 128, 128) for 3D.
157        in_channels: The number of input channels. By default this is derived from the model's
158            'in_channels' attribute.
159        device: The device of the model. If not given, will be derived from the model parameters.
160            Must be a CUDA device.
161        dtype: The data type of the dummy input data.
162        safety_factor: Factor in (0, 1] applied to the empirically determined maximum, to stay clear
163            of the out-of-memory boundary. The returned batch size is at least one.
164        max_batch_size: The upper bound for the search. If the model does not run out of memory at this
165            batch size, it is returned and a warning is issued (the true maximum may be larger).
166        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
167
168    Returns:
169        The maximum batch size.
170    """
171    device = _resolve_device(model, device)
172    in_channels = _resolve_in_channels(model, in_channels)
173
174    was_training = model.training
175    model.eval()
176    try:
177        def fits(batch_size: int) -> bool:
178            return _attempt_forward(
179                model, device, dtype, in_channels, batch_size, patch_shape, prediction_function
180            )
181
182        max_fitting = _search_max_int(fits, max_batch_size)
183    finally:
184        model.train(was_training)
185
186    if max_fitting >= max_batch_size:
187        warnings.warn(
188            f"The batch size search reached the upper bound 'max_batch_size'={max_batch_size} without running "
189            "out of memory. The true maximum may be larger; increase 'max_batch_size' to search further."
190        )
191        return max_fitting
192
193    return max(1, int(max_fitting * safety_factor))

Empirically determine the largest batch size that fits into GPU memory for a fixed patch shape.

The batch size is increased (forward pass with dummy data, exponential bracketing followed by binary search) until the GPU runs out of memory. This requires a CUDA device, since running out of memory is used as the termination signal.

Arguments:
  • model: The model.
  • patch_shape: The spatial shape of a single sample, without batch or channel axis, e.g. (512, 512) for 2D or (64, 128, 128) for 3D.
  • in_channels: The number of input channels. By default this is derived from the model's 'in_channels' attribute.
  • device: The device of the model. If not given, will be derived from the model parameters. Must be a CUDA device.
  • dtype: The data type of the dummy input data.
  • safety_factor: Factor in (0, 1] applied to the empirically determined maximum, to stay clear of the out-of-memory boundary. The returned batch size is at least one.
  • max_batch_size: The upper bound for the search. If the model does not run out of memory at this batch size, it is returned and a warning is issued (the true maximum may be larger).
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures.
Returns:

The maximum batch size.

def compute_max_patch_shape( model: torch.nn.modules.module.Module, ndim: int, batch_size: int = 1, min_divisible: Optional[Tuple[int, ...]] = None, in_channels: Optional[int] = None, device: Union[torch.device, str, NoneType] = None, dtype: torch.dtype = torch.float32, safety_factor: float = 0.9, max_scale_factor: int = 128, prediction_function: Optional[Callable[[Any], Any]] = None) -> Tuple[int, ...]:
196def compute_max_patch_shape(
197    model: torch.nn.Module,
198    ndim: int,
199    batch_size: int = 1,
200    min_divisible: Optional[Tuple[int, ...]] = None,
201    in_channels: Optional[int] = None,
202    device: Optional[Union[torch.device, str]] = None,
203    dtype: torch.dtype = torch.float32,
204    safety_factor: float = 0.9,
205    max_scale_factor: int = 128,
206    prediction_function: Optional[Callable[[Any], Any]] = None,
207) -> Tuple[int, ...]:
208    """Empirically determine the largest patch shape that fits into GPU memory for a fixed batch size.
209
210    The patch shape is grown isotropically as integer multiples of `min_divisible`, i.e. the candidate
211    shapes are `(k * d0, k * d1, ...)` for increasing `k`, so that the network's divisibility constraints
212    are always satisfied. The multiplier is increased (forward pass with dummy data, exponential bracketing
213    followed by binary search) until the GPU runs out of memory. This requires a CUDA device, since running
214    out of memory is used as the termination signal.
215
216    Args:
217        model: The model.
218        ndim: The number of spatial dimensions, i.e. 2 for a 2D and 3 for a 3D model.
219        batch_size: The (fixed) batch size to use for the search.
220        min_divisible: The factors each spatial axis must be divisible by, which also define the smallest
221            patch shape and the increment of the search. By default this is derived from the model's 'depth'
222            attribute as (2 ** depth,) * ndim (the constraint for a U-Net), falling back to (1,) * ndim.
223        in_channels: The number of input channels. By default this is derived from the model's
224            'in_channels' attribute.
225        device: The device of the model. If not given, will be derived from the model parameters.
226            Must be a CUDA device.
227        dtype: The data type of the dummy input data.
228        safety_factor: Factor in (0, 1] applied to the empirically determined maximum multiplier, to stay
229            clear of the out-of-memory boundary. The returned patch shape is at least 'min_divisible'.
230        max_scale_factor: The upper bound for the multiplier search. If the model does not run out of memory
231            at this multiplier, the corresponding patch shape is returned and a warning is issued.
232        prediction_function: A wrapper function for prediction to enable custom prediction procedures.
233
234    Returns:
235        The maximum patch shape, as a tuple of length 'ndim'.
236    """
237    min_divisible = _resolve_min_divisible(model, ndim, min_divisible)
238    device = _resolve_device(model, device)
239    in_channels = _resolve_in_channels(model, in_channels)
240
241    def scale(k: int) -> Tuple[int, ...]:
242        return tuple(k * d for d in min_divisible)
243
244    was_training = model.training
245    model.eval()
246    try:
247        def fits(k: int) -> bool:
248            return _attempt_forward(
249                model, device, dtype, in_channels, batch_size, scale(k), prediction_function
250            )
251
252        max_fitting = _search_max_int(fits, max_scale_factor)
253    finally:
254        model.train(was_training)
255
256    if max_fitting >= max_scale_factor:
257        warnings.warn(
258            f"The patch shape search reached the upper bound 'max_scale_factor'={max_scale_factor} without "
259            "running out of memory. The true maximum may be larger; increase 'max_scale_factor' to search further."
260        )
261        return scale(max_fitting)
262
263    return scale(max(1, int(max_fitting * safety_factor)))

Empirically determine the largest patch shape that fits into GPU memory for a fixed batch size.

The patch shape is grown isotropically as integer multiples of min_divisible, i.e. the candidate shapes are (k * d0, k * d1, ...) for increasing k, so that the network's divisibility constraints are always satisfied. The multiplier is increased (forward pass with dummy data, exponential bracketing followed by binary search) until the GPU runs out of memory. This requires a CUDA device, since running out of memory is used as the termination signal.

Arguments:
  • model: The model.
  • ndim: The number of spatial dimensions, i.e. 2 for a 2D and 3 for a 3D model.
  • batch_size: The (fixed) batch size to use for the search.
  • min_divisible: The factors each spatial axis must be divisible by, which also define the smallest patch shape and the increment of the search. By default this is derived from the model's 'depth' attribute as (2 ** depth,) * ndim (the constraint for a U-Net), falling back to (1,) * ndim.
  • in_channels: The number of input channels. By default this is derived from the model's 'in_channels' attribute.
  • device: The device of the model. If not given, will be derived from the model parameters. Must be a CUDA device.
  • dtype: The data type of the dummy input data.
  • safety_factor: Factor in (0, 1] applied to the empirically determined maximum multiplier, to stay clear of the out-of-memory boundary. The returned patch shape is at least 'min_divisible'.
  • max_scale_factor: The upper bound for the multiplier search. If the model does not run out of memory at this multiplier, the corresponding patch shape is returned and a warning is issued.
  • prediction_function: A wrapper function for prediction to enable custom prediction procedures.
Returns:

The maximum patch shape, as a tuple of length 'ndim'.