torch_em.util.util
1import os 2import warnings 3from collections import OrderedDict 4from typing import Optional, Tuple, Union 5 6import numpy as np 7import torch 8import torch_em 9from matplotlib import colors 10from numpy.typing import ArrayLike 11 12try: 13 from torch._dynamo.eval_frame import OptimizedModule 14except ImportError: 15 OptimizedModule = None 16 17# torch doesn't support most unsigned types, 18# so we map them to their signed equivalent 19DTYPE_MAP = { 20 np.dtype("uint16"): np.int16, 21 np.dtype("uint32"): np.int32, 22 np.dtype("uint64"): np.int64 23} 24"""@private 25""" 26 27 28# This is a fairly brittle way to check if a module is compiled. 29# Would be good to find a better solution, ideall something like model.is_compiled(). 30def is_compiled(model): 31 """@private 32 """ 33 if OptimizedModule is None: 34 return False 35 return isinstance(model, OptimizedModule) 36 37 38def auto_compile( 39 model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True 40) -> torch.nn.Module: 41 """Automatically compile a model for pytorch >= 2. 42 43 Args: 44 model: The model. 45 compile_model: Whether to comile the model. 46 If None, it will not be compiled for torch < 2, and for torch > 2 the behavior 47 specificed by 'default_compile' will be used. If a string is given it will be 48 intepreted as the compile mode (torch.compile(model, mode=compile_model)) 49 default_compile: Whether to use the default compilation behavior for torch 2. 50 51 Returns: 52 The compiled model. 53 """ 54 torch_major = int(torch.__version__.split(".")[0]) 55 56 if compile_model is None: 57 if torch_major < 2: 58 compile_model = False 59 elif is_compiled(model): # model is already compiled 60 compile_model = False 61 else: 62 compile_model = default_compile 63 64 if compile_model: 65 if torch_major < 2: 66 raise RuntimeError("Model compilation is only supported for pytorch 2") 67 68 print("Compiling pytorch model ...") 69 if isinstance(compile_model, str): 70 model = torch.compile(model, mode=compile_model) 71 else: 72 model = torch.compile(model) 73 74 return model 75 76 77def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor: 78 """Ensure that the input is a torch tensor, by converting it if necessary. 79 80 Args: 81 tensor: The input object, either a torch tensor or a numpy-array like object. 82 dtype: The required data type for the output tensor. 83 84 Returns: 85 The input, converted to a torch tensor if necessary. 86 """ 87 if isinstance(tensor, np.ndarray): 88 if np.dtype(tensor.dtype) in DTYPE_MAP: 89 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 90 # Try to convert the tensor, even if it has wrong byte-order 91 try: 92 tensor = torch.from_numpy(tensor if tensor.flags.writeable else tensor.copy()) 93 except ValueError: 94 tensor = tensor.view(tensor.dtype.newbyteorder()) 95 if np.dtype(tensor.dtype) in DTYPE_MAP: 96 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 97 tensor = torch.from_numpy(tensor) 98 99 assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch" 100 if dtype is not None: 101 tensor = tensor.to(dtype=dtype) 102 return tensor 103 104 105def validate_roi(roi, shape, patch_shape=None): 106 """Normalize an ROI to explicit slices and validate that it is non-empty.""" 107 if roi is None: 108 return None 109 if isinstance(roi, slice): 110 roi = (roi,) 111 if not isinstance(roi, tuple): 112 raise TypeError(f"Invalid roi type: {type(roi)}") 113 if len(roi) > len(shape): 114 raise ValueError(f"Invalid roi {roi} for data shape {shape}: too many dimensions") 115 116 normalized_roi = [] 117 for this_roi, dim in zip(roi, shape): 118 if not isinstance(this_roi, slice): 119 raise TypeError(f"Invalid roi entry: {this_roi}. Only slices are supported") 120 step = 1 if this_roi.step is None else this_roi.step 121 if step != 1: 122 raise ValueError(f"Invalid roi {roi}: slice steps other than 1 are not supported") 123 start, stop, _ = this_roi.indices(dim) 124 normalized_roi.append(slice(start, stop)) 125 126 if len(roi) < len(shape): 127 normalized_roi.extend(slice(0, dim) for dim in shape[len(roi):]) 128 129 roi_shape = tuple(sl.stop - sl.start for sl in normalized_roi) 130 if any(sh <= 0 for sh in roi_shape): 131 msg = f"Invalid roi {roi} for data shape {shape}: it results in an empty region" 132 if patch_shape is not None: 133 msg += f" for patch_shape {patch_shape}" 134 raise ValueError(msg) 135 136 return tuple(normalized_roi) 137 138 139def ensure_tensor_with_channels( 140 tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None 141) -> torch.Tensor: 142 """Ensure that the input is a torch tensor of a given dimensionality with channels. 143 144 Args: 145 tensor: The input tensor or numpy-array like data. 146 ndim: The dimensionality of the output tensor. 147 dtype: The data type of the output tensor. 148 149 Returns: 150 The input converted to a torch tensor of the requested dimensionality. 151 """ 152 assert ndim in (2, 3, 4), f"{ndim}" 153 tensor = ensure_tensor(tensor, dtype) 154 if ndim == 2: 155 assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}" 156 if tensor.ndim == 2: 157 tensor = tensor[None] 158 elif tensor.ndim == 4: 159 assert tensor.shape[0] == 1, f"{tensor.shape}" 160 tensor = tensor[0] 161 elif tensor.ndim == 5: 162 assert tensor.shape[:2] == (1, 1), f"{tensor.shape}" 163 tensor = tensor[0, 0] 164 elif ndim == 3: 165 assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}" 166 if tensor.ndim == 3: 167 tensor = tensor[None] 168 elif tensor.ndim == 5: 169 assert tensor.shape[0] == 1, f"{tensor.shape}" 170 tensor = tensor[0] 171 else: 172 assert tensor.ndim in (4, 5), f"{tensor.ndim}" 173 if tensor.ndim == 5: 174 assert tensor.shape[0] == 1, f"{tensor.shape}" 175 tensor = tensor[0] 176 return tensor 177 178 179def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray: 180 """Ensure that the input is a numpy array, by converting it if necessary. 181 182 Args: 183 array: The input torch tensor or numpy array. 184 dtype: The dtype of the ouptut array. 185 186 Returns: 187 The input converted to a numpy array if necessary. 188 """ 189 if torch.is_tensor(array): 190 array = array.detach().cpu().numpy() 191 assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy" 192 if dtype is not None: 193 array = np.require(array, dtype=dtype) 194 return array 195 196 197def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray: 198 """Ensure that the input is a numpy array of a given dimensionality. 199 200 Args: 201 array: The input numpy array or torch tensor. 202 ndim: The requested dimensionality. 203 dtype: The dtype of the output array. 204 205 Returns: 206 A numpy array of the requested dimensionality and data type. 207 """ 208 assert ndim in (2, 3) 209 array = ensure_array(array, dtype) 210 if ndim == 2: 211 assert array.ndim in (2, 3, 4, 5), str(array.ndim) 212 if array.ndim == 3: 213 assert array.shape[0] == 1 214 array = array[0] 215 elif array.ndim == 4: 216 assert array.shape[:2] == (1, 1) 217 array = array[0, 0] 218 elif array.ndim == 5: 219 assert array.shape[:3] == (1, 1, 1) 220 array = array[0, 0, 0] 221 else: 222 assert array.ndim in (3, 4, 5), str(array.ndim) 223 if array.ndim == 4: 224 assert array.shape[0] == 1, f"{array.shape}" 225 array = array[0] 226 elif array.ndim == 5: 227 assert array.shape[:2] == (1, 1) 228 array = array[0, 0] 229 return array 230 231 232def ensure_patch_shape( 233 raw: np.ndarray, 234 labels: Optional[np.ndarray], 235 patch_shape: Tuple[int, ...], 236 have_raw_channels: bool = False, 237 have_label_channels: bool = False, 238 channel_first: bool = True, 239) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: 240 """Ensure that the raw data and labels have at least the requested patch shape. 241 242 If either raw data or labels do not have the patch shape they will be padded. 243 244 Args: 245 raw: The input raw data. 246 labels: The input labels. 247 patch_shape: The required minimal patch shape. 248 have_raw_channels: Whether the raw data has channels. 249 have_label_channels: Whether the label data has channels. 250 channel_first: Whether the channel axis is the first or last axis. 251 252 Returns: 253 The raw data. 254 The labels. 255 """ 256 raw_shape = raw.shape 257 if labels is not None: 258 labels_shape = labels.shape 259 260 # In case the inputs has channels and they are channels first 261 # IMPORTANT: for ImageCollectionDataset 262 if have_raw_channels and channel_first: 263 raw_shape = raw_shape[1:] 264 265 if have_label_channels and channel_first and labels is not None: 266 labels_shape = labels_shape[1:] 267 268 # Extract the pad_width and pad the raw inputs 269 if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)): 270 pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)] 271 272 if have_raw_channels and channel_first: 273 pad_width = [(0, 0), *pw] 274 elif have_raw_channels and not channel_first: 275 pad_width = [*pw, (0, 0)] 276 else: 277 pad_width = pw 278 279 raw = np.pad(array=raw, pad_width=pad_width) 280 281 # Extract the pad width and pad the label inputs 282 if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)): 283 pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)] 284 285 if have_label_channels and channel_first: 286 pad_width = [(0, 0), *pw] 287 elif have_label_channels and not channel_first: 288 pad_width = [*pw, (0, 0)] 289 else: 290 pad_width = pw 291 292 labels = np.pad(array=labels, pad_width=pad_width) 293 if labels is None: 294 return raw 295 else: 296 return raw, labels 297 298 299def get_constructor_arguments(obj): 300 """@private 301 """ 302 # All relevant torch_em classes have 'init_kwargs' to directly recover the init call. 303 if hasattr(obj, "init_kwargs"): 304 return getattr(obj, "init_kwargs") 305 306 def _get_args(obj, param_names): 307 return {name: getattr(obj, name) for name in param_names} 308 309 # We don't need to find the constructor arguments for optimizers/schedulers because we deserialize the state later. 310 if isinstance( 311 obj, ( 312 torch.optim.Optimizer, 313 torch.optim.lr_scheduler._LRScheduler, 314 # ReduceLROnPlateau does not inherit from _LRScheduler 315 torch.optim.lr_scheduler.ReduceLROnPlateau 316 ) 317 ): 318 return {} 319 320 # recover the arguments for torch dataloader 321 elif isinstance(obj, torch.utils.data.DataLoader): 322 # These are all the "simple" arguements. 323 # "sampler", "batch_sampler" and "worker_init_fn" are more complicated 324 # and generally not used in torch_em 325 sampler = getattr(obj, "sampler", None) 326 if sampler is not None and not isinstance( 327 sampler, 328 ( 329 torch.utils.data.RandomSampler, 330 torch.utils.data.SequentialSampler, 331 torch.utils.data.SubsetRandomSampler, 332 ), 333 ): 334 warnings.warn( 335 f"DataLoader uses sampler {type(sampler).__name__}, but only its effective `shuffle` setting " 336 "is serialized. `DefaultTrainer.from_checkpoint` will recreate the loader without the original " 337 "sampler, so sampling behavior may change." 338 ) 339 shuffle = getattr(obj, "shuffle", None) 340 if shuffle is None: 341 shuffle = getattr(sampler, "shuffle", None) 342 if shuffle is None: 343 # Only randomized samplers map to shuffle=True. SequentialSampler is handled 344 # by the default fallback of shuffle=False and does not need a special case. 345 shuffle = isinstance(sampler, (torch.utils.data.RandomSampler, torch.utils.data.SubsetRandomSampler)) 346 347 return { 348 **_get_args( 349 obj, [ 350 "batch_size", "num_workers", "pin_memory", "drop_last", 351 "persistent_workers", "prefetch_factor", "timeout" 352 ] 353 ), 354 "shuffle": shuffle, 355 } 356 357 # TODO support common torch losses (e.g. CrossEntropy, BCE) 358 warnings.warn( 359 f"Constructor arguments for {type(obj)} cannot be deduced.\n" + 360 "For this object, empty constructor arguments will be used.\n" + 361 "The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'." 362 ) 363 return {} 364 365 366def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None): 367 """Load trainer from a checkpoint. 368 369 Args: 370 checkpoint: The path to the checkpoint. 371 name: The name of the checkpoint. 372 device: The device to use for loading the checkpoint. 373 374 Returns: 375 The trainer. 376 """ 377 # try to load from file 378 if isinstance(checkpoint, str): 379 assert os.path.exists(checkpoint), checkpoint 380 trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device) 381 else: 382 trainer = checkpoint 383 assert isinstance(trainer, torch_em.trainer.DefaultTrainer) 384 return trainer 385 386 387def get_normalizer(trainer): 388 """@private 389 """ 390 dataset = trainer.train_loader.dataset 391 while ( 392 isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset) or 393 isinstance(dataset, torch.utils.data.dataset.ConcatDataset) 394 ): 395 dataset = dataset.datasets[0] 396 397 if isinstance(dataset, torch.utils.data.dataset.Subset): 398 dataset = dataset.dataset 399 400 preprocessor = dataset.raw_transform 401 402 if hasattr(preprocessor, "normalizer"): 403 return preprocessor.normalizer 404 else: 405 return preprocessor 406 407 408def load_model( 409 checkpoint: str, 410 model: Optional[torch.nn.Module] = None, 411 name: str = "best", 412 state_key: Optional[str] = "model_state", 413 device: Optional[str] = None, 414) -> torch.nn.Module: 415 """Load a model from a trainer checkpoint or a serialized torch model. 416 417 This function can either load the model directly (`model` is not passed), 418 or deserialize the model state and then load it (`model` is passed). 419 420 The `checkpoint` argument must either point to the checkpoint directory of a torch_em trainer 421 or to a serialized torch model. 422 423 Args: 424 checkpoint: The path to the checkpoint folder or serialized torch model. 425 model: The model for which the state should be loaded. 426 If it is not passed, the model class and parameters will also be loaded from the trainer. 427 name: The name of the checkpoint. 428 state_key: The name of the model state to load. Set to None if the model state is stored top-level. 429 device: The device on which to load the model. 430 431 Returns: 432 The model. 433 """ 434 if model is None and os.path.isdir(checkpoint): # Load the model and its state from a torch_em checkpoint. 435 model = get_trainer(checkpoint, name=name, device=device).model 436 437 elif model is None: # Load the model from a serialized model. 438 model = torch.load(checkpoint, map_location=device, weights_only=False) 439 440 else: # Load the model state from a checkpoint. 441 if os.path.isdir(checkpoint): # From a torch_em checkpoint. 442 ckpt = os.path.join(checkpoint, f"{name}.pt") 443 else: # From a serialized path. 444 ckpt = checkpoint 445 446 state = torch.load(ckpt, map_location=device, weights_only=False) 447 if state_key is not None: 448 state = state[state_key] 449 450 # To enable loading compiled models. 451 compiled_prefix = "_orig_mod." 452 state = OrderedDict( 453 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()] 454 ) 455 456 model.load_state_dict(state) 457 if device is not None: 458 model.to(device) 459 460 return model 461 462 463def model_is_equal(model1, model2): 464 """@private 465 """ 466 for p1, p2 in zip(model1.parameters(), model2.parameters()): 467 if p1.data.ne(p2.data).sum() > 0: 468 return False 469 return True 470 471 472def get_random_colors(labels: np.ndarray) -> colors.ListedColormap: 473 """Generate a random color map for a label image. 474 475 Args: 476 labels: The labels. 477 478 Returns: 479 The color map. 480 """ 481 unique_labels = np.unique(labels) 482 have_zero = 0 in unique_labels 483 cmap = [[0, 0, 0]] if have_zero else [] 484 cmap += np.random.rand(len(unique_labels), 3).tolist() 485 cmap = colors.ListedColormap(cmap) 486 return cmap
39def auto_compile( 40 model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True 41) -> torch.nn.Module: 42 """Automatically compile a model for pytorch >= 2. 43 44 Args: 45 model: The model. 46 compile_model: Whether to comile the model. 47 If None, it will not be compiled for torch < 2, and for torch > 2 the behavior 48 specificed by 'default_compile' will be used. If a string is given it will be 49 intepreted as the compile mode (torch.compile(model, mode=compile_model)) 50 default_compile: Whether to use the default compilation behavior for torch 2. 51 52 Returns: 53 The compiled model. 54 """ 55 torch_major = int(torch.__version__.split(".")[0]) 56 57 if compile_model is None: 58 if torch_major < 2: 59 compile_model = False 60 elif is_compiled(model): # model is already compiled 61 compile_model = False 62 else: 63 compile_model = default_compile 64 65 if compile_model: 66 if torch_major < 2: 67 raise RuntimeError("Model compilation is only supported for pytorch 2") 68 69 print("Compiling pytorch model ...") 70 if isinstance(compile_model, str): 71 model = torch.compile(model, mode=compile_model) 72 else: 73 model = torch.compile(model) 74 75 return model
Automatically compile a model for pytorch >= 2.
Arguments:
- model: The model.
- compile_model: Whether to comile the model. If None, it will not be compiled for torch < 2, and for torch > 2 the behavior specificed by 'default_compile' will be used. If a string is given it will be intepreted as the compile mode (torch.compile(model, mode=compile_model))
- default_compile: Whether to use the default compilation behavior for torch 2.
Returns:
The compiled model.
78def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor: 79 """Ensure that the input is a torch tensor, by converting it if necessary. 80 81 Args: 82 tensor: The input object, either a torch tensor or a numpy-array like object. 83 dtype: The required data type for the output tensor. 84 85 Returns: 86 The input, converted to a torch tensor if necessary. 87 """ 88 if isinstance(tensor, np.ndarray): 89 if np.dtype(tensor.dtype) in DTYPE_MAP: 90 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 91 # Try to convert the tensor, even if it has wrong byte-order 92 try: 93 tensor = torch.from_numpy(tensor if tensor.flags.writeable else tensor.copy()) 94 except ValueError: 95 tensor = tensor.view(tensor.dtype.newbyteorder()) 96 if np.dtype(tensor.dtype) in DTYPE_MAP: 97 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 98 tensor = torch.from_numpy(tensor) 99 100 assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch" 101 if dtype is not None: 102 tensor = tensor.to(dtype=dtype) 103 return tensor
Ensure that the input is a torch tensor, by converting it if necessary.
Arguments:
- tensor: The input object, either a torch tensor or a numpy-array like object.
- dtype: The required data type for the output tensor.
Returns:
The input, converted to a torch tensor if necessary.
106def validate_roi(roi, shape, patch_shape=None): 107 """Normalize an ROI to explicit slices and validate that it is non-empty.""" 108 if roi is None: 109 return None 110 if isinstance(roi, slice): 111 roi = (roi,) 112 if not isinstance(roi, tuple): 113 raise TypeError(f"Invalid roi type: {type(roi)}") 114 if len(roi) > len(shape): 115 raise ValueError(f"Invalid roi {roi} for data shape {shape}: too many dimensions") 116 117 normalized_roi = [] 118 for this_roi, dim in zip(roi, shape): 119 if not isinstance(this_roi, slice): 120 raise TypeError(f"Invalid roi entry: {this_roi}. Only slices are supported") 121 step = 1 if this_roi.step is None else this_roi.step 122 if step != 1: 123 raise ValueError(f"Invalid roi {roi}: slice steps other than 1 are not supported") 124 start, stop, _ = this_roi.indices(dim) 125 normalized_roi.append(slice(start, stop)) 126 127 if len(roi) < len(shape): 128 normalized_roi.extend(slice(0, dim) for dim in shape[len(roi):]) 129 130 roi_shape = tuple(sl.stop - sl.start for sl in normalized_roi) 131 if any(sh <= 0 for sh in roi_shape): 132 msg = f"Invalid roi {roi} for data shape {shape}: it results in an empty region" 133 if patch_shape is not None: 134 msg += f" for patch_shape {patch_shape}" 135 raise ValueError(msg) 136 137 return tuple(normalized_roi)
Normalize an ROI to explicit slices and validate that it is non-empty.
140def ensure_tensor_with_channels( 141 tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None 142) -> torch.Tensor: 143 """Ensure that the input is a torch tensor of a given dimensionality with channels. 144 145 Args: 146 tensor: The input tensor or numpy-array like data. 147 ndim: The dimensionality of the output tensor. 148 dtype: The data type of the output tensor. 149 150 Returns: 151 The input converted to a torch tensor of the requested dimensionality. 152 """ 153 assert ndim in (2, 3, 4), f"{ndim}" 154 tensor = ensure_tensor(tensor, dtype) 155 if ndim == 2: 156 assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}" 157 if tensor.ndim == 2: 158 tensor = tensor[None] 159 elif tensor.ndim == 4: 160 assert tensor.shape[0] == 1, f"{tensor.shape}" 161 tensor = tensor[0] 162 elif tensor.ndim == 5: 163 assert tensor.shape[:2] == (1, 1), f"{tensor.shape}" 164 tensor = tensor[0, 0] 165 elif ndim == 3: 166 assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}" 167 if tensor.ndim == 3: 168 tensor = tensor[None] 169 elif tensor.ndim == 5: 170 assert tensor.shape[0] == 1, f"{tensor.shape}" 171 tensor = tensor[0] 172 else: 173 assert tensor.ndim in (4, 5), f"{tensor.ndim}" 174 if tensor.ndim == 5: 175 assert tensor.shape[0] == 1, f"{tensor.shape}" 176 tensor = tensor[0] 177 return tensor
Ensure that the input is a torch tensor of a given dimensionality with channels.
Arguments:
- tensor: The input tensor or numpy-array like data.
- ndim: The dimensionality of the output tensor.
- dtype: The data type of the output tensor.
Returns:
The input converted to a torch tensor of the requested dimensionality.
180def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray: 181 """Ensure that the input is a numpy array, by converting it if necessary. 182 183 Args: 184 array: The input torch tensor or numpy array. 185 dtype: The dtype of the ouptut array. 186 187 Returns: 188 The input converted to a numpy array if necessary. 189 """ 190 if torch.is_tensor(array): 191 array = array.detach().cpu().numpy() 192 assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy" 193 if dtype is not None: 194 array = np.require(array, dtype=dtype) 195 return array
Ensure that the input is a numpy array, by converting it if necessary.
Arguments:
- array: The input torch tensor or numpy array.
- dtype: The dtype of the ouptut array.
Returns:
The input converted to a numpy array if necessary.
198def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray: 199 """Ensure that the input is a numpy array of a given dimensionality. 200 201 Args: 202 array: The input numpy array or torch tensor. 203 ndim: The requested dimensionality. 204 dtype: The dtype of the output array. 205 206 Returns: 207 A numpy array of the requested dimensionality and data type. 208 """ 209 assert ndim in (2, 3) 210 array = ensure_array(array, dtype) 211 if ndim == 2: 212 assert array.ndim in (2, 3, 4, 5), str(array.ndim) 213 if array.ndim == 3: 214 assert array.shape[0] == 1 215 array = array[0] 216 elif array.ndim == 4: 217 assert array.shape[:2] == (1, 1) 218 array = array[0, 0] 219 elif array.ndim == 5: 220 assert array.shape[:3] == (1, 1, 1) 221 array = array[0, 0, 0] 222 else: 223 assert array.ndim in (3, 4, 5), str(array.ndim) 224 if array.ndim == 4: 225 assert array.shape[0] == 1, f"{array.shape}" 226 array = array[0] 227 elif array.ndim == 5: 228 assert array.shape[:2] == (1, 1) 229 array = array[0, 0] 230 return array
Ensure that the input is a numpy array of a given dimensionality.
Arguments:
- array: The input numpy array or torch tensor.
- ndim: The requested dimensionality.
- dtype: The dtype of the output array.
Returns:
A numpy array of the requested dimensionality and data type.
233def ensure_patch_shape( 234 raw: np.ndarray, 235 labels: Optional[np.ndarray], 236 patch_shape: Tuple[int, ...], 237 have_raw_channels: bool = False, 238 have_label_channels: bool = False, 239 channel_first: bool = True, 240) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: 241 """Ensure that the raw data and labels have at least the requested patch shape. 242 243 If either raw data or labels do not have the patch shape they will be padded. 244 245 Args: 246 raw: The input raw data. 247 labels: The input labels. 248 patch_shape: The required minimal patch shape. 249 have_raw_channels: Whether the raw data has channels. 250 have_label_channels: Whether the label data has channels. 251 channel_first: Whether the channel axis is the first or last axis. 252 253 Returns: 254 The raw data. 255 The labels. 256 """ 257 raw_shape = raw.shape 258 if labels is not None: 259 labels_shape = labels.shape 260 261 # In case the inputs has channels and they are channels first 262 # IMPORTANT: for ImageCollectionDataset 263 if have_raw_channels and channel_first: 264 raw_shape = raw_shape[1:] 265 266 if have_label_channels and channel_first and labels is not None: 267 labels_shape = labels_shape[1:] 268 269 # Extract the pad_width and pad the raw inputs 270 if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)): 271 pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)] 272 273 if have_raw_channels and channel_first: 274 pad_width = [(0, 0), *pw] 275 elif have_raw_channels and not channel_first: 276 pad_width = [*pw, (0, 0)] 277 else: 278 pad_width = pw 279 280 raw = np.pad(array=raw, pad_width=pad_width) 281 282 # Extract the pad width and pad the label inputs 283 if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)): 284 pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)] 285 286 if have_label_channels and channel_first: 287 pad_width = [(0, 0), *pw] 288 elif have_label_channels and not channel_first: 289 pad_width = [*pw, (0, 0)] 290 else: 291 pad_width = pw 292 293 labels = np.pad(array=labels, pad_width=pad_width) 294 if labels is None: 295 return raw 296 else: 297 return raw, labels
Ensure that the raw data and labels have at least the requested patch shape.
If either raw data or labels do not have the patch shape they will be padded.
Arguments:
- raw: The input raw data.
- labels: The input labels.
- patch_shape: The required minimal patch shape.
- have_raw_channels: Whether the raw data has channels.
- have_label_channels: Whether the label data has channels.
- channel_first: Whether the channel axis is the first or last axis.
Returns:
The raw data. The labels.
367def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None): 368 """Load trainer from a checkpoint. 369 370 Args: 371 checkpoint: The path to the checkpoint. 372 name: The name of the checkpoint. 373 device: The device to use for loading the checkpoint. 374 375 Returns: 376 The trainer. 377 """ 378 # try to load from file 379 if isinstance(checkpoint, str): 380 assert os.path.exists(checkpoint), checkpoint 381 trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device) 382 else: 383 trainer = checkpoint 384 assert isinstance(trainer, torch_em.trainer.DefaultTrainer) 385 return trainer
Load trainer from a checkpoint.
Arguments:
- checkpoint: The path to the checkpoint.
- name: The name of the checkpoint.
- device: The device to use for loading the checkpoint.
Returns:
The trainer.
409def load_model( 410 checkpoint: str, 411 model: Optional[torch.nn.Module] = None, 412 name: str = "best", 413 state_key: Optional[str] = "model_state", 414 device: Optional[str] = None, 415) -> torch.nn.Module: 416 """Load a model from a trainer checkpoint or a serialized torch model. 417 418 This function can either load the model directly (`model` is not passed), 419 or deserialize the model state and then load it (`model` is passed). 420 421 The `checkpoint` argument must either point to the checkpoint directory of a torch_em trainer 422 or to a serialized torch model. 423 424 Args: 425 checkpoint: The path to the checkpoint folder or serialized torch model. 426 model: The model for which the state should be loaded. 427 If it is not passed, the model class and parameters will also be loaded from the trainer. 428 name: The name of the checkpoint. 429 state_key: The name of the model state to load. Set to None if the model state is stored top-level. 430 device: The device on which to load the model. 431 432 Returns: 433 The model. 434 """ 435 if model is None and os.path.isdir(checkpoint): # Load the model and its state from a torch_em checkpoint. 436 model = get_trainer(checkpoint, name=name, device=device).model 437 438 elif model is None: # Load the model from a serialized model. 439 model = torch.load(checkpoint, map_location=device, weights_only=False) 440 441 else: # Load the model state from a checkpoint. 442 if os.path.isdir(checkpoint): # From a torch_em checkpoint. 443 ckpt = os.path.join(checkpoint, f"{name}.pt") 444 else: # From a serialized path. 445 ckpt = checkpoint 446 447 state = torch.load(ckpt, map_location=device, weights_only=False) 448 if state_key is not None: 449 state = state[state_key] 450 451 # To enable loading compiled models. 452 compiled_prefix = "_orig_mod." 453 state = OrderedDict( 454 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()] 455 ) 456 457 model.load_state_dict(state) 458 if device is not None: 459 model.to(device) 460 461 return model
Load a model from a trainer checkpoint or a serialized torch model.
This function can either load the model directly (model is not passed),
or deserialize the model state and then load it (model is passed).
The checkpoint argument must either point to the checkpoint directory of a torch_em trainer
or to a serialized torch model.
Arguments:
- checkpoint: The path to the checkpoint folder or serialized torch model.
- model: The model for which the state should be loaded. If it is not passed, the model class and parameters will also be loaded from the trainer.
- name: The name of the checkpoint.
- state_key: The name of the model state to load. Set to None if the model state is stored top-level.
- device: The device on which to load the model.
Returns:
The model.
473def get_random_colors(labels: np.ndarray) -> colors.ListedColormap: 474 """Generate a random color map for a label image. 475 476 Args: 477 labels: The labels. 478 479 Returns: 480 The color map. 481 """ 482 unique_labels = np.unique(labels) 483 have_zero = 0 in unique_labels 484 cmap = [[0, 0, 0]] if have_zero else [] 485 cmap += np.random.rand(len(unique_labels), 3).tolist() 486 cmap = colors.ListedColormap(cmap) 487 return cmap
Generate a random color map for a label image.
Arguments:
- labels: The labels.
Returns:
The color map.