torch_em.util.util
1import os 2import warnings 3from collections import OrderedDict 4from typing import Optional, Tuple, Union 5 6import numpy as np 7import torch 8import torch_em 9from matplotlib import colors 10from numpy.typing import ArrayLike 11 12try: 13 from torch._dynamo.eval_frame import OptimizedModule 14except ImportError: 15 OptimizedModule = None 16 17# torch doesn't support most unsigned types, 18# so we map them to their signed equivalent 19DTYPE_MAP = { 20 np.dtype("uint16"): np.int16, 21 np.dtype("uint32"): np.int32, 22 np.dtype("uint64"): np.int64 23} 24"""@private 25""" 26 27 28# This is a fairly brittle way to check if a module is compiled. 29# Would be good to find a better solution, ideall something like model.is_compiled(). 30def is_compiled(model): 31 """@private 32 """ 33 if OptimizedModule is None: 34 return False 35 return isinstance(model, OptimizedModule) 36 37 38def auto_compile( 39 model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True 40) -> torch.nn.Module: 41 """Automatically compile a model for pytorch >= 2. 42 43 Args: 44 model: The model. 45 compile_model: Whether to comile the model. 46 If None, it will not be compiled for torch < 2, and for torch > 2 the behavior 47 specificed by 'default_compile' will be used. If a string is given it will be 48 intepreted as the compile mode (torch.compile(model, mode=compile_model)) 49 default_compile: Whether to use the default compilation behavior for torch 2. 50 51 Returns: 52 The compiled model. 53 """ 54 torch_major = int(torch.__version__.split(".")[0]) 55 56 if compile_model is None: 57 if torch_major < 2: 58 compile_model = False 59 elif is_compiled(model): # model is already compiled 60 compile_model = False 61 else: 62 compile_model = default_compile 63 64 if compile_model: 65 if torch_major < 2: 66 raise RuntimeError("Model compilation is only supported for pytorch 2") 67 68 print("Compiling pytorch model ...") 69 if isinstance(compile_model, str): 70 model = torch.compile(model, mode=compile_model) 71 else: 72 model = torch.compile(model) 73 74 return model 75 76 77def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor: 78 """Ensure that the input is a torch tensor, by converting it if necessary. 79 80 Args: 81 tensor: The input object, either a torch tensor or a numpy-array like object. 82 dtype: The required data type for the output tensor. 83 84 Returns: 85 The input, converted to a torch tensor if necessary. 86 """ 87 if isinstance(tensor, np.ndarray): 88 if np.dtype(tensor.dtype) in DTYPE_MAP: 89 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 90 # Try to convert the tensor, even if it has wrong byte-order 91 try: 92 tensor = torch.from_numpy(tensor) 93 except ValueError: 94 tensor = tensor.view(tensor.dtype.newbyteorder()) 95 if np.dtype(tensor.dtype) in DTYPE_MAP: 96 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 97 tensor = torch.from_numpy(tensor) 98 99 assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch" 100 if dtype is not None: 101 tensor = tensor.to(dtype=dtype) 102 return tensor 103 104 105def ensure_tensor_with_channels( 106 tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None 107) -> torch.Tensor: 108 """Ensure that the input is a torch tensor of a given dimensionality with channels. 109 110 Args: 111 tensor: The input tensor or numpy-array like data. 112 ndim: The dimensionality of the output tensor. 113 dtype: The data type of the output tensor. 114 115 Returns: 116 The input converted to a torch tensor of the requested dimensionality. 117 """ 118 assert ndim in (2, 3, 4), f"{ndim}" 119 tensor = ensure_tensor(tensor, dtype) 120 if ndim == 2: 121 assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}" 122 if tensor.ndim == 2: 123 tensor = tensor[None] 124 elif tensor.ndim == 4: 125 assert tensor.shape[0] == 1, f"{tensor.shape}" 126 tensor = tensor[0] 127 elif tensor.ndim == 5: 128 assert tensor.shape[:2] == (1, 1), f"{tensor.shape}" 129 tensor = tensor[0, 0] 130 elif ndim == 3: 131 assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}" 132 if tensor.ndim == 3: 133 tensor = tensor[None] 134 elif tensor.ndim == 5: 135 assert tensor.shape[0] == 1, f"{tensor.shape}" 136 tensor = tensor[0] 137 else: 138 assert tensor.ndim in (4, 5), f"{tensor.ndim}" 139 if tensor.ndim == 5: 140 assert tensor.shape[0] == 1, f"{tensor.shape}" 141 tensor = tensor[0] 142 return tensor 143 144 145def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray: 146 """Ensure that the input is a numpy array, by converting it if necessary. 147 148 Args: 149 array: The input torch tensor or numpy array. 150 dtype: The dtype of the ouptut array. 151 152 Returns: 153 The input converted to a numpy array if necessary. 154 """ 155 if torch.is_tensor(array): 156 array = array.detach().cpu().numpy() 157 assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy" 158 if dtype is not None: 159 array = np.require(array, dtype=dtype) 160 return array 161 162 163def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray: 164 """Ensure that the input is a numpy array of a given dimensionality. 165 166 Args: 167 array: The input numpy array or torch tensor. 168 ndim: The requested dimensionality. 169 dtype: The dtype of the output array. 170 171 Returns: 172 A numpy array of the requested dimensionality and data type. 173 """ 174 assert ndim in (2, 3) 175 array = ensure_array(array, dtype) 176 if ndim == 2: 177 assert array.ndim in (2, 3, 4, 5), str(array.ndim) 178 if array.ndim == 3: 179 assert array.shape[0] == 1 180 array = array[0] 181 elif array.ndim == 4: 182 assert array.shape[:2] == (1, 1) 183 array = array[0, 0] 184 elif array.ndim == 5: 185 assert array.shape[:3] == (1, 1, 1) 186 array = array[0, 0, 0] 187 else: 188 assert array.ndim in (3, 4, 5), str(array.ndim) 189 if array.ndim == 4: 190 assert array.shape[0] == 1, f"{array.shape}" 191 array = array[0] 192 elif array.ndim == 5: 193 assert array.shape[:2] == (1, 1) 194 array = array[0, 0] 195 return array 196 197 198def ensure_patch_shape( 199 raw: np.ndarray, 200 labels: Optional[np.ndarray], 201 patch_shape: Tuple[int, ...], 202 have_raw_channels: bool = False, 203 have_label_channels: bool = False, 204 channel_first: bool = True, 205) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: 206 """Ensure that the raw data and labels have at least the requested patch shape. 207 208 If either raw data or labels do not have the patch shape they will be padded. 209 210 Args: 211 raw: The input raw data. 212 labels: The input labels. 213 patch_shape: The required minimal patch shape. 214 have_raw_channels: Whether the raw data has channels. 215 have_label_channels: Whether the label data has channels. 216 channel_first: Whether the channel axis is the first or last axis. 217 218 Returns: 219 The raw data. 220 The labels. 221 """ 222 raw_shape = raw.shape 223 if labels is not None: 224 labels_shape = labels.shape 225 226 # In case the inputs has channels and they are channels first 227 # IMPORTANT: for ImageCollectionDataset 228 if have_raw_channels and channel_first: 229 raw_shape = raw_shape[1:] 230 231 if have_label_channels and channel_first and labels is not None: 232 labels_shape = labels_shape[1:] 233 234 # Extract the pad_width and pad the raw inputs 235 if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)): 236 pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)] 237 238 if have_raw_channels and channel_first: 239 pad_width = [(0, 0), *pw] 240 elif have_raw_channels and not channel_first: 241 pad_width = [*pw, (0, 0)] 242 else: 243 pad_width = pw 244 245 raw = np.pad(array=raw, pad_width=pad_width) 246 247 # Extract the pad width and pad the label inputs 248 if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)): 249 pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)] 250 251 if have_label_channels and channel_first: 252 pad_width = [(0, 0), *pw] 253 elif have_label_channels and not channel_first: 254 pad_width = [*pw, (0, 0)] 255 else: 256 pad_width = pw 257 258 labels = np.pad(array=labels, pad_width=pad_width) 259 if labels is None: 260 return raw 261 else: 262 return raw, labels 263 264 265def get_constructor_arguments(obj): 266 """@private 267 """ 268 # All relevant torch_em classes have 'init_kwargs' to directly recover the init call. 269 if hasattr(obj, "init_kwargs"): 270 return getattr(obj, "init_kwargs") 271 272 def _get_args(obj, param_names): 273 return {name: getattr(obj, name) for name in param_names} 274 275 # We don't need to find the constructor arguments for optimizers/schedulers because we deserialize the state later. 276 if isinstance( 277 obj, ( 278 torch.optim.Optimizer, 279 torch.optim.lr_scheduler._LRScheduler, 280 # ReduceLROnPlateau does not inherit from _LRScheduler 281 torch.optim.lr_scheduler.ReduceLROnPlateau 282 ) 283 ): 284 return {} 285 286 # recover the arguments for torch dataloader 287 elif isinstance(obj, torch.utils.data.DataLoader): 288 # These are all the "simple" arguements. 289 # "sampler", "batch_sampler" and "worker_init_fn" are more complicated 290 # and generally not used in torch_em 291 return _get_args( 292 obj, [ 293 "batch_size", "shuffle", "num_workers", "pin_memory", "drop_last", 294 "persistent_workers", "prefetch_factor", "timeout" 295 ] 296 ) 297 298 # TODO support common torch losses (e.g. CrossEntropy, BCE) 299 warnings.warn( 300 f"Constructor arguments for {type(obj)} cannot be deduced.\n" + 301 "For this object, empty constructor arguments will be used.\n" + 302 "The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'." 303 ) 304 return {} 305 306 307def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None): 308 """Load trainer from a checkpoint. 309 310 Args: 311 checkpoint: The path to the checkpoint. 312 name: The name of the checkpoint. 313 device: The device to use for loading the checkpoint. 314 315 Returns: 316 The trainer. 317 """ 318 # try to load from file 319 if isinstance(checkpoint, str): 320 assert os.path.exists(checkpoint), checkpoint 321 trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device) 322 else: 323 trainer = checkpoint 324 assert isinstance(trainer, torch_em.trainer.DefaultTrainer) 325 return trainer 326 327 328def get_normalizer(trainer): 329 """@private 330 """ 331 dataset = trainer.train_loader.dataset 332 while ( 333 isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset) or 334 isinstance(dataset, torch.utils.data.dataset.ConcatDataset) 335 ): 336 dataset = dataset.datasets[0] 337 338 if isinstance(dataset, torch.utils.data.dataset.Subset): 339 dataset = dataset.dataset 340 341 preprocessor = dataset.raw_transform 342 343 if hasattr(preprocessor, "normalizer"): 344 return preprocessor.normalizer 345 else: 346 return preprocessor 347 348 349def load_model( 350 checkpoint: str, 351 model: Optional[torch.nn.Module] = None, 352 name: str = "best", 353 state_key: str = "model_state", 354 device: Optional[str] = None, 355) -> torch.nn.Module: 356 """Load model from a trainer checkpoint. 357 358 This function can either load the model directly from the trainer (model is not passed), 359 or deserialize the model state from the trainer and load the model state (model is passed). 360 361 Args: 362 checkpoint: The path to the checkpoint folder. 363 model: The model for which the state should be loaded. 364 If it is not passed, the model class and parameters will also be loaded from the trainer. 365 name: The name of the checkpoint. 366 state_key: The name of the model state to load. 367 device: The device on which to load the model. 368 369 Returns: 370 The model. 371 """ 372 if model is None: # load the model and its state from the checkpoint 373 model = get_trainer(checkpoint, name=name, device=device).model 374 375 else: # load the model state from the checkpoint 376 if os.path.isdir(checkpoint): 377 ckpt = os.path.join(checkpoint, f"{name}.pt") 378 else: 379 ckpt = checkpoint 380 381 state = torch.load(ckpt, map_location=device, weights_only=False)[state_key] 382 # to enable loading compiled models 383 compiled_prefix = "_orig_mod." 384 state = OrderedDict( 385 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()] 386 ) 387 model.load_state_dict(state) 388 if device is not None: 389 model.to(device) 390 model.load_state_dict(state) 391 392 return model 393 394 395def model_is_equal(model1, model2): 396 """@private 397 """ 398 for p1, p2 in zip(model1.parameters(), model2.parameters()): 399 if p1.data.ne(p2.data).sum() > 0: 400 return False 401 return True 402 403 404def get_random_colors(labels: np.ndarray) -> colors.ListedColormap: 405 """Generate a random color map for a label image. 406 407 Args: 408 labels: The labels. 409 410 Returns: 411 The color map. 412 """ 413 unique_labels = np.unique(labels) 414 have_zero = 0 in unique_labels 415 cmap = [[0, 0, 0]] if have_zero else [] 416 cmap += np.random.rand(len(unique_labels), 3).tolist() 417 cmap = colors.ListedColormap(cmap) 418 return cmap
39def auto_compile( 40 model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True 41) -> torch.nn.Module: 42 """Automatically compile a model for pytorch >= 2. 43 44 Args: 45 model: The model. 46 compile_model: Whether to comile the model. 47 If None, it will not be compiled for torch < 2, and for torch > 2 the behavior 48 specificed by 'default_compile' will be used. If a string is given it will be 49 intepreted as the compile mode (torch.compile(model, mode=compile_model)) 50 default_compile: Whether to use the default compilation behavior for torch 2. 51 52 Returns: 53 The compiled model. 54 """ 55 torch_major = int(torch.__version__.split(".")[0]) 56 57 if compile_model is None: 58 if torch_major < 2: 59 compile_model = False 60 elif is_compiled(model): # model is already compiled 61 compile_model = False 62 else: 63 compile_model = default_compile 64 65 if compile_model: 66 if torch_major < 2: 67 raise RuntimeError("Model compilation is only supported for pytorch 2") 68 69 print("Compiling pytorch model ...") 70 if isinstance(compile_model, str): 71 model = torch.compile(model, mode=compile_model) 72 else: 73 model = torch.compile(model) 74 75 return model
Automatically compile a model for pytorch >= 2.
Arguments:
- model: The model.
- compile_model: Whether to comile the model. If None, it will not be compiled for torch < 2, and for torch > 2 the behavior specificed by 'default_compile' will be used. If a string is given it will be intepreted as the compile mode (torch.compile(model, mode=compile_model))
- default_compile: Whether to use the default compilation behavior for torch 2.
Returns:
The compiled model.
78def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor: 79 """Ensure that the input is a torch tensor, by converting it if necessary. 80 81 Args: 82 tensor: The input object, either a torch tensor or a numpy-array like object. 83 dtype: The required data type for the output tensor. 84 85 Returns: 86 The input, converted to a torch tensor if necessary. 87 """ 88 if isinstance(tensor, np.ndarray): 89 if np.dtype(tensor.dtype) in DTYPE_MAP: 90 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 91 # Try to convert the tensor, even if it has wrong byte-order 92 try: 93 tensor = torch.from_numpy(tensor) 94 except ValueError: 95 tensor = tensor.view(tensor.dtype.newbyteorder()) 96 if np.dtype(tensor.dtype) in DTYPE_MAP: 97 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 98 tensor = torch.from_numpy(tensor) 99 100 assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch" 101 if dtype is not None: 102 tensor = tensor.to(dtype=dtype) 103 return tensor
Ensure that the input is a torch tensor, by converting it if necessary.
Arguments:
- tensor: The input object, either a torch tensor or a numpy-array like object.
- dtype: The required data type for the output tensor.
Returns:
The input, converted to a torch tensor if necessary.
106def ensure_tensor_with_channels( 107 tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None 108) -> torch.Tensor: 109 """Ensure that the input is a torch tensor of a given dimensionality with channels. 110 111 Args: 112 tensor: The input tensor or numpy-array like data. 113 ndim: The dimensionality of the output tensor. 114 dtype: The data type of the output tensor. 115 116 Returns: 117 The input converted to a torch tensor of the requested dimensionality. 118 """ 119 assert ndim in (2, 3, 4), f"{ndim}" 120 tensor = ensure_tensor(tensor, dtype) 121 if ndim == 2: 122 assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}" 123 if tensor.ndim == 2: 124 tensor = tensor[None] 125 elif tensor.ndim == 4: 126 assert tensor.shape[0] == 1, f"{tensor.shape}" 127 tensor = tensor[0] 128 elif tensor.ndim == 5: 129 assert tensor.shape[:2] == (1, 1), f"{tensor.shape}" 130 tensor = tensor[0, 0] 131 elif ndim == 3: 132 assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}" 133 if tensor.ndim == 3: 134 tensor = tensor[None] 135 elif tensor.ndim == 5: 136 assert tensor.shape[0] == 1, f"{tensor.shape}" 137 tensor = tensor[0] 138 else: 139 assert tensor.ndim in (4, 5), f"{tensor.ndim}" 140 if tensor.ndim == 5: 141 assert tensor.shape[0] == 1, f"{tensor.shape}" 142 tensor = tensor[0] 143 return tensor
Ensure that the input is a torch tensor of a given dimensionality with channels.
Arguments:
- tensor: The input tensor or numpy-array like data.
- ndim: The dimensionality of the output tensor.
- dtype: The data type of the output tensor.
Returns:
The input converted to a torch tensor of the requested dimensionality.
146def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray: 147 """Ensure that the input is a numpy array, by converting it if necessary. 148 149 Args: 150 array: The input torch tensor or numpy array. 151 dtype: The dtype of the ouptut array. 152 153 Returns: 154 The input converted to a numpy array if necessary. 155 """ 156 if torch.is_tensor(array): 157 array = array.detach().cpu().numpy() 158 assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy" 159 if dtype is not None: 160 array = np.require(array, dtype=dtype) 161 return array
Ensure that the input is a numpy array, by converting it if necessary.
Arguments:
- array: The input torch tensor or numpy array.
- dtype: The dtype of the ouptut array.
Returns:
The input converted to a numpy array if necessary.
164def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray: 165 """Ensure that the input is a numpy array of a given dimensionality. 166 167 Args: 168 array: The input numpy array or torch tensor. 169 ndim: The requested dimensionality. 170 dtype: The dtype of the output array. 171 172 Returns: 173 A numpy array of the requested dimensionality and data type. 174 """ 175 assert ndim in (2, 3) 176 array = ensure_array(array, dtype) 177 if ndim == 2: 178 assert array.ndim in (2, 3, 4, 5), str(array.ndim) 179 if array.ndim == 3: 180 assert array.shape[0] == 1 181 array = array[0] 182 elif array.ndim == 4: 183 assert array.shape[:2] == (1, 1) 184 array = array[0, 0] 185 elif array.ndim == 5: 186 assert array.shape[:3] == (1, 1, 1) 187 array = array[0, 0, 0] 188 else: 189 assert array.ndim in (3, 4, 5), str(array.ndim) 190 if array.ndim == 4: 191 assert array.shape[0] == 1, f"{array.shape}" 192 array = array[0] 193 elif array.ndim == 5: 194 assert array.shape[:2] == (1, 1) 195 array = array[0, 0] 196 return array
Ensure that the input is a numpy array of a given dimensionality.
Arguments:
- array: The input numpy array or torch tensor.
- ndim: The requested dimensionality.
- dtype: The dtype of the output array.
Returns:
A numpy array of the requested dimensionality and data type.
199def ensure_patch_shape( 200 raw: np.ndarray, 201 labels: Optional[np.ndarray], 202 patch_shape: Tuple[int, ...], 203 have_raw_channels: bool = False, 204 have_label_channels: bool = False, 205 channel_first: bool = True, 206) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: 207 """Ensure that the raw data and labels have at least the requested patch shape. 208 209 If either raw data or labels do not have the patch shape they will be padded. 210 211 Args: 212 raw: The input raw data. 213 labels: The input labels. 214 patch_shape: The required minimal patch shape. 215 have_raw_channels: Whether the raw data has channels. 216 have_label_channels: Whether the label data has channels. 217 channel_first: Whether the channel axis is the first or last axis. 218 219 Returns: 220 The raw data. 221 The labels. 222 """ 223 raw_shape = raw.shape 224 if labels is not None: 225 labels_shape = labels.shape 226 227 # In case the inputs has channels and they are channels first 228 # IMPORTANT: for ImageCollectionDataset 229 if have_raw_channels and channel_first: 230 raw_shape = raw_shape[1:] 231 232 if have_label_channels and channel_first and labels is not None: 233 labels_shape = labels_shape[1:] 234 235 # Extract the pad_width and pad the raw inputs 236 if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)): 237 pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)] 238 239 if have_raw_channels and channel_first: 240 pad_width = [(0, 0), *pw] 241 elif have_raw_channels and not channel_first: 242 pad_width = [*pw, (0, 0)] 243 else: 244 pad_width = pw 245 246 raw = np.pad(array=raw, pad_width=pad_width) 247 248 # Extract the pad width and pad the label inputs 249 if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)): 250 pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)] 251 252 if have_label_channels and channel_first: 253 pad_width = [(0, 0), *pw] 254 elif have_label_channels and not channel_first: 255 pad_width = [*pw, (0, 0)] 256 else: 257 pad_width = pw 258 259 labels = np.pad(array=labels, pad_width=pad_width) 260 if labels is None: 261 return raw 262 else: 263 return raw, labels
Ensure that the raw data and labels have at least the requested patch shape.
If either raw data or labels do not have the patch shape they will be padded.
Arguments:
- raw: The input raw data.
- labels: The input labels.
- patch_shape: The required minimal patch shape.
- have_raw_channels: Whether the raw data has channels.
- have_label_channels: Whether the label data has channels.
- channel_first: Whether the channel axis is the first or last axis.
Returns:
The raw data. The labels.
308def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None): 309 """Load trainer from a checkpoint. 310 311 Args: 312 checkpoint: The path to the checkpoint. 313 name: The name of the checkpoint. 314 device: The device to use for loading the checkpoint. 315 316 Returns: 317 The trainer. 318 """ 319 # try to load from file 320 if isinstance(checkpoint, str): 321 assert os.path.exists(checkpoint), checkpoint 322 trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device) 323 else: 324 trainer = checkpoint 325 assert isinstance(trainer, torch_em.trainer.DefaultTrainer) 326 return trainer
Load trainer from a checkpoint.
Arguments:
- checkpoint: The path to the checkpoint.
- name: The name of the checkpoint.
- device: The device to use for loading the checkpoint.
Returns:
The trainer.
350def load_model( 351 checkpoint: str, 352 model: Optional[torch.nn.Module] = None, 353 name: str = "best", 354 state_key: str = "model_state", 355 device: Optional[str] = None, 356) -> torch.nn.Module: 357 """Load model from a trainer checkpoint. 358 359 This function can either load the model directly from the trainer (model is not passed), 360 or deserialize the model state from the trainer and load the model state (model is passed). 361 362 Args: 363 checkpoint: The path to the checkpoint folder. 364 model: The model for which the state should be loaded. 365 If it is not passed, the model class and parameters will also be loaded from the trainer. 366 name: The name of the checkpoint. 367 state_key: The name of the model state to load. 368 device: The device on which to load the model. 369 370 Returns: 371 The model. 372 """ 373 if model is None: # load the model and its state from the checkpoint 374 model = get_trainer(checkpoint, name=name, device=device).model 375 376 else: # load the model state from the checkpoint 377 if os.path.isdir(checkpoint): 378 ckpt = os.path.join(checkpoint, f"{name}.pt") 379 else: 380 ckpt = checkpoint 381 382 state = torch.load(ckpt, map_location=device, weights_only=False)[state_key] 383 # to enable loading compiled models 384 compiled_prefix = "_orig_mod." 385 state = OrderedDict( 386 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()] 387 ) 388 model.load_state_dict(state) 389 if device is not None: 390 model.to(device) 391 model.load_state_dict(state) 392 393 return model
Load model from a trainer checkpoint.
This function can either load the model directly from the trainer (model is not passed), or deserialize the model state from the trainer and load the model state (model is passed).
Arguments:
- checkpoint: The path to the checkpoint folder.
- model: The model for which the state should be loaded. If it is not passed, the model class and parameters will also be loaded from the trainer.
- name: The name of the checkpoint.
- state_key: The name of the model state to load.
- device: The device on which to load the model.
Returns:
The model.
405def get_random_colors(labels: np.ndarray) -> colors.ListedColormap: 406 """Generate a random color map for a label image. 407 408 Args: 409 labels: The labels. 410 411 Returns: 412 The color map. 413 """ 414 unique_labels = np.unique(labels) 415 have_zero = 0 in unique_labels 416 cmap = [[0, 0, 0]] if have_zero else [] 417 cmap += np.random.rand(len(unique_labels), 3).tolist() 418 cmap = colors.ListedColormap(cmap) 419 return cmap
Generate a random color map for a label image.
Arguments:
- labels: The labels.
Returns:
The color map.