
  1import os
  2import warnings
  3from collections import OrderedDict
  4from typing import Optional, Tuple, Union
  6import numpy as np
  7import torch
  8import torch_em
  9from matplotlib import colors
 10from numpy.typing import ArrayLike
 13    from torch._dynamo.eval_frame import OptimizedModule
 14except ImportError:
 15    OptimizedModule = None
 17# torch doesn't support most unsigned types,
 18# so we map them to their signed equivalent
 19DTYPE_MAP = {
 20    np.dtype("uint16"): np.int16,
 21    np.dtype("uint32"): np.int32,
 22    np.dtype("uint64"): np.int64
 28# This is a fairly brittle way to check if a module is compiled.
 29# Would be good to find a better solution, ideall something like model.is_compiled().
 30def is_compiled(model):
 31    """@private
 32    """
 33    if OptimizedModule is None:
 34        return False
 35    return isinstance(model, OptimizedModule)
 38def auto_compile(
 39    model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True
 40) -> torch.nn.Module:
 41    """Automatically compile a model for pytorch >= 2.
 43    Args:
 44        model: The model.
 45        compile_model: Whether to comile the model.
 46            If None, it will not be compiled for torch < 2, and for torch > 2 the behavior
 47            specificed by 'default_compile' will be used. If a string is given it will be
 48            intepreted as the compile mode (torch.compile(model, mode=compile_model))
 49        default_compile: Whether to use the default compilation behavior for torch 2.
 51    Returns:
 52        The compiled model.
 53    """
 54    torch_major = int(torch.__version__.split(".")[0])
 56    if compile_model is None:
 57        if torch_major < 2:
 58            compile_model = False
 59        elif is_compiled(model):  # model is already compiled
 60            compile_model = False
 61        else:
 62            compile_model = default_compile
 64    if compile_model:
 65        if torch_major < 2:
 66            raise RuntimeError("Model compilation is only supported for pytorch 2")
 68        print("Compiling pytorch model ...")
 69        if isinstance(compile_model, str):
 70            model = torch.compile(model, mode=compile_model)
 71        else:
 72            model = torch.compile(model)
 74    return model
 77def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor:
 78    """Ensure that the input is a torch tensor, by converting it if necessary.
 80    Args:
 81        tensor: The input object, either a torch tensor or a numpy-array like object.
 82        dtype: The required data type for the output tensor.
 84    Returns:
 85        The input, converted to a torch tensor if necessary.
 86    """
 87    if isinstance(tensor, np.ndarray):
 88        if np.dtype(tensor.dtype) in DTYPE_MAP:
 89            tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 90        # Try to convert the tensor, even if it has wrong byte-order
 91        try:
 92            tensor = torch.from_numpy(tensor)
 93        except ValueError:
 94            tensor = tensor.view(tensor.dtype.newbyteorder())
 95            if np.dtype(tensor.dtype) in DTYPE_MAP:
 96                tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 97            tensor = torch.from_numpy(tensor)
 99    assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch"
100    if dtype is not None:
101        tensor =
102    return tensor
105def ensure_tensor_with_channels(
106    tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None
107) -> torch.Tensor:
108    """Ensure that the input is a torch tensor of a given dimensionality with channels.
110    Args:
111        tensor: The input tensor or numpy-array like data.
112        ndim: The dimensionality of the output tensor.
113        dtype: The data type of the output tensor.
115    Returns:
116        The input converted to a torch tensor of the requested dimensionality.
117    """
118    assert ndim in (2, 3, 4), f"{ndim}"
119    tensor = ensure_tensor(tensor, dtype)
120    if ndim == 2:
121        assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}"
122        if tensor.ndim == 2:
123            tensor = tensor[None]
124        elif tensor.ndim == 4:
125            assert tensor.shape[0] == 1, f"{tensor.shape}"
126            tensor = tensor[0]
127        elif tensor.ndim == 5:
128            assert tensor.shape[:2] == (1, 1), f"{tensor.shape}"
129            tensor = tensor[0, 0]
130    elif ndim == 3:
131        assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}"
132        if tensor.ndim == 3:
133            tensor = tensor[None]
134        elif tensor.ndim == 5:
135            assert tensor.shape[0] == 1, f"{tensor.shape}"
136            tensor = tensor[0]
137    else:
138        assert tensor.ndim in (4, 5), f"{tensor.ndim}"
139        if tensor.ndim == 5:
140            assert tensor.shape[0] == 1, f"{tensor.shape}"
141            tensor = tensor[0]
142    return tensor
145def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray:
146    """Ensure that the input is a numpy array, by converting it if necessary.
148    Args:
149        array: The input torch tensor or numpy array.
150        dtype: The dtype of the ouptut array.
152    Returns:
153        The input converted to a numpy array if necessary.
154    """
155    if torch.is_tensor(array):
156        array = array.detach().cpu().numpy()
157    assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy"
158    if dtype is not None:
159        array = np.require(array, dtype=dtype)
160    return array
163def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray:
164    """Ensure that the input is a numpy array of a given dimensionality.
166    Args:
167        array: The input numpy array or torch tensor.
168        ndim: The requested dimensionality.
169        dtype: The dtype of the output array.
171    Returns:
172        A numpy array of the requested dimensionality and data type.
173    """
174    assert ndim in (2, 3)
175    array = ensure_array(array, dtype)
176    if ndim == 2:
177        assert array.ndim in (2, 3, 4, 5), str(array.ndim)
178        if array.ndim == 3:
179            assert array.shape[0] == 1
180            array = array[0]
181        elif array.ndim == 4:
182            assert array.shape[:2] == (1, 1)
183            array = array[0, 0]
184        elif array.ndim == 5:
185            assert array.shape[:3] == (1, 1, 1)
186            array = array[0, 0, 0]
187    else:
188        assert array.ndim in (3, 4, 5), str(array.ndim)
189        if array.ndim == 4:
190            assert array.shape[0] == 1, f"{array.shape}"
191            array = array[0]
192        elif array.ndim == 5:
193            assert array.shape[:2] == (1, 1)
194            array = array[0, 0]
195    return array
198def ensure_patch_shape(
199    raw: np.ndarray,
200    labels: Optional[np.ndarray],
201    patch_shape: Tuple[int, ...],
202    have_raw_channels: bool = False,
203    have_label_channels: bool = False,
204    channel_first: bool = True,
205) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
206    """Ensure that the raw data and labels have at least the requested patch shape.
208    If either raw data or labels do not have the patch shape they will be padded.
210    Args:
211        raw: The input raw data.
212        labels: The input labels.
213        patch_shape: The required minimal patch shape.
214        have_raw_channels: Whether the raw data has channels.
215        have_label_channels: Whether the label data has channels.
216        channel_first: Whether the channel axis is the first or last axis.
218    Returns:
219        The raw data.
220        The labels.
221    """
222    raw_shape = raw.shape
223    if labels is not None:
224        labels_shape = labels.shape
226    # In case the inputs has channels and they are channels first
227    # IMPORTANT: for ImageCollectionDataset
228    if have_raw_channels and channel_first:
229        raw_shape = raw_shape[1:]
231    if have_label_channels and channel_first and labels is not None:
232        labels_shape = labels_shape[1:]
234    # Extract the pad_width and pad the raw inputs
235    if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)):
236        pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)]
238        if have_raw_channels and channel_first:
239            pad_width = [(0, 0), *pw]
240        elif have_raw_channels and not channel_first:
241            pad_width = [*pw, (0, 0)]
242        else:
243            pad_width = pw
245        raw = np.pad(array=raw, pad_width=pad_width)
247    # Extract the pad width and pad the label inputs
248    if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)):
249        pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)]
251        if have_label_channels and channel_first:
252            pad_width = [(0, 0), *pw]
253        elif have_label_channels and not channel_first:
254            pad_width = [*pw, (0, 0)]
255        else:
256            pad_width = pw
258        labels = np.pad(array=labels, pad_width=pad_width)
259    if labels is None:
260        return raw
261    else:
262        return raw, labels
265def get_constructor_arguments(obj):
266    """@private
267    """
268    # All relevant torch_em classes have 'init_kwargs' to directly recover the init call.
269    if hasattr(obj, "init_kwargs"):
270        return getattr(obj, "init_kwargs")
272    def _get_args(obj, param_names):
273        return {name: getattr(obj, name) for name in param_names}
275    # We don't need to find the constructor arguments for optimizers/schedulers because we deserialize the state later.
276    if isinstance(
277        obj, (
278            torch.optim.Optimizer,
279            torch.optim.lr_scheduler._LRScheduler,
280            # ReduceLROnPlateau does not inherit from _LRScheduler
281            torch.optim.lr_scheduler.ReduceLROnPlateau
282        )
283    ):
284        return {}
286    # recover the arguments for torch dataloader
287    elif isinstance(obj,
288        # These are all the "simple" arguements.
289        # "sampler", "batch_sampler" and "worker_init_fn" are more complicated
290        # and generally not used in torch_em
291        return _get_args(
292            obj, [
293                "batch_size", "shuffle", "num_workers", "pin_memory", "drop_last",
294                "persistent_workers", "prefetch_factor", "timeout"
295            ]
296        )
298    # TODO support common torch losses (e.g. CrossEntropy, BCE)
299    warnings.warn(
300        f"Constructor arguments for {type(obj)} cannot be deduced.\n" +
301        "For this object, empty constructor arguments will be used.\n" +
302        "The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'."
303    )
304    return {}
307def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None):
308    """Load trainer from a checkpoint.
310    Args:
311        checkpoint: The path to the checkpoint.
312        name: The name of the checkpoint.
313        device: The device to use for loading the checkpoint.
315    Returns:
316        The trainer.
317    """
318    # try to load from file
319    if isinstance(checkpoint, str):
320        assert os.path.exists(checkpoint), checkpoint
321        trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device)
322    else:
323        trainer = checkpoint
324    assert isinstance(trainer, torch_em.trainer.DefaultTrainer)
325    return trainer
328def get_normalizer(trainer):
329    """@private
330    """
331    dataset = trainer.train_loader.dataset
332    while (
333        isinstance(dataset, or
334        isinstance(dataset,
335    ):
336        dataset = dataset.datasets[0]
338    if isinstance(dataset,
339        dataset = dataset.dataset
341    preprocessor = dataset.raw_transform
343    if hasattr(preprocessor, "normalizer"):
344        return preprocessor.normalizer
345    else:
346        return preprocessor
349def load_model(
350    checkpoint: str,
351    model: Optional[torch.nn.Module] = None,
352    name: str = "best",
353    state_key: str = "model_state",
354    device: Optional[str] = None,
355) -> torch.nn.Module:
356    """Load model from a trainer checkpoint.
358    This function can either load the model directly from the trainer (model is not passed),
359    or deserialize the model state from the trainer and load the model state (model is passed).
361    Args:
362        checkpoint: The path to the checkpoint folder.
363        model: The model for which the state should be loaded.
364            If it is not passed, the model class and parameters will also be loaded from the trainer.
365        name: The name of the checkpoint.
366        state_key: The name of the model state to load.
367        device: The device on which to load the model.
369    Returns:
370        The model.
371    """
372    if model is None:  # load the model and its state from the checkpoint
373        model = get_trainer(checkpoint, name=name, device=device).model
375    else:  # load the model state from the checkpoint
376        if os.path.isdir(checkpoint):
377            ckpt = os.path.join(checkpoint, f"{name}.pt")
378        else:
379            ckpt = checkpoint
381        state = torch.load(ckpt, map_location=device, weights_only=False)[state_key]
382        # to enable loading compiled models
383        compiled_prefix = "_orig_mod."
384        state = OrderedDict(
385            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()]
386        )
387        model.load_state_dict(state)
388        if device is not None:
390        model.load_state_dict(state)
392    return model
395def model_is_equal(model1, model2):
396    """@private
397    """
398    for p1, p2 in zip(model1.parameters(), model2.parameters()):
399        if > 0:
400            return False
401    return True
404def get_random_colors(labels: np.ndarray) -> colors.ListedColormap:
405    """Generate a random color map for a label image.
407    Args:
408        labels: The labels.
410    Returns:
411        The color map.
412    """
413    unique_labels = np.unique(labels)
414    have_zero = 0 in unique_labels
415    cmap = [[0, 0, 0]] if have_zero else []
416    cmap += np.random.rand(len(unique_labels), 3).tolist()
417    cmap = colors.ListedColormap(cmap)
418    return cmap
