torch_em.util.util
1import os 2import warnings 3from collections import OrderedDict 4from typing import Optional, Tuple, Union 5 6import numpy as np 7import torch 8import torch_em 9from matplotlib import colors 10from numpy.typing import ArrayLike 11 12try: 13 from torch._dynamo.eval_frame import OptimizedModule 14except ImportError: 15 OptimizedModule = None 16 17# torch doesn't support most unsigned types, 18# so we map them to their signed equivalent 19DTYPE_MAP = { 20 np.dtype("uint16"): np.int16, 21 np.dtype("uint32"): np.int32, 22 np.dtype("uint64"): np.int64 23} 24"""@private 25""" 26 27 28# This is a fairly brittle way to check if a module is compiled. 29# Would be good to find a better solution, ideall something like model.is_compiled(). 30def is_compiled(model): 31 """@private 32 """ 33 if OptimizedModule is None: 34 return False 35 return isinstance(model, OptimizedModule) 36 37 38def auto_compile( 39 model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True 40) -> torch.nn.Module: 41 """Automatically compile a model for pytorch >= 2. 42 43 Args: 44 model: The model. 45 compile_model: Whether to comile the model. 46 If None, it will not be compiled for torch < 2, and for torch > 2 the behavior 47 specificed by 'default_compile' will be used. If a string is given it will be 48 intepreted as the compile mode (torch.compile(model, mode=compile_model)) 49 default_compile: Whether to use the default compilation behavior for torch 2. 50 51 Returns: 52 The compiled model. 53 """ 54 torch_major = int(torch.__version__.split(".")[0]) 55 56 if compile_model is None: 57 if torch_major < 2: 58 compile_model = False 59 elif is_compiled(model): # model is already compiled 60 compile_model = False 61 else: 62 compile_model = default_compile 63 64 if compile_model: 65 if torch_major < 2: 66 raise RuntimeError("Model compilation is only supported for pytorch 2") 67 68 print("Compiling pytorch model ...") 69 if isinstance(compile_model, str): 70 model = torch.compile(model, mode=compile_model) 71 else: 72 model = torch.compile(model) 73 74 return model 75 76 77def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor: 78 """Ensure that the input is a torch tensor, by converting it if necessary. 79 80 Args: 81 tensor: The input object, either a torch tensor or a numpy-array like object. 82 dtype: The required data type for the output tensor. 83 84 Returns: 85 The input, converted to a torch tensor if necessary. 86 """ 87 if isinstance(tensor, np.ndarray): 88 if np.dtype(tensor.dtype) in DTYPE_MAP: 89 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 90 # Try to convert the tensor, even if it has wrong byte-order 91 try: 92 tensor = torch.from_numpy(tensor) 93 except ValueError: 94 tensor = tensor.view(tensor.dtype.newbyteorder()) 95 if np.dtype(tensor.dtype) in DTYPE_MAP: 96 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 97 tensor = torch.from_numpy(tensor) 98 99 assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch" 100 if dtype is not None: 101 tensor = tensor.to(dtype=dtype) 102 return tensor 103 104 105def ensure_tensor_with_channels( 106 tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None 107) -> torch.Tensor: 108 """Ensure that the input is a torch tensor of a given dimensionality with channels. 109 110 Args: 111 tensor: The input tensor or numpy-array like data. 112 ndim: The dimensionality of the output tensor. 113 dtype: The data type of the output tensor. 114 115 Returns: 116 The input converted to a torch tensor of the requested dimensionality. 117 """ 118 assert ndim in (2, 3, 4), f"{ndim}" 119 tensor = ensure_tensor(tensor, dtype) 120 if ndim == 2: 121 assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}" 122 if tensor.ndim == 2: 123 tensor = tensor[None] 124 elif tensor.ndim == 4: 125 assert tensor.shape[0] == 1, f"{tensor.shape}" 126 tensor = tensor[0] 127 elif tensor.ndim == 5: 128 assert tensor.shape[:2] == (1, 1), f"{tensor.shape}" 129 tensor = tensor[0, 0] 130 elif ndim == 3: 131 assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}" 132 if tensor.ndim == 3: 133 tensor = tensor[None] 134 elif tensor.ndim == 5: 135 assert tensor.shape[0] == 1, f"{tensor.shape}" 136 tensor = tensor[0] 137 else: 138 assert tensor.ndim in (4, 5), f"{tensor.ndim}" 139 if tensor.ndim == 5: 140 assert tensor.shape[0] == 1, f"{tensor.shape}" 141 tensor = tensor[0] 142 return tensor 143 144 145def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray: 146 """Ensure that the input is a numpy array, by converting it if necessary. 147 148 Args: 149 array: The input torch tensor or numpy array. 150 dtype: The dtype of the ouptut array. 151 152 Returns: 153 The input converted to a numpy array if necessary. 154 """ 155 if torch.is_tensor(array): 156 array = array.detach().cpu().numpy() 157 assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy" 158 if dtype is not None: 159 array = np.require(array, dtype=dtype) 160 return array 161 162 163def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray: 164 """Ensure that the input is a numpy array of a given dimensionality. 165 166 Args: 167 array: The input numpy array or torch tensor. 168 ndim: The requested dimensionality. 169 dtype: The dtype of the output array. 170 171 Returns: 172 A numpy array of the requested dimensionality and data type. 173 """ 174 assert ndim in (2, 3) 175 array = ensure_array(array, dtype) 176 if ndim == 2: 177 assert array.ndim in (2, 3, 4, 5), str(array.ndim) 178 if array.ndim == 3: 179 assert array.shape[0] == 1 180 array = array[0] 181 elif array.ndim == 4: 182 assert array.shape[:2] == (1, 1) 183 array = array[0, 0] 184 elif array.ndim == 5: 185 assert array.shape[:3] == (1, 1, 1) 186 array = array[0, 0, 0] 187 else: 188 assert array.ndim in (3, 4, 5), str(array.ndim) 189 if array.ndim == 4: 190 assert array.shape[0] == 1, f"{array.shape}" 191 array = array[0] 192 elif array.ndim == 5: 193 assert array.shape[:2] == (1, 1) 194 array = array[0, 0] 195 return array 196 197 198def ensure_patch_shape( 199 raw: np.ndarray, 200 labels: Optional[np.ndarray], 201 patch_shape: Tuple[int, ...], 202 have_raw_channels: bool = False, 203 have_label_channels: bool = False, 204 channel_first: bool = True, 205) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: 206 """Ensure that the raw data and labels have at least the requested patch shape. 207 208 If either raw data or labels do not have the patch shape they will be padded. 209 210 Args: 211 raw: The input raw data. 212 labels: The input labels. 213 patch_shape: The required minimal patch shape. 214 have_raw_channels: Whether the raw data has channels. 215 have_label_channels: Whether the label data has channels. 216 channel_first: Whether the channel axis is the first or last axis. 217 218 Returns: 219 The raw data. 220 The labels. 221 """ 222 raw_shape = raw.shape 223 if labels is not None: 224 labels_shape = labels.shape 225 226 # In case the inputs has channels and they are channels first 227 # IMPORTANT: for ImageCollectionDataset 228 if have_raw_channels and channel_first: 229 raw_shape = raw_shape[1:] 230 231 if have_label_channels and channel_first and labels is not None: 232 labels_shape = labels_shape[1:] 233 234 # Extract the pad_width and pad the raw inputs 235 if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)): 236 pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)] 237 238 if have_raw_channels and channel_first: 239 pad_width = [(0, 0), *pw] 240 elif have_raw_channels and not channel_first: 241 pad_width = [*pw, (0, 0)] 242 else: 243 pad_width = pw 244 245 raw = np.pad(array=raw, pad_width=pad_width) 246 247 # Extract the pad width and pad the label inputs 248 if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)): 249 pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)] 250 251 if have_label_channels and channel_first: 252 pad_width = [(0, 0), *pw] 253 elif have_label_channels and not channel_first: 254 pad_width = [*pw, (0, 0)] 255 else: 256 pad_width = pw 257 258 labels = np.pad(array=labels, pad_width=pad_width) 259 if labels is None: 260 return raw 261 else: 262 return raw, labels 263 264 265def get_constructor_arguments(obj): 266 """@private 267 """ 268 # All relevant torch_em classes have 'init_kwargs' to directly recover the init call. 269 if hasattr(obj, "init_kwargs"): 270 return getattr(obj, "init_kwargs") 271 272 def _get_args(obj, param_names): 273 return {name: getattr(obj, name) for name in param_names} 274 275 # We don't need to find the constructor arguments for optimizers/schedulers because we deserialize the state later. 276 if isinstance( 277 obj, ( 278 torch.optim.Optimizer, 279 torch.optim.lr_scheduler._LRScheduler, 280 # ReduceLROnPlateau does not inherit from _LRScheduler 281 torch.optim.lr_scheduler.ReduceLROnPlateau 282 ) 283 ): 284 return {} 285 286 # recover the arguments for torch dataloader 287 elif isinstance(obj, torch.utils.data.DataLoader): 288 # These are all the "simple" arguements. 289 # "sampler", "batch_sampler" and "worker_init_fn" are more complicated 290 # and generally not used in torch_em 291 return _get_args( 292 obj, [ 293 "batch_size", "shuffle", "num_workers", "pin_memory", "drop_last", 294 "persistent_workers", "prefetch_factor", "timeout" 295 ] 296 ) 297 298 # TODO support common torch losses (e.g. CrossEntropy, BCE) 299 warnings.warn( 300 f"Constructor arguments for {type(obj)} cannot be deduced.\n" + 301 "For this object, empty constructor arguments will be used.\n" + 302 "The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'." 303 ) 304 return {} 305 306 307def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None): 308 """Load trainer from a checkpoint. 309 310 Args: 311 checkpoint: The path to the checkpoint. 312 name: The name of the checkpoint. 313 device: The device to use for loading the checkpoint. 314 315 Returns: 316 The trainer. 317 """ 318 # try to load from file 319 if isinstance(checkpoint, str): 320 assert os.path.exists(checkpoint), checkpoint 321 trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device) 322 else: 323 trainer = checkpoint 324 assert isinstance(trainer, torch_em.trainer.DefaultTrainer) 325 return trainer 326 327 328def get_normalizer(trainer): 329 """@private 330 """ 331 dataset = trainer.train_loader.dataset 332 while ( 333 isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset) or 334 isinstance(dataset, torch.utils.data.dataset.ConcatDataset) 335 ): 336 dataset = dataset.datasets[0] 337 338 if isinstance(dataset, torch.utils.data.dataset.Subset): 339 dataset = dataset.dataset 340 341 preprocessor = dataset.raw_transform 342 343 if hasattr(preprocessor, "normalizer"): 344 return preprocessor.normalizer 345 else: 346 return preprocessor 347 348 349def load_model( 350 checkpoint: str, 351 model: Optional[torch.nn.Module] = None, 352 name: str = "best", 353 state_key: Optional[str] = "model_state", 354 device: Optional[str] = None, 355) -> torch.nn.Module: 356 """Load a model from a trainer checkpoint or a serialized torch model. 357 358 This function can either load the model directly (`model` is not passed), 359 or deserialize the model state and then load it (`model` is passed). 360 361 The `checkpoint` argument must either point to the checkpoint directory of a torch_em trainer 362 or to a serialized torch model. 363 364 Args: 365 checkpoint: The path to the checkpoint folder or serialized torch model. 366 model: The model for which the state should be loaded. 367 If it is not passed, the model class and parameters will also be loaded from the trainer. 368 name: The name of the checkpoint. 369 state_key: The name of the model state to load. Set to None if the model state is stored top-level. 370 device: The device on which to load the model. 371 372 Returns: 373 The model. 374 """ 375 if model is None and os.path.isdir(checkpoint): # Load the model and its state from a torch_em checkpoint. 376 model = get_trainer(checkpoint, name=name, device=device).model 377 378 elif model is None: # Load the model from a serialized model. 379 model = torch.load(checkpoint, map_location=device, weights_only=False) 380 381 else: # Load the model state from a checkpoint. 382 if os.path.isdir(checkpoint): # From a torch_em checkpoint. 383 ckpt = os.path.join(checkpoint, f"{name}.pt") 384 else: # From a serialized path. 385 ckpt = checkpoint 386 387 state = torch.load(ckpt, map_location=device, weights_only=False) 388 if state_key is not None: 389 state = state[state_key] 390 391 # To enable loading compiled models. 392 compiled_prefix = "_orig_mod." 393 state = OrderedDict( 394 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()] 395 ) 396 397 model.load_state_dict(state) 398 if device is not None: 399 model.to(device) 400 401 return model 402 403 404def model_is_equal(model1, model2): 405 """@private 406 """ 407 for p1, p2 in zip(model1.parameters(), model2.parameters()): 408 if p1.data.ne(p2.data).sum() > 0: 409 return False 410 return True 411 412 413def get_random_colors(labels: np.ndarray) -> colors.ListedColormap: 414 """Generate a random color map for a label image. 415 416 Args: 417 labels: The labels. 418 419 Returns: 420 The color map. 421 """ 422 unique_labels = np.unique(labels) 423 have_zero = 0 in unique_labels 424 cmap = [[0, 0, 0]] if have_zero else [] 425 cmap += np.random.rand(len(unique_labels), 3).tolist() 426 cmap = colors.ListedColormap(cmap) 427 return cmap
39def auto_compile( 40 model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True 41) -> torch.nn.Module: 42 """Automatically compile a model for pytorch >= 2. 43 44 Args: 45 model: The model. 46 compile_model: Whether to comile the model. 47 If None, it will not be compiled for torch < 2, and for torch > 2 the behavior 48 specificed by 'default_compile' will be used. If a string is given it will be 49 intepreted as the compile mode (torch.compile(model, mode=compile_model)) 50 default_compile: Whether to use the default compilation behavior for torch 2. 51 52 Returns: 53 The compiled model. 54 """ 55 torch_major = int(torch.__version__.split(".")[0]) 56 57 if compile_model is None: 58 if torch_major < 2: 59 compile_model = False 60 elif is_compiled(model): # model is already compiled 61 compile_model = False 62 else: 63 compile_model = default_compile 64 65 if compile_model: 66 if torch_major < 2: 67 raise RuntimeError("Model compilation is only supported for pytorch 2") 68 69 print("Compiling pytorch model ...") 70 if isinstance(compile_model, str): 71 model = torch.compile(model, mode=compile_model) 72 else: 73 model = torch.compile(model) 74 75 return model
Automatically compile a model for pytorch >= 2.
Arguments:
- model: The model.
- compile_model: Whether to comile the model. If None, it will not be compiled for torch < 2, and for torch > 2 the behavior specificed by 'default_compile' will be used. If a string is given it will be intepreted as the compile mode (torch.compile(model, mode=compile_model))
- default_compile: Whether to use the default compilation behavior for torch 2.
Returns:
The compiled model.
78def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor: 79 """Ensure that the input is a torch tensor, by converting it if necessary. 80 81 Args: 82 tensor: The input object, either a torch tensor or a numpy-array like object. 83 dtype: The required data type for the output tensor. 84 85 Returns: 86 The input, converted to a torch tensor if necessary. 87 """ 88 if isinstance(tensor, np.ndarray): 89 if np.dtype(tensor.dtype) in DTYPE_MAP: 90 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 91 # Try to convert the tensor, even if it has wrong byte-order 92 try: 93 tensor = torch.from_numpy(tensor) 94 except ValueError: 95 tensor = tensor.view(tensor.dtype.newbyteorder()) 96 if np.dtype(tensor.dtype) in DTYPE_MAP: 97 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 98 tensor = torch.from_numpy(tensor) 99 100 assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch" 101 if dtype is not None: 102 tensor = tensor.to(dtype=dtype) 103 return tensor
Ensure that the input is a torch tensor, by converting it if necessary.
Arguments:
- tensor: The input object, either a torch tensor or a numpy-array like object.
- dtype: The required data type for the output tensor.
Returns:
The input, converted to a torch tensor if necessary.
106def ensure_tensor_with_channels( 107 tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None 108) -> torch.Tensor: 109 """Ensure that the input is a torch tensor of a given dimensionality with channels. 110 111 Args: 112 tensor: The input tensor or numpy-array like data. 113 ndim: The dimensionality of the output tensor. 114 dtype: The data type of the output tensor. 115 116 Returns: 117 The input converted to a torch tensor of the requested dimensionality. 118 """ 119 assert ndim in (2, 3, 4), f"{ndim}" 120 tensor = ensure_tensor(tensor, dtype) 121 if ndim == 2: 122 assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}" 123 if tensor.ndim == 2: 124 tensor = tensor[None] 125 elif tensor.ndim == 4: 126 assert tensor.shape[0] == 1, f"{tensor.shape}" 127 tensor = tensor[0] 128 elif tensor.ndim == 5: 129 assert tensor.shape[:2] == (1, 1), f"{tensor.shape}" 130 tensor = tensor[0, 0] 131 elif ndim == 3: 132 assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}" 133 if tensor.ndim == 3: 134 tensor = tensor[None] 135 elif tensor.ndim == 5: 136 assert tensor.shape[0] == 1, f"{tensor.shape}" 137 tensor = tensor[0] 138 else: 139 assert tensor.ndim in (4, 5), f"{tensor.ndim}" 140 if tensor.ndim == 5: 141 assert tensor.shape[0] == 1, f"{tensor.shape}" 142 tensor = tensor[0] 143 return tensor
Ensure that the input is a torch tensor of a given dimensionality with channels.
Arguments:
- tensor: The input tensor or numpy-array like data.
- ndim: The dimensionality of the output tensor.
- dtype: The data type of the output tensor.
Returns:
The input converted to a torch tensor of the requested dimensionality.
146def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray: 147 """Ensure that the input is a numpy array, by converting it if necessary. 148 149 Args: 150 array: The input torch tensor or numpy array. 151 dtype: The dtype of the ouptut array. 152 153 Returns: 154 The input converted to a numpy array if necessary. 155 """ 156 if torch.is_tensor(array): 157 array = array.detach().cpu().numpy() 158 assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy" 159 if dtype is not None: 160 array = np.require(array, dtype=dtype) 161 return array
Ensure that the input is a numpy array, by converting it if necessary.
Arguments:
- array: The input torch tensor or numpy array.
- dtype: The dtype of the ouptut array.
Returns:
The input converted to a numpy array if necessary.
164def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray: 165 """Ensure that the input is a numpy array of a given dimensionality. 166 167 Args: 168 array: The input numpy array or torch tensor. 169 ndim: The requested dimensionality. 170 dtype: The dtype of the output array. 171 172 Returns: 173 A numpy array of the requested dimensionality and data type. 174 """ 175 assert ndim in (2, 3) 176 array = ensure_array(array, dtype) 177 if ndim == 2: 178 assert array.ndim in (2, 3, 4, 5), str(array.ndim) 179 if array.ndim == 3: 180 assert array.shape[0] == 1 181 array = array[0] 182 elif array.ndim == 4: 183 assert array.shape[:2] == (1, 1) 184 array = array[0, 0] 185 elif array.ndim == 5: 186 assert array.shape[:3] == (1, 1, 1) 187 array = array[0, 0, 0] 188 else: 189 assert array.ndim in (3, 4, 5), str(array.ndim) 190 if array.ndim == 4: 191 assert array.shape[0] == 1, f"{array.shape}" 192 array = array[0] 193 elif array.ndim == 5: 194 assert array.shape[:2] == (1, 1) 195 array = array[0, 0] 196 return array
Ensure that the input is a numpy array of a given dimensionality.
Arguments:
- array: The input numpy array or torch tensor.
- ndim: The requested dimensionality.
- dtype: The dtype of the output array.
Returns:
A numpy array of the requested dimensionality and data type.
199def ensure_patch_shape( 200 raw: np.ndarray, 201 labels: Optional[np.ndarray], 202 patch_shape: Tuple[int, ...], 203 have_raw_channels: bool = False, 204 have_label_channels: bool = False, 205 channel_first: bool = True, 206) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: 207 """Ensure that the raw data and labels have at least the requested patch shape. 208 209 If either raw data or labels do not have the patch shape they will be padded. 210 211 Args: 212 raw: The input raw data. 213 labels: The input labels. 214 patch_shape: The required minimal patch shape. 215 have_raw_channels: Whether the raw data has channels. 216 have_label_channels: Whether the label data has channels. 217 channel_first: Whether the channel axis is the first or last axis. 218 219 Returns: 220 The raw data. 221 The labels. 222 """ 223 raw_shape = raw.shape 224 if labels is not None: 225 labels_shape = labels.shape 226 227 # In case the inputs has channels and they are channels first 228 # IMPORTANT: for ImageCollectionDataset 229 if have_raw_channels and channel_first: 230 raw_shape = raw_shape[1:] 231 232 if have_label_channels and channel_first and labels is not None: 233 labels_shape = labels_shape[1:] 234 235 # Extract the pad_width and pad the raw inputs 236 if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)): 237 pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)] 238 239 if have_raw_channels and channel_first: 240 pad_width = [(0, 0), *pw] 241 elif have_raw_channels and not channel_first: 242 pad_width = [*pw, (0, 0)] 243 else: 244 pad_width = pw 245 246 raw = np.pad(array=raw, pad_width=pad_width) 247 248 # Extract the pad width and pad the label inputs 249 if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)): 250 pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)] 251 252 if have_label_channels and channel_first: 253 pad_width = [(0, 0), *pw] 254 elif have_label_channels and not channel_first: 255 pad_width = [*pw, (0, 0)] 256 else: 257 pad_width = pw 258 259 labels = np.pad(array=labels, pad_width=pad_width) 260 if labels is None: 261 return raw 262 else: 263 return raw, labels
Ensure that the raw data and labels have at least the requested patch shape.
If either raw data or labels do not have the patch shape they will be padded.
Arguments:
- raw: The input raw data.
- labels: The input labels.
- patch_shape: The required minimal patch shape.
- have_raw_channels: Whether the raw data has channels.
- have_label_channels: Whether the label data has channels.
- channel_first: Whether the channel axis is the first or last axis.
Returns:
The raw data. The labels.
308def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None): 309 """Load trainer from a checkpoint. 310 311 Args: 312 checkpoint: The path to the checkpoint. 313 name: The name of the checkpoint. 314 device: The device to use for loading the checkpoint. 315 316 Returns: 317 The trainer. 318 """ 319 # try to load from file 320 if isinstance(checkpoint, str): 321 assert os.path.exists(checkpoint), checkpoint 322 trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device) 323 else: 324 trainer = checkpoint 325 assert isinstance(trainer, torch_em.trainer.DefaultTrainer) 326 return trainer
Load trainer from a checkpoint.
Arguments:
- checkpoint: The path to the checkpoint.
- name: The name of the checkpoint.
- device: The device to use for loading the checkpoint.
Returns:
The trainer.
350def load_model( 351 checkpoint: str, 352 model: Optional[torch.nn.Module] = None, 353 name: str = "best", 354 state_key: Optional[str] = "model_state", 355 device: Optional[str] = None, 356) -> torch.nn.Module: 357 """Load a model from a trainer checkpoint or a serialized torch model. 358 359 This function can either load the model directly (`model` is not passed), 360 or deserialize the model state and then load it (`model` is passed). 361 362 The `checkpoint` argument must either point to the checkpoint directory of a torch_em trainer 363 or to a serialized torch model. 364 365 Args: 366 checkpoint: The path to the checkpoint folder or serialized torch model. 367 model: The model for which the state should be loaded. 368 If it is not passed, the model class and parameters will also be loaded from the trainer. 369 name: The name of the checkpoint. 370 state_key: The name of the model state to load. Set to None if the model state is stored top-level. 371 device: The device on which to load the model. 372 373 Returns: 374 The model. 375 """ 376 if model is None and os.path.isdir(checkpoint): # Load the model and its state from a torch_em checkpoint. 377 model = get_trainer(checkpoint, name=name, device=device).model 378 379 elif model is None: # Load the model from a serialized model. 380 model = torch.load(checkpoint, map_location=device, weights_only=False) 381 382 else: # Load the model state from a checkpoint. 383 if os.path.isdir(checkpoint): # From a torch_em checkpoint. 384 ckpt = os.path.join(checkpoint, f"{name}.pt") 385 else: # From a serialized path. 386 ckpt = checkpoint 387 388 state = torch.load(ckpt, map_location=device, weights_only=False) 389 if state_key is not None: 390 state = state[state_key] 391 392 # To enable loading compiled models. 393 compiled_prefix = "_orig_mod." 394 state = OrderedDict( 395 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()] 396 ) 397 398 model.load_state_dict(state) 399 if device is not None: 400 model.to(device) 401 402 return model
Load a model from a trainer checkpoint or a serialized torch model.
This function can either load the model directly (model is not passed),
or deserialize the model state and then load it (model is passed).
The checkpoint argument must either point to the checkpoint directory of a torch_em trainer
or to a serialized torch model.
Arguments:
- checkpoint: The path to the checkpoint folder or serialized torch model.
- model: The model for which the state should be loaded. If it is not passed, the model class and parameters will also be loaded from the trainer.
- name: The name of the checkpoint.
- state_key: The name of the model state to load. Set to None if the model state is stored top-level.
- device: The device on which to load the model.
Returns:
The model.
414def get_random_colors(labels: np.ndarray) -> colors.ListedColormap: 415 """Generate a random color map for a label image. 416 417 Args: 418 labels: The labels. 419 420 Returns: 421 The color map. 422 """ 423 unique_labels = np.unique(labels) 424 have_zero = 0 in unique_labels 425 cmap = [[0, 0, 0]] if have_zero else [] 426 cmap += np.random.rand(len(unique_labels), 3).tolist() 427 cmap = colors.ListedColormap(cmap) 428 return cmap
Generate a random color map for a label image.
Arguments:
- labels: The labels.
Returns:
The color map.