torch_em.util.util

  1import os
  2import warnings
  3from collections import OrderedDict
  4
  5import numpy as np
  6import torch
  7import torch_em
  8import matplotlib.pyplot as plt
  9from matplotlib import colors
 10
 11# this is a fairly brittle way to check if a module is compiled.
 12# would be good to find a better solution, ideall something like
 13# model.is_compiled()
 14try:
 15    from torch._dynamo.eval_frame import OptimizedModule
 16except ImportError:
 17    OptimizedModule = None
 18
 19# torch doesn't support most unsigned types,
 20# so we map them to their signed equivalent
 21DTYPE_MAP = {
 22    np.dtype("uint16"): np.int16,
 23    np.dtype("uint32"): np.int32,
 24    np.dtype("uint64"): np.int64
 25}
 26
 27
 28def is_compiled(model):
 29    if OptimizedModule is None:
 30        return False
 31    return isinstance(model, OptimizedModule)
 32
 33
 34def auto_compile(model, compile_model, default_compile=True):
 35    """Model compilation for pytorch >= 2
 36
 37    Parameters:
 38        model [torch.nn.Module] - the model
 39        compile_model [None, bool, str] - whether to comile the model.
 40            If None, it will not be compiled for torch < 2, and for torch > 2 the behavior
 41            specificed by 'default_compile' will be used. If a string is given it will be
 42            intepreted as the compile mode (torch.compile(model, mode=compile_model)) (default: None)
 43        default_compile [bool] - the default compilation behavior for torch 2
 44    """
 45    torch_major = int(torch.__version__.split(".")[0])
 46
 47    if compile_model is None:
 48        if torch_major < 2:
 49            compile_model = False
 50        elif is_compiled(model):  # model is already compiled
 51            compile_model = False
 52        else:
 53            compile_model = default_compile
 54
 55    if compile_model:
 56        if torch_major < 2:
 57            raise RuntimeError("Model compilation is only supported for pytorch 2")
 58        print("Compiling pytorch model ...")
 59        if isinstance(compile_model, str):
 60            model = torch.compile(model, mode=compile_model)
 61        else:
 62            model = torch.compile(model)
 63
 64    return model
 65
 66
 67def ensure_tensor(tensor, dtype=None):
 68
 69    if isinstance(tensor, np.ndarray):
 70        if np.dtype(tensor.dtype) in DTYPE_MAP:
 71            tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 72        tensor = torch.from_numpy(tensor)
 73
 74    assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch"
 75    if dtype is not None:
 76        tensor = tensor.to(dtype=dtype)
 77    return tensor
 78
 79
 80def ensure_tensor_with_channels(tensor, ndim, dtype=None):
 81    assert ndim in (2, 3, 4), f"{ndim}"
 82    tensor = ensure_tensor(tensor, dtype)
 83    if ndim == 2:
 84        assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}"
 85        if tensor.ndim == 2:
 86            tensor = tensor[None]
 87        elif tensor.ndim == 4:
 88            assert tensor.shape[0] == 1, f"{tensor.shape}"
 89            tensor = tensor[0]
 90        elif tensor.ndim == 5:
 91            assert tensor.shape[:2] == (1, 1), f"{tensor.shape}"
 92            tensor = tensor[0, 0]
 93    elif ndim == 3:
 94        assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}"
 95        if tensor.ndim == 3:
 96            tensor = tensor[None]
 97        elif tensor.ndim == 5:
 98            assert tensor.shape[0] == 1, f"{tensor.shape}"
 99            tensor = tensor[0]
100    else:
101        assert tensor.ndim in (4, 5), f"{tensor.ndim}"
102        if tensor.ndim == 5:
103            assert tensor.shape[0] == 1, f"{tensor.shape}"
104            tensor = tensor[0]
105    return tensor
106
107
108def ensure_array(array, dtype=None):
109    if torch.is_tensor(array):
110        array = array.detach().cpu().numpy()
111    assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy"
112    if dtype is not None:
113        array = np.require(array, dtype=dtype)
114    return array
115
116
117def ensure_spatial_array(array, ndim, dtype=None):
118    assert ndim in (2, 3)
119    array = ensure_array(array, dtype)
120    if ndim == 2:
121        assert array.ndim in (2, 3, 4, 5), str(array.ndim)
122        if array.ndim == 3:
123            assert array.shape[0] == 1
124            array = array[0]
125        elif array.ndim == 4:
126            assert array.shape[:2] == (1, 1)
127            array = array[0, 0]
128        elif array.ndim == 5:
129            assert array.shape[:3] == (1, 1, 1)
130            array = array[0, 0, 0]
131    else:
132        assert array.ndim in (3, 4, 5), str(array.ndim)
133        if array.ndim == 4:
134            assert array.shape[0] == 1, f"{array.shape}"
135            array = array[0]
136        elif array.ndim == 5:
137            assert array.shape[:2] == (1, 1)
138            array = array[0, 0]
139    return array
140
141
142def get_constructor_arguments(obj):
143
144    # all relevant torch_em classes have 'init_kwargs' to
145    # directly recover the init call
146    if hasattr(obj, "init_kwargs"):
147        return getattr(obj, "init_kwargs")
148
149    def _get_args(obj, param_names):
150        return {name: getattr(obj, name) for name in param_names}
151
152    # we don't need to find the constructor arguments for optimizers or schedulers
153    # because we deserialize the state later
154    if isinstance(obj, (torch.optim.Optimizer,
155                        torch.optim.lr_scheduler._LRScheduler,
156                        # ReduceLROnPlateau does not inherit from _LRScheduler
157                        torch.optim.lr_scheduler.ReduceLROnPlateau)):
158        return {}
159
160    # recover the arguments for torch dataloader
161    elif isinstance(obj, torch.utils.data.DataLoader):
162        # These are all the "simple" arguements.
163        # "sampler", "batch_sampler" and "worker_init_fn" are more complicated
164        # and generally not used in torch_em
165        return _get_args(obj, ["batch_size", "shuffle", "num_workers",
166                               "pin_memory", "drop_last", "persistent_workers",
167                               "prefetch_factor", "timeout"])
168
169    # TODO support common torch losses (e.g. CrossEntropy, BCE)
170
171    warnings.warn(
172        f"Constructor arguments for {type(obj)} cannot be deduced." +
173        "For this object, empty constructor arguments will be used." +
174        "Hence, the trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'."
175    )
176    return {}
177
178
179def get_trainer(checkpoint, name="best", device=None):
180    """Load trainer from a checkpoint.
181    """
182    # try to load from file
183    if isinstance(checkpoint, str):
184        assert os.path.exists(checkpoint), checkpoint
185        trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint,
186                                                                  name=name,
187                                                                  device=device)
188    else:
189        trainer = checkpoint
190    assert isinstance(trainer, torch_em.trainer.DefaultTrainer)
191    return trainer
192
193
194def get_normalizer(trainer):
195    dataset = trainer.train_loader.dataset
196    while (
197        isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset) or
198        isinstance(dataset, torch.utils.data.dataset.ConcatDataset)
199    ):
200        dataset = dataset.datasets[0]
201
202    if isinstance(dataset, torch.utils.data.dataset.Subset):
203        dataset = dataset.dataset
204
205    preprocessor = dataset.raw_transform
206
207    if hasattr(preprocessor, "normalizer"):
208        return preprocessor.normalizer
209    else:
210        return preprocessor
211
212
213def load_model(checkpoint, model=None, name="best", state_key="model_state", device=None):
214    """Convenience function to load a model from a trainer checkpoint.
215
216    This function can either load the model directly from the trainer (model is not passed),
217    or deserialize the model state from the trainer and load the model state (model is passed).
218
219    Parameters:
220        checkpoint [str] - path to the checkpoint folder.
221        model [torch.nn.Module] - the model for which the state should be loaded.
222            If it is not passed the model class and parameters will also be loaded from the trainer. (default: None)
223        name [str] - the name of the checkpoint. (default: "best")
224        state_key [str] - the name of the model state to load. (default: "model_state")
225        device [torch.device] - the device on which to load the model. (default: None)
226    """
227    if model is None:  # load the model and its state from the checkpoint
228        model = get_trainer(checkpoint, name=name, device=device).model
229
230    else:  # load the model state from the checkpoint
231        ckpt = os.path.join(checkpoint, f"{name}.pt")
232        state = torch.load(ckpt, map_location=device)[state_key]
233        # to enable loading compiled models
234        compiled_prefix = "_orig_mod."
235        state = OrderedDict(
236            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()]
237        )
238        model.load_state_dict(state)
239        if device is not None:
240            model.to(device)
241        model.load_state_dict(state)
242
243    return model
244
245
246def model_is_equal(model1, model2):
247    for p1, p2 in zip(model1.parameters(), model2.parameters()):
248        if p1.data.ne(p2.data).sum() > 0:
249            return False
250    return True
251
252
253
254def get_random_colors(labels):
255    """Function to generate a random color map for a label image
256    """
257    n_labels = len(np.unique(labels)) - 1
258    cmap = [[0, 0, 0]] + np.random.rand(n_labels, 3).tolist()
259    cmap = colors.ListedColormap(cmap)
260    return cmap
DTYPE_MAP = {dtype('uint16'): <class 'numpy.int16'>, dtype('uint32'): <class 'numpy.int32'>, dtype('uint64'): <class 'numpy.int64'>}
def is_compiled(model):
29def is_compiled(model):
30    if OptimizedModule is None:
31        return False
32    return isinstance(model, OptimizedModule)
def auto_compile(model, compile_model, default_compile=True):
35def auto_compile(model, compile_model, default_compile=True):
36    """Model compilation for pytorch >= 2
37
38    Parameters:
39        model [torch.nn.Module] - the model
40        compile_model [None, bool, str] - whether to comile the model.
41            If None, it will not be compiled for torch < 2, and for torch > 2 the behavior
42            specificed by 'default_compile' will be used. If a string is given it will be
43            intepreted as the compile mode (torch.compile(model, mode=compile_model)) (default: None)
44        default_compile [bool] - the default compilation behavior for torch 2
45    """
46    torch_major = int(torch.__version__.split(".")[0])
47
48    if compile_model is None:
49        if torch_major < 2:
50            compile_model = False
51        elif is_compiled(model):  # model is already compiled
52            compile_model = False
53        else:
54            compile_model = default_compile
55
56    if compile_model:
57        if torch_major < 2:
58            raise RuntimeError("Model compilation is only supported for pytorch 2")
59        print("Compiling pytorch model ...")
60        if isinstance(compile_model, str):
61            model = torch.compile(model, mode=compile_model)
62        else:
63            model = torch.compile(model)
64
65    return model

Model compilation for pytorch >= 2

Arguments:
  • model [torch.nn.Module] - the model
  • compile_model [None, bool, str] - whether to comile the model. If None, it will not be compiled for torch < 2, and for torch > 2 the behavior specificed by 'default_compile' will be used. If a string is given it will be intepreted as the compile mode (torch.compile(model, mode=compile_model)) (default: None)
  • default_compile [bool] - the default compilation behavior for torch 2
def ensure_tensor(tensor, dtype=None):
68def ensure_tensor(tensor, dtype=None):
69
70    if isinstance(tensor, np.ndarray):
71        if np.dtype(tensor.dtype) in DTYPE_MAP:
72            tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
73        tensor = torch.from_numpy(tensor)
74
75    assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch"
76    if dtype is not None:
77        tensor = tensor.to(dtype=dtype)
78    return tensor
def ensure_tensor_with_channels(tensor, ndim, dtype=None):
 81def ensure_tensor_with_channels(tensor, ndim, dtype=None):
 82    assert ndim in (2, 3, 4), f"{ndim}"
 83    tensor = ensure_tensor(tensor, dtype)
 84    if ndim == 2:
 85        assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}"
 86        if tensor.ndim == 2:
 87            tensor = tensor[None]
 88        elif tensor.ndim == 4:
 89            assert tensor.shape[0] == 1, f"{tensor.shape}"
 90            tensor = tensor[0]
 91        elif tensor.ndim == 5:
 92            assert tensor.shape[:2] == (1, 1), f"{tensor.shape}"
 93            tensor = tensor[0, 0]
 94    elif ndim == 3:
 95        assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}"
 96        if tensor.ndim == 3:
 97            tensor = tensor[None]
 98        elif tensor.ndim == 5:
 99            assert tensor.shape[0] == 1, f"{tensor.shape}"
100            tensor = tensor[0]
101    else:
102        assert tensor.ndim in (4, 5), f"{tensor.ndim}"
103        if tensor.ndim == 5:
104            assert tensor.shape[0] == 1, f"{tensor.shape}"
105            tensor = tensor[0]
106    return tensor
def ensure_array(array, dtype=None):
109def ensure_array(array, dtype=None):
110    if torch.is_tensor(array):
111        array = array.detach().cpu().numpy()
112    assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy"
113    if dtype is not None:
114        array = np.require(array, dtype=dtype)
115    return array
def ensure_spatial_array(array, ndim, dtype=None):
118def ensure_spatial_array(array, ndim, dtype=None):
119    assert ndim in (2, 3)
120    array = ensure_array(array, dtype)
121    if ndim == 2:
122        assert array.ndim in (2, 3, 4, 5), str(array.ndim)
123        if array.ndim == 3:
124            assert array.shape[0] == 1
125            array = array[0]
126        elif array.ndim == 4:
127            assert array.shape[:2] == (1, 1)
128            array = array[0, 0]
129        elif array.ndim == 5:
130            assert array.shape[:3] == (1, 1, 1)
131            array = array[0, 0, 0]
132    else:
133        assert array.ndim in (3, 4, 5), str(array.ndim)
134        if array.ndim == 4:
135            assert array.shape[0] == 1, f"{array.shape}"
136            array = array[0]
137        elif array.ndim == 5:
138            assert array.shape[:2] == (1, 1)
139            array = array[0, 0]
140    return array
def get_constructor_arguments(obj):
143def get_constructor_arguments(obj):
144
145    # all relevant torch_em classes have 'init_kwargs' to
146    # directly recover the init call
147    if hasattr(obj, "init_kwargs"):
148        return getattr(obj, "init_kwargs")
149
150    def _get_args(obj, param_names):
151        return {name: getattr(obj, name) for name in param_names}
152
153    # we don't need to find the constructor arguments for optimizers or schedulers
154    # because we deserialize the state later
155    if isinstance(obj, (torch.optim.Optimizer,
156                        torch.optim.lr_scheduler._LRScheduler,
157                        # ReduceLROnPlateau does not inherit from _LRScheduler
158                        torch.optim.lr_scheduler.ReduceLROnPlateau)):
159        return {}
160
161    # recover the arguments for torch dataloader
162    elif isinstance(obj, torch.utils.data.DataLoader):
163        # These are all the "simple" arguements.
164        # "sampler", "batch_sampler" and "worker_init_fn" are more complicated
165        # and generally not used in torch_em
166        return _get_args(obj, ["batch_size", "shuffle", "num_workers",
167                               "pin_memory", "drop_last", "persistent_workers",
168                               "prefetch_factor", "timeout"])
169
170    # TODO support common torch losses (e.g. CrossEntropy, BCE)
171
172    warnings.warn(
173        f"Constructor arguments for {type(obj)} cannot be deduced." +
174        "For this object, empty constructor arguments will be used." +
175        "Hence, the trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'."
176    )
177    return {}
def get_trainer(checkpoint, name='best', device=None):
180def get_trainer(checkpoint, name="best", device=None):
181    """Load trainer from a checkpoint.
182    """
183    # try to load from file
184    if isinstance(checkpoint, str):
185        assert os.path.exists(checkpoint), checkpoint
186        trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint,
187                                                                  name=name,
188                                                                  device=device)
189    else:
190        trainer = checkpoint
191    assert isinstance(trainer, torch_em.trainer.DefaultTrainer)
192    return trainer

Load trainer from a checkpoint.

def get_normalizer(trainer):
195def get_normalizer(trainer):
196    dataset = trainer.train_loader.dataset
197    while (
198        isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset) or
199        isinstance(dataset, torch.utils.data.dataset.ConcatDataset)
200    ):
201        dataset = dataset.datasets[0]
202
203    if isinstance(dataset, torch.utils.data.dataset.Subset):
204        dataset = dataset.dataset
205
206    preprocessor = dataset.raw_transform
207
208    if hasattr(preprocessor, "normalizer"):
209        return preprocessor.normalizer
210    else:
211        return preprocessor
def load_model( checkpoint, model=None, name='best', state_key='model_state', device=None):
214def load_model(checkpoint, model=None, name="best", state_key="model_state", device=None):
215    """Convenience function to load a model from a trainer checkpoint.
216
217    This function can either load the model directly from the trainer (model is not passed),
218    or deserialize the model state from the trainer and load the model state (model is passed).
219
220    Parameters:
221        checkpoint [str] - path to the checkpoint folder.
222        model [torch.nn.Module] - the model for which the state should be loaded.
223            If it is not passed the model class and parameters will also be loaded from the trainer. (default: None)
224        name [str] - the name of the checkpoint. (default: "best")
225        state_key [str] - the name of the model state to load. (default: "model_state")
226        device [torch.device] - the device on which to load the model. (default: None)
227    """
228    if model is None:  # load the model and its state from the checkpoint
229        model = get_trainer(checkpoint, name=name, device=device).model
230
231    else:  # load the model state from the checkpoint
232        ckpt = os.path.join(checkpoint, f"{name}.pt")
233        state = torch.load(ckpt, map_location=device)[state_key]
234        # to enable loading compiled models
235        compiled_prefix = "_orig_mod."
236        state = OrderedDict(
237            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()]
238        )
239        model.load_state_dict(state)
240        if device is not None:
241            model.to(device)
242        model.load_state_dict(state)
243
244    return model

Convenience function to load a model from a trainer checkpoint.

This function can either load the model directly from the trainer (model is not passed), or deserialize the model state from the trainer and load the model state (model is passed).

Arguments:
  • checkpoint [str] - path to the checkpoint folder.
  • model [torch.nn.Module] - the model for which the state should be loaded. If it is not passed the model class and parameters will also be loaded from the trainer. (default: None)
  • name [str] - the name of the checkpoint. (default: "best")
  • state_key [str] - the name of the model state to load. (default: "model_state")
  • device [torch.device] - the device on which to load the model. (default: None)
def model_is_equal(model1, model2):
247def model_is_equal(model1, model2):
248    for p1, p2 in zip(model1.parameters(), model2.parameters()):
249        if p1.data.ne(p2.data).sum() > 0:
250            return False
251    return True
def get_random_colors(labels):
255def get_random_colors(labels):
256    """Function to generate a random color map for a label image
257    """
258    n_labels = len(np.unique(labels)) - 1
259    cmap = [[0, 0, 0]] + np.random.rand(n_labels, 3).tolist()
260    cmap = colors.ListedColormap(cmap)
261    return cmap

Function to generate a random color map for a label image