torch_em.util.util
1import os 2import warnings 3from collections import OrderedDict 4 5import numpy as np 6import torch 7import torch_em 8import matplotlib.pyplot as plt 9from matplotlib import colors 10 11# this is a fairly brittle way to check if a module is compiled. 12# would be good to find a better solution, ideall something like 13# model.is_compiled() 14try: 15 from torch._dynamo.eval_frame import OptimizedModule 16except ImportError: 17 OptimizedModule = None 18 19# torch doesn't support most unsigned types, 20# so we map them to their signed equivalent 21DTYPE_MAP = { 22 np.dtype("uint16"): np.int16, 23 np.dtype("uint32"): np.int32, 24 np.dtype("uint64"): np.int64 25} 26 27 28def is_compiled(model): 29 if OptimizedModule is None: 30 return False 31 return isinstance(model, OptimizedModule) 32 33 34def auto_compile(model, compile_model, default_compile=True): 35 """Model compilation for pytorch >= 2 36 37 Parameters: 38 model [torch.nn.Module] - the model 39 compile_model [None, bool, str] - whether to comile the model. 40 If None, it will not be compiled for torch < 2, and for torch > 2 the behavior 41 specificed by 'default_compile' will be used. If a string is given it will be 42 intepreted as the compile mode (torch.compile(model, mode=compile_model)) (default: None) 43 default_compile [bool] - the default compilation behavior for torch 2 44 """ 45 torch_major = int(torch.__version__.split(".")[0]) 46 47 if compile_model is None: 48 if torch_major < 2: 49 compile_model = False 50 elif is_compiled(model): # model is already compiled 51 compile_model = False 52 else: 53 compile_model = default_compile 54 55 if compile_model: 56 if torch_major < 2: 57 raise RuntimeError("Model compilation is only supported for pytorch 2") 58 print("Compiling pytorch model ...") 59 if isinstance(compile_model, str): 60 model = torch.compile(model, mode=compile_model) 61 else: 62 model = torch.compile(model) 63 64 return model 65 66 67def ensure_tensor(tensor, dtype=None): 68 69 if isinstance(tensor, np.ndarray): 70 if np.dtype(tensor.dtype) in DTYPE_MAP: 71 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 72 tensor = torch.from_numpy(tensor) 73 74 assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch" 75 if dtype is not None: 76 tensor = tensor.to(dtype=dtype) 77 return tensor 78 79 80def ensure_tensor_with_channels(tensor, ndim, dtype=None): 81 assert ndim in (2, 3, 4), f"{ndim}" 82 tensor = ensure_tensor(tensor, dtype) 83 if ndim == 2: 84 assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}" 85 if tensor.ndim == 2: 86 tensor = tensor[None] 87 elif tensor.ndim == 4: 88 assert tensor.shape[0] == 1, f"{tensor.shape}" 89 tensor = tensor[0] 90 elif tensor.ndim == 5: 91 assert tensor.shape[:2] == (1, 1), f"{tensor.shape}" 92 tensor = tensor[0, 0] 93 elif ndim == 3: 94 assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}" 95 if tensor.ndim == 3: 96 tensor = tensor[None] 97 elif tensor.ndim == 5: 98 assert tensor.shape[0] == 1, f"{tensor.shape}" 99 tensor = tensor[0] 100 else: 101 assert tensor.ndim in (4, 5), f"{tensor.ndim}" 102 if tensor.ndim == 5: 103 assert tensor.shape[0] == 1, f"{tensor.shape}" 104 tensor = tensor[0] 105 return tensor 106 107 108def ensure_array(array, dtype=None): 109 if torch.is_tensor(array): 110 array = array.detach().cpu().numpy() 111 assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy" 112 if dtype is not None: 113 array = np.require(array, dtype=dtype) 114 return array 115 116 117def ensure_spatial_array(array, ndim, dtype=None): 118 assert ndim in (2, 3) 119 array = ensure_array(array, dtype) 120 if ndim == 2: 121 assert array.ndim in (2, 3, 4, 5), str(array.ndim) 122 if array.ndim == 3: 123 assert array.shape[0] == 1 124 array = array[0] 125 elif array.ndim == 4: 126 assert array.shape[:2] == (1, 1) 127 array = array[0, 0] 128 elif array.ndim == 5: 129 assert array.shape[:3] == (1, 1, 1) 130 array = array[0, 0, 0] 131 else: 132 assert array.ndim in (3, 4, 5), str(array.ndim) 133 if array.ndim == 4: 134 assert array.shape[0] == 1, f"{array.shape}" 135 array = array[0] 136 elif array.ndim == 5: 137 assert array.shape[:2] == (1, 1) 138 array = array[0, 0] 139 return array 140 141 142def get_constructor_arguments(obj): 143 144 # all relevant torch_em classes have 'init_kwargs' to 145 # directly recover the init call 146 if hasattr(obj, "init_kwargs"): 147 return getattr(obj, "init_kwargs") 148 149 def _get_args(obj, param_names): 150 return {name: getattr(obj, name) for name in param_names} 151 152 # we don't need to find the constructor arguments for optimizers or schedulers 153 # because we deserialize the state later 154 if isinstance(obj, (torch.optim.Optimizer, 155 torch.optim.lr_scheduler._LRScheduler, 156 # ReduceLROnPlateau does not inherit from _LRScheduler 157 torch.optim.lr_scheduler.ReduceLROnPlateau)): 158 return {} 159 160 # recover the arguments for torch dataloader 161 elif isinstance(obj, torch.utils.data.DataLoader): 162 # These are all the "simple" arguements. 163 # "sampler", "batch_sampler" and "worker_init_fn" are more complicated 164 # and generally not used in torch_em 165 return _get_args(obj, ["batch_size", "shuffle", "num_workers", 166 "pin_memory", "drop_last", "persistent_workers", 167 "prefetch_factor", "timeout"]) 168 169 # TODO support common torch losses (e.g. CrossEntropy, BCE) 170 171 warnings.warn( 172 f"Constructor arguments for {type(obj)} cannot be deduced." + 173 "For this object, empty constructor arguments will be used." + 174 "Hence, the trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'." 175 ) 176 return {} 177 178 179def get_trainer(checkpoint, name="best", device=None): 180 """Load trainer from a checkpoint. 181 """ 182 # try to load from file 183 if isinstance(checkpoint, str): 184 assert os.path.exists(checkpoint), checkpoint 185 trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, 186 name=name, 187 device=device) 188 else: 189 trainer = checkpoint 190 assert isinstance(trainer, torch_em.trainer.DefaultTrainer) 191 return trainer 192 193 194def get_normalizer(trainer): 195 dataset = trainer.train_loader.dataset 196 while ( 197 isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset) or 198 isinstance(dataset, torch.utils.data.dataset.ConcatDataset) 199 ): 200 dataset = dataset.datasets[0] 201 202 if isinstance(dataset, torch.utils.data.dataset.Subset): 203 dataset = dataset.dataset 204 205 preprocessor = dataset.raw_transform 206 207 if hasattr(preprocessor, "normalizer"): 208 return preprocessor.normalizer 209 else: 210 return preprocessor 211 212 213def load_model(checkpoint, model=None, name="best", state_key="model_state", device=None): 214 """Convenience function to load a model from a trainer checkpoint. 215 216 This function can either load the model directly from the trainer (model is not passed), 217 or deserialize the model state from the trainer and load the model state (model is passed). 218 219 Parameters: 220 checkpoint [str] - path to the checkpoint folder. 221 model [torch.nn.Module] - the model for which the state should be loaded. 222 If it is not passed the model class and parameters will also be loaded from the trainer. (default: None) 223 name [str] - the name of the checkpoint. (default: "best") 224 state_key [str] - the name of the model state to load. (default: "model_state") 225 device [torch.device] - the device on which to load the model. (default: None) 226 """ 227 if model is None: # load the model and its state from the checkpoint 228 model = get_trainer(checkpoint, name=name, device=device).model 229 230 else: # load the model state from the checkpoint 231 ckpt = os.path.join(checkpoint, f"{name}.pt") 232 state = torch.load(ckpt, map_location=device)[state_key] 233 # to enable loading compiled models 234 compiled_prefix = "_orig_mod." 235 state = OrderedDict( 236 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()] 237 ) 238 model.load_state_dict(state) 239 if device is not None: 240 model.to(device) 241 model.load_state_dict(state) 242 243 return model 244 245 246def model_is_equal(model1, model2): 247 for p1, p2 in zip(model1.parameters(), model2.parameters()): 248 if p1.data.ne(p2.data).sum() > 0: 249 return False 250 return True 251 252 253 254def get_random_colors(labels): 255 """Function to generate a random color map for a label image 256 """ 257 n_labels = len(np.unique(labels)) - 1 258 cmap = [[0, 0, 0]] + np.random.rand(n_labels, 3).tolist() 259 cmap = colors.ListedColormap(cmap) 260 return cmap
DTYPE_MAP =
{dtype('uint16'): <class 'numpy.int16'>, dtype('uint32'): <class 'numpy.int32'>, dtype('uint64'): <class 'numpy.int64'>}
def
is_compiled(model):
def
auto_compile(model, compile_model, default_compile=True):
35def auto_compile(model, compile_model, default_compile=True): 36 """Model compilation for pytorch >= 2 37 38 Parameters: 39 model [torch.nn.Module] - the model 40 compile_model [None, bool, str] - whether to comile the model. 41 If None, it will not be compiled for torch < 2, and for torch > 2 the behavior 42 specificed by 'default_compile' will be used. If a string is given it will be 43 intepreted as the compile mode (torch.compile(model, mode=compile_model)) (default: None) 44 default_compile [bool] - the default compilation behavior for torch 2 45 """ 46 torch_major = int(torch.__version__.split(".")[0]) 47 48 if compile_model is None: 49 if torch_major < 2: 50 compile_model = False 51 elif is_compiled(model): # model is already compiled 52 compile_model = False 53 else: 54 compile_model = default_compile 55 56 if compile_model: 57 if torch_major < 2: 58 raise RuntimeError("Model compilation is only supported for pytorch 2") 59 print("Compiling pytorch model ...") 60 if isinstance(compile_model, str): 61 model = torch.compile(model, mode=compile_model) 62 else: 63 model = torch.compile(model) 64 65 return model
Model compilation for pytorch >= 2
Arguments:
- model [torch.nn.Module] - the model
- compile_model [None, bool, str] - whether to comile the model. If None, it will not be compiled for torch < 2, and for torch > 2 the behavior specificed by 'default_compile' will be used. If a string is given it will be intepreted as the compile mode (torch.compile(model, mode=compile_model)) (default: None)
- default_compile [bool] - the default compilation behavior for torch 2
def
ensure_tensor(tensor, dtype=None):
68def ensure_tensor(tensor, dtype=None): 69 70 if isinstance(tensor, np.ndarray): 71 if np.dtype(tensor.dtype) in DTYPE_MAP: 72 tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) 73 tensor = torch.from_numpy(tensor) 74 75 assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch" 76 if dtype is not None: 77 tensor = tensor.to(dtype=dtype) 78 return tensor
def
ensure_tensor_with_channels(tensor, ndim, dtype=None):
81def ensure_tensor_with_channels(tensor, ndim, dtype=None): 82 assert ndim in (2, 3, 4), f"{ndim}" 83 tensor = ensure_tensor(tensor, dtype) 84 if ndim == 2: 85 assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}" 86 if tensor.ndim == 2: 87 tensor = tensor[None] 88 elif tensor.ndim == 4: 89 assert tensor.shape[0] == 1, f"{tensor.shape}" 90 tensor = tensor[0] 91 elif tensor.ndim == 5: 92 assert tensor.shape[:2] == (1, 1), f"{tensor.shape}" 93 tensor = tensor[0, 0] 94 elif ndim == 3: 95 assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}" 96 if tensor.ndim == 3: 97 tensor = tensor[None] 98 elif tensor.ndim == 5: 99 assert tensor.shape[0] == 1, f"{tensor.shape}" 100 tensor = tensor[0] 101 else: 102 assert tensor.ndim in (4, 5), f"{tensor.ndim}" 103 if tensor.ndim == 5: 104 assert tensor.shape[0] == 1, f"{tensor.shape}" 105 tensor = tensor[0] 106 return tensor
def
ensure_array(array, dtype=None):
def
ensure_spatial_array(array, ndim, dtype=None):
118def ensure_spatial_array(array, ndim, dtype=None): 119 assert ndim in (2, 3) 120 array = ensure_array(array, dtype) 121 if ndim == 2: 122 assert array.ndim in (2, 3, 4, 5), str(array.ndim) 123 if array.ndim == 3: 124 assert array.shape[0] == 1 125 array = array[0] 126 elif array.ndim == 4: 127 assert array.shape[:2] == (1, 1) 128 array = array[0, 0] 129 elif array.ndim == 5: 130 assert array.shape[:3] == (1, 1, 1) 131 array = array[0, 0, 0] 132 else: 133 assert array.ndim in (3, 4, 5), str(array.ndim) 134 if array.ndim == 4: 135 assert array.shape[0] == 1, f"{array.shape}" 136 array = array[0] 137 elif array.ndim == 5: 138 assert array.shape[:2] == (1, 1) 139 array = array[0, 0] 140 return array
def
get_constructor_arguments(obj):
143def get_constructor_arguments(obj): 144 145 # all relevant torch_em classes have 'init_kwargs' to 146 # directly recover the init call 147 if hasattr(obj, "init_kwargs"): 148 return getattr(obj, "init_kwargs") 149 150 def _get_args(obj, param_names): 151 return {name: getattr(obj, name) for name in param_names} 152 153 # we don't need to find the constructor arguments for optimizers or schedulers 154 # because we deserialize the state later 155 if isinstance(obj, (torch.optim.Optimizer, 156 torch.optim.lr_scheduler._LRScheduler, 157 # ReduceLROnPlateau does not inherit from _LRScheduler 158 torch.optim.lr_scheduler.ReduceLROnPlateau)): 159 return {} 160 161 # recover the arguments for torch dataloader 162 elif isinstance(obj, torch.utils.data.DataLoader): 163 # These are all the "simple" arguements. 164 # "sampler", "batch_sampler" and "worker_init_fn" are more complicated 165 # and generally not used in torch_em 166 return _get_args(obj, ["batch_size", "shuffle", "num_workers", 167 "pin_memory", "drop_last", "persistent_workers", 168 "prefetch_factor", "timeout"]) 169 170 # TODO support common torch losses (e.g. CrossEntropy, BCE) 171 172 warnings.warn( 173 f"Constructor arguments for {type(obj)} cannot be deduced." + 174 "For this object, empty constructor arguments will be used." + 175 "Hence, the trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'." 176 ) 177 return {}
def
get_trainer(checkpoint, name='best', device=None):
180def get_trainer(checkpoint, name="best", device=None): 181 """Load trainer from a checkpoint. 182 """ 183 # try to load from file 184 if isinstance(checkpoint, str): 185 assert os.path.exists(checkpoint), checkpoint 186 trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, 187 name=name, 188 device=device) 189 else: 190 trainer = checkpoint 191 assert isinstance(trainer, torch_em.trainer.DefaultTrainer) 192 return trainer
Load trainer from a checkpoint.
def
get_normalizer(trainer):
195def get_normalizer(trainer): 196 dataset = trainer.train_loader.dataset 197 while ( 198 isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset) or 199 isinstance(dataset, torch.utils.data.dataset.ConcatDataset) 200 ): 201 dataset = dataset.datasets[0] 202 203 if isinstance(dataset, torch.utils.data.dataset.Subset): 204 dataset = dataset.dataset 205 206 preprocessor = dataset.raw_transform 207 208 if hasattr(preprocessor, "normalizer"): 209 return preprocessor.normalizer 210 else: 211 return preprocessor
def
load_model( checkpoint, model=None, name='best', state_key='model_state', device=None):
214def load_model(checkpoint, model=None, name="best", state_key="model_state", device=None): 215 """Convenience function to load a model from a trainer checkpoint. 216 217 This function can either load the model directly from the trainer (model is not passed), 218 or deserialize the model state from the trainer and load the model state (model is passed). 219 220 Parameters: 221 checkpoint [str] - path to the checkpoint folder. 222 model [torch.nn.Module] - the model for which the state should be loaded. 223 If it is not passed the model class and parameters will also be loaded from the trainer. (default: None) 224 name [str] - the name of the checkpoint. (default: "best") 225 state_key [str] - the name of the model state to load. (default: "model_state") 226 device [torch.device] - the device on which to load the model. (default: None) 227 """ 228 if model is None: # load the model and its state from the checkpoint 229 model = get_trainer(checkpoint, name=name, device=device).model 230 231 else: # load the model state from the checkpoint 232 ckpt = os.path.join(checkpoint, f"{name}.pt") 233 state = torch.load(ckpt, map_location=device)[state_key] 234 # to enable loading compiled models 235 compiled_prefix = "_orig_mod." 236 state = OrderedDict( 237 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()] 238 ) 239 model.load_state_dict(state) 240 if device is not None: 241 model.to(device) 242 model.load_state_dict(state) 243 244 return model
Convenience function to load a model from a trainer checkpoint.
This function can either load the model directly from the trainer (model is not passed), or deserialize the model state from the trainer and load the model state (model is passed).
Arguments:
- checkpoint [str] - path to the checkpoint folder.
- model [torch.nn.Module] - the model for which the state should be loaded. If it is not passed the model class and parameters will also be loaded from the trainer. (default: None)
- name [str] - the name of the checkpoint. (default: "best")
- state_key [str] - the name of the model state to load. (default: "model_state")
- device [torch.device] - the device on which to load the model. (default: None)
def
model_is_equal(model1, model2):
def
get_random_colors(labels):
255def get_random_colors(labels): 256 """Function to generate a random color map for a label image 257 """ 258 n_labels = len(np.unique(labels)) - 1 259 cmap = [[0, 0, 0]] + np.random.rand(n_labels, 3).tolist() 260 cmap = colors.ListedColormap(cmap) 261 return cmap
Function to generate a random color map for a label image