torch_em.util.memory
Helper functionality for empirically determining the largest batch size or patch shape that fits into GPU memory for a given network.
The maximum is found by running a forward pass (prediction) with dummy data and increasing the respective parameter until the GPU runs out of memory, using an exponential bracketing followed by a binary search to converge in a logarithmic number of steps.
1"""Helper functionality for empirically determining the largest batch size or patch shape 2that fits into GPU memory for a given network. 3 4The maximum is found by running a forward pass (prediction) with dummy data and increasing 5the respective parameter until the GPU runs out of memory, using an exponential bracketing 6followed by a binary search to converge in a logarithmic number of steps. 7""" 8 9import gc 10import warnings 11from typing import Any, Callable, Optional, Tuple, Union 12 13import torch 14 15 16def _is_oom_error(exc: BaseException) -> bool: 17 """Check whether an exception was raised because the GPU ran out of memory.""" 18 oom_type = getattr(torch.cuda, "OutOfMemoryError", None) 19 if oom_type is not None and isinstance(exc, oom_type): 20 return True 21 return isinstance(exc, RuntimeError) and "out of memory" in str(exc).lower() 22 23 24def _resolve_device(model: torch.nn.Module, device: Optional[Union[torch.device, str]]) -> torch.device: 25 """Resolve the device and ensure it is a CUDA device (OOM cannot be detected on CPU).""" 26 if device is None: 27 device = next(model.parameters()).device 28 device = torch.device(device) 29 if device.type != "cuda" or not torch.cuda.is_available(): 30 raise RuntimeError( 31 "compute_max_batch_size and compute_max_patch_shape require a CUDA device, because running out " 32 f"of GPU memory is used as the termination signal, but got a model / device of type '{device.type}'. " 33 "Move the model to a GPU and / or pass device='cuda'." 34 ) 35 return device 36 37 38def _resolve_in_channels(model: torch.nn.Module, in_channels: Optional[int]) -> int: 39 """Resolve the number of input channels, falling back to the model's `in_channels` attribute.""" 40 if in_channels is None: 41 in_channels = getattr(model, "in_channels", None) 42 if in_channels is None: 43 raise ValueError( 44 "Could not determine the number of input channels from the model. Please pass 'in_channels' explicitly." 45 ) 46 return int(in_channels) 47 48 49def _resolve_min_divisible( 50 model: torch.nn.Module, ndim: int, min_divisible: Optional[Tuple[int, ...]] 51) -> Tuple[int, ...]: 52 """Resolve the per-axis divisibility constraint, defaulting to (2 ** depth,) * ndim for U-Nets.""" 53 if min_divisible is None: 54 depth = getattr(model, "depth", None) 55 factor = 2 ** depth if depth is not None else 1 56 min_divisible = (factor,) * ndim 57 if len(min_divisible) != ndim: 58 raise ValueError(f"min_divisible {min_divisible} does not match the number of dimensions {ndim}.") 59 return tuple(int(d) for d in min_divisible) 60 61 62def _attempt_forward( 63 model: torch.nn.Module, 64 device: torch.device, 65 dtype: torch.dtype, 66 in_channels: int, 67 batch_size: int, 68 patch_shape: Tuple[int, ...], 69 prediction_function: Optional[Callable[[Any], Any]], 70) -> bool: 71 """Run a single forward pass with dummy data and report whether it fit into memory. 72 73 Returns True if the forward pass succeeded, False if the GPU ran out of memory. 74 Any other (non-OOM) exception is re-raised, since it indicates a genuine problem 75 such as an invalid patch shape rather than a memory limit. 76 """ 77 inp = out = None 78 try: 79 inp = torch.empty((batch_size, in_channels, *patch_shape), dtype=dtype, device=device).normal_() 80 with torch.no_grad(): 81 out = model(inp) if prediction_function is None else prediction_function(model, inp) 82 torch.cuda.synchronize(device) # Surface asynchronous / lazy OOM errors here instead of later. 83 return True 84 except Exception as exc: 85 if _is_oom_error(exc): 86 return False 87 raise 88 finally: 89 del inp, out 90 gc.collect() 91 torch.cuda.empty_cache() 92 93 94def _search_max_int(fits: Callable[[int], bool], upper_bound: int) -> int: 95 """Find the largest integer in [1, upper_bound] for which `fits` returns True. 96 97 Assumes monotonicity: if a value fits, every smaller value fits as well. First an upper 98 bracket is found by doubling, then the exact value is determined by binary search. 99 The return value equals `upper_bound` exactly if the bound was reached without a failure. 100 """ 101 if upper_bound < 1: 102 raise ValueError(f"upper_bound must be a positive integer, got {upper_bound}.") 103 if not fits(1): 104 raise RuntimeError( 105 "The model does not fit into memory even for the smallest configuration (batch size 1 / the smallest " 106 "patch shape). Reduce the patch shape, use a smaller model or run on a device with more memory." 107 ) 108 109 # Phase 1: exponential bracketing - double the candidate until it fails or exceeds the bound. 110 last_ok = 1 111 candidate = 2 112 while candidate <= upper_bound and fits(candidate): 113 last_ok = candidate 114 candidate *= 2 115 116 if candidate > upper_bound: 117 # We never observed a failure within the bound. Check the bound itself. 118 if last_ok == upper_bound or fits(upper_bound): 119 return upper_bound 120 first_fail = upper_bound 121 else: 122 # The loop terminated because `fits(candidate)` returned False. 123 first_fail = candidate 124 125 # Phase 2: binary search in the open interval (last_ok, first_fail). 126 lo, hi = last_ok, first_fail 127 while hi - lo > 1: 128 mid = (lo + hi) // 2 129 if fits(mid): 130 lo = mid 131 else: 132 hi = mid 133 return lo 134 135 136def compute_max_batch_size( 137 model: torch.nn.Module, 138 patch_shape: Tuple[int, ...], 139 in_channels: Optional[int] = None, 140 device: Optional[Union[torch.device, str]] = None, 141 dtype: torch.dtype = torch.float32, 142 safety_factor: float = 0.9, 143 max_batch_size: int = 1024, 144 prediction_function: Optional[Callable[[Any], Any]] = None, 145) -> int: 146 """Empirically determine the largest batch size that fits into GPU memory for a fixed patch shape. 147 148 The batch size is increased (forward pass with dummy data, exponential bracketing followed by 149 binary search) until the GPU runs out of memory. This requires a CUDA device, since running out 150 of memory is used as the termination signal. 151 152 Args: 153 model: The model. 154 patch_shape: The spatial shape of a single sample, without batch or channel axis, 155 e.g. (512, 512) for 2D or (64, 128, 128) for 3D. 156 in_channels: The number of input channels. By default this is derived from the model's 157 'in_channels' attribute. 158 device: The device of the model. If not given, will be derived from the model parameters. 159 Must be a CUDA device. 160 dtype: The data type of the dummy input data. 161 safety_factor: Factor in (0, 1] applied to the empirically determined maximum, to stay clear 162 of the out-of-memory boundary. The returned batch size is at least one. 163 max_batch_size: The upper bound for the search. If the model does not run out of memory at this 164 batch size, it is returned and a warning is issued (the true maximum may be larger). 165 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 166 167 Returns: 168 The maximum batch size. 169 """ 170 device = _resolve_device(model, device) 171 in_channels = _resolve_in_channels(model, in_channels) 172 173 was_training = model.training 174 model.eval() 175 try: 176 def fits(batch_size: int) -> bool: 177 return _attempt_forward( 178 model, device, dtype, in_channels, batch_size, patch_shape, prediction_function 179 ) 180 181 max_fitting = _search_max_int(fits, max_batch_size) 182 finally: 183 model.train(was_training) 184 185 if max_fitting >= max_batch_size: 186 warnings.warn( 187 f"The batch size search reached the upper bound 'max_batch_size'={max_batch_size} without running " 188 "out of memory. The true maximum may be larger; increase 'max_batch_size' to search further." 189 ) 190 return max_fitting 191 192 return max(1, int(max_fitting * safety_factor)) 193 194 195def compute_max_patch_shape( 196 model: torch.nn.Module, 197 ndim: int, 198 batch_size: int = 1, 199 min_divisible: Optional[Tuple[int, ...]] = None, 200 in_channels: Optional[int] = None, 201 device: Optional[Union[torch.device, str]] = None, 202 dtype: torch.dtype = torch.float32, 203 safety_factor: float = 0.9, 204 max_scale_factor: int = 128, 205 prediction_function: Optional[Callable[[Any], Any]] = None, 206) -> Tuple[int, ...]: 207 """Empirically determine the largest patch shape that fits into GPU memory for a fixed batch size. 208 209 The patch shape is grown isotropically as integer multiples of `min_divisible`, i.e. the candidate 210 shapes are `(k * d0, k * d1, ...)` for increasing `k`, so that the network's divisibility constraints 211 are always satisfied. The multiplier is increased (forward pass with dummy data, exponential bracketing 212 followed by binary search) until the GPU runs out of memory. This requires a CUDA device, since running 213 out of memory is used as the termination signal. 214 215 Args: 216 model: The model. 217 ndim: The number of spatial dimensions, i.e. 2 for a 2D and 3 for a 3D model. 218 batch_size: The (fixed) batch size to use for the search. 219 min_divisible: The factors each spatial axis must be divisible by, which also define the smallest 220 patch shape and the increment of the search. By default this is derived from the model's 'depth' 221 attribute as (2 ** depth,) * ndim (the constraint for a U-Net), falling back to (1,) * ndim. 222 in_channels: The number of input channels. By default this is derived from the model's 223 'in_channels' attribute. 224 device: The device of the model. If not given, will be derived from the model parameters. 225 Must be a CUDA device. 226 dtype: The data type of the dummy input data. 227 safety_factor: Factor in (0, 1] applied to the empirically determined maximum multiplier, to stay 228 clear of the out-of-memory boundary. The returned patch shape is at least 'min_divisible'. 229 max_scale_factor: The upper bound for the multiplier search. If the model does not run out of memory 230 at this multiplier, the corresponding patch shape is returned and a warning is issued. 231 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 232 233 Returns: 234 The maximum patch shape, as a tuple of length 'ndim'. 235 """ 236 min_divisible = _resolve_min_divisible(model, ndim, min_divisible) 237 device = _resolve_device(model, device) 238 in_channels = _resolve_in_channels(model, in_channels) 239 240 def scale(k: int) -> Tuple[int, ...]: 241 return tuple(k * d for d in min_divisible) 242 243 was_training = model.training 244 model.eval() 245 try: 246 def fits(k: int) -> bool: 247 return _attempt_forward( 248 model, device, dtype, in_channels, batch_size, scale(k), prediction_function 249 ) 250 251 max_fitting = _search_max_int(fits, max_scale_factor) 252 finally: 253 model.train(was_training) 254 255 if max_fitting >= max_scale_factor: 256 warnings.warn( 257 f"The patch shape search reached the upper bound 'max_scale_factor'={max_scale_factor} without " 258 "running out of memory. The true maximum may be larger; increase 'max_scale_factor' to search further." 259 ) 260 return scale(max_fitting) 261 262 return scale(max(1, int(max_fitting * safety_factor)))
137def compute_max_batch_size( 138 model: torch.nn.Module, 139 patch_shape: Tuple[int, ...], 140 in_channels: Optional[int] = None, 141 device: Optional[Union[torch.device, str]] = None, 142 dtype: torch.dtype = torch.float32, 143 safety_factor: float = 0.9, 144 max_batch_size: int = 1024, 145 prediction_function: Optional[Callable[[Any], Any]] = None, 146) -> int: 147 """Empirically determine the largest batch size that fits into GPU memory for a fixed patch shape. 148 149 The batch size is increased (forward pass with dummy data, exponential bracketing followed by 150 binary search) until the GPU runs out of memory. This requires a CUDA device, since running out 151 of memory is used as the termination signal. 152 153 Args: 154 model: The model. 155 patch_shape: The spatial shape of a single sample, without batch or channel axis, 156 e.g. (512, 512) for 2D or (64, 128, 128) for 3D. 157 in_channels: The number of input channels. By default this is derived from the model's 158 'in_channels' attribute. 159 device: The device of the model. If not given, will be derived from the model parameters. 160 Must be a CUDA device. 161 dtype: The data type of the dummy input data. 162 safety_factor: Factor in (0, 1] applied to the empirically determined maximum, to stay clear 163 of the out-of-memory boundary. The returned batch size is at least one. 164 max_batch_size: The upper bound for the search. If the model does not run out of memory at this 165 batch size, it is returned and a warning is issued (the true maximum may be larger). 166 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 167 168 Returns: 169 The maximum batch size. 170 """ 171 device = _resolve_device(model, device) 172 in_channels = _resolve_in_channels(model, in_channels) 173 174 was_training = model.training 175 model.eval() 176 try: 177 def fits(batch_size: int) -> bool: 178 return _attempt_forward( 179 model, device, dtype, in_channels, batch_size, patch_shape, prediction_function 180 ) 181 182 max_fitting = _search_max_int(fits, max_batch_size) 183 finally: 184 model.train(was_training) 185 186 if max_fitting >= max_batch_size: 187 warnings.warn( 188 f"The batch size search reached the upper bound 'max_batch_size'={max_batch_size} without running " 189 "out of memory. The true maximum may be larger; increase 'max_batch_size' to search further." 190 ) 191 return max_fitting 192 193 return max(1, int(max_fitting * safety_factor))
Empirically determine the largest batch size that fits into GPU memory for a fixed patch shape.
The batch size is increased (forward pass with dummy data, exponential bracketing followed by binary search) until the GPU runs out of memory. This requires a CUDA device, since running out of memory is used as the termination signal.
Arguments:
- model: The model.
- patch_shape: The spatial shape of a single sample, without batch or channel axis, e.g. (512, 512) for 2D or (64, 128, 128) for 3D.
- in_channels: The number of input channels. By default this is derived from the model's 'in_channels' attribute.
- device: The device of the model. If not given, will be derived from the model parameters. Must be a CUDA device.
- dtype: The data type of the dummy input data.
- safety_factor: Factor in (0, 1] applied to the empirically determined maximum, to stay clear of the out-of-memory boundary. The returned batch size is at least one.
- max_batch_size: The upper bound for the search. If the model does not run out of memory at this batch size, it is returned and a warning is issued (the true maximum may be larger).
- prediction_function: A wrapper function for prediction to enable custom prediction procedures.
Returns:
The maximum batch size.
196def compute_max_patch_shape( 197 model: torch.nn.Module, 198 ndim: int, 199 batch_size: int = 1, 200 min_divisible: Optional[Tuple[int, ...]] = None, 201 in_channels: Optional[int] = None, 202 device: Optional[Union[torch.device, str]] = None, 203 dtype: torch.dtype = torch.float32, 204 safety_factor: float = 0.9, 205 max_scale_factor: int = 128, 206 prediction_function: Optional[Callable[[Any], Any]] = None, 207) -> Tuple[int, ...]: 208 """Empirically determine the largest patch shape that fits into GPU memory for a fixed batch size. 209 210 The patch shape is grown isotropically as integer multiples of `min_divisible`, i.e. the candidate 211 shapes are `(k * d0, k * d1, ...)` for increasing `k`, so that the network's divisibility constraints 212 are always satisfied. The multiplier is increased (forward pass with dummy data, exponential bracketing 213 followed by binary search) until the GPU runs out of memory. This requires a CUDA device, since running 214 out of memory is used as the termination signal. 215 216 Args: 217 model: The model. 218 ndim: The number of spatial dimensions, i.e. 2 for a 2D and 3 for a 3D model. 219 batch_size: The (fixed) batch size to use for the search. 220 min_divisible: The factors each spatial axis must be divisible by, which also define the smallest 221 patch shape and the increment of the search. By default this is derived from the model's 'depth' 222 attribute as (2 ** depth,) * ndim (the constraint for a U-Net), falling back to (1,) * ndim. 223 in_channels: The number of input channels. By default this is derived from the model's 224 'in_channels' attribute. 225 device: The device of the model. If not given, will be derived from the model parameters. 226 Must be a CUDA device. 227 dtype: The data type of the dummy input data. 228 safety_factor: Factor in (0, 1] applied to the empirically determined maximum multiplier, to stay 229 clear of the out-of-memory boundary. The returned patch shape is at least 'min_divisible'. 230 max_scale_factor: The upper bound for the multiplier search. If the model does not run out of memory 231 at this multiplier, the corresponding patch shape is returned and a warning is issued. 232 prediction_function: A wrapper function for prediction to enable custom prediction procedures. 233 234 Returns: 235 The maximum patch shape, as a tuple of length 'ndim'. 236 """ 237 min_divisible = _resolve_min_divisible(model, ndim, min_divisible) 238 device = _resolve_device(model, device) 239 in_channels = _resolve_in_channels(model, in_channels) 240 241 def scale(k: int) -> Tuple[int, ...]: 242 return tuple(k * d for d in min_divisible) 243 244 was_training = model.training 245 model.eval() 246 try: 247 def fits(k: int) -> bool: 248 return _attempt_forward( 249 model, device, dtype, in_channels, batch_size, scale(k), prediction_function 250 ) 251 252 max_fitting = _search_max_int(fits, max_scale_factor) 253 finally: 254 model.train(was_training) 255 256 if max_fitting >= max_scale_factor: 257 warnings.warn( 258 f"The patch shape search reached the upper bound 'max_scale_factor'={max_scale_factor} without " 259 "running out of memory. The true maximum may be larger; increase 'max_scale_factor' to search further." 260 ) 261 return scale(max_fitting) 262 263 return scale(max(1, int(max_fitting * safety_factor)))
Empirically determine the largest patch shape that fits into GPU memory for a fixed batch size.
The patch shape is grown isotropically as integer multiples of min_divisible, i.e. the candidate
shapes are (k * d0, k * d1, ...) for increasing k, so that the network's divisibility constraints
are always satisfied. The multiplier is increased (forward pass with dummy data, exponential bracketing
followed by binary search) until the GPU runs out of memory. This requires a CUDA device, since running
out of memory is used as the termination signal.
Arguments:
- model: The model.
- ndim: The number of spatial dimensions, i.e. 2 for a 2D and 3 for a 3D model.
- batch_size: The (fixed) batch size to use for the search.
- min_divisible: The factors each spatial axis must be divisible by, which also define the smallest patch shape and the increment of the search. By default this is derived from the model's 'depth' attribute as (2 ** depth,) * ndim (the constraint for a U-Net), falling back to (1,) * ndim.
- in_channels: The number of input channels. By default this is derived from the model's 'in_channels' attribute.
- device: The device of the model. If not given, will be derived from the model parameters. Must be a CUDA device.
- dtype: The data type of the dummy input data.
- safety_factor: Factor in (0, 1] applied to the empirically determined maximum multiplier, to stay clear of the out-of-memory boundary. The returned patch shape is at least 'min_divisible'.
- max_scale_factor: The upper bound for the multiplier search. If the model does not run out of memory at this multiplier, the corresponding patch shape is returned and a warning is issued.
- prediction_function: A wrapper function for prediction to enable custom prediction procedures.
Returns:
The maximum patch shape, as a tuple of length 'ndim'.