torch_em.util.util

  1import os
  2import warnings
  3from collections import OrderedDict
  4from typing import Optional, Tuple, Union
  5
  6import numpy as np
  7import torch
  8import torch_em
  9from matplotlib import colors
 10from numpy.typing import ArrayLike
 11
 12try:
 13    from torch._dynamo.eval_frame import OptimizedModule
 14except ImportError:
 15    OptimizedModule = None
 16
 17# torch doesn't support most unsigned types,
 18# so we map them to their signed equivalent
 19DTYPE_MAP = {
 20    np.dtype("uint16"): np.int16,
 21    np.dtype("uint32"): np.int32,
 22    np.dtype("uint64"): np.int64
 23}
 24"""@private
 25"""
 26
 27
 28# This is a fairly brittle way to check if a module is compiled.
 29# Would be good to find a better solution, ideall something like model.is_compiled().
 30def is_compiled(model):
 31    """@private
 32    """
 33    if OptimizedModule is None:
 34        return False
 35    return isinstance(model, OptimizedModule)
 36
 37
 38def auto_compile(
 39    model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True
 40) -> torch.nn.Module:
 41    """Automatically compile a model for pytorch >= 2.
 42
 43    Args:
 44        model: The model.
 45        compile_model: Whether to comile the model.
 46            If None, it will not be compiled for torch < 2, and for torch > 2 the behavior
 47            specificed by 'default_compile' will be used. If a string is given it will be
 48            intepreted as the compile mode (torch.compile(model, mode=compile_model))
 49        default_compile: Whether to use the default compilation behavior for torch 2.
 50
 51    Returns:
 52        The compiled model.
 53    """
 54    torch_major = int(torch.__version__.split(".")[0])
 55
 56    if compile_model is None:
 57        if torch_major < 2:
 58            compile_model = False
 59        elif is_compiled(model):  # model is already compiled
 60            compile_model = False
 61        else:
 62            compile_model = default_compile
 63
 64    if compile_model:
 65        if torch_major < 2:
 66            raise RuntimeError("Model compilation is only supported for pytorch 2")
 67
 68        print("Compiling pytorch model ...")
 69        if isinstance(compile_model, str):
 70            model = torch.compile(model, mode=compile_model)
 71        else:
 72            model = torch.compile(model)
 73
 74    return model
 75
 76
 77def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor:
 78    """Ensure that the input is a torch tensor, by converting it if necessary.
 79
 80    Args:
 81        tensor: The input object, either a torch tensor or a numpy-array like object.
 82        dtype: The required data type for the output tensor.
 83
 84    Returns:
 85        The input, converted to a torch tensor if necessary.
 86    """
 87    if isinstance(tensor, np.ndarray):
 88        if np.dtype(tensor.dtype) in DTYPE_MAP:
 89            tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 90        # Try to convert the tensor, even if it has wrong byte-order
 91        try:
 92            tensor = torch.from_numpy(tensor)
 93        except ValueError:
 94            tensor = tensor.view(tensor.dtype.newbyteorder())
 95            if np.dtype(tensor.dtype) in DTYPE_MAP:
 96                tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 97            tensor = torch.from_numpy(tensor)
 98
 99    assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch"
100    if dtype is not None:
101        tensor = tensor.to(dtype=dtype)
102    return tensor
103
104
105def ensure_tensor_with_channels(
106    tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None
107) -> torch.Tensor:
108    """Ensure that the input is a torch tensor of a given dimensionality with channels.
109
110    Args:
111        tensor: The input tensor or numpy-array like data.
112        ndim: The dimensionality of the output tensor.
113        dtype: The data type of the output tensor.
114
115    Returns:
116        The input converted to a torch tensor of the requested dimensionality.
117    """
118    assert ndim in (2, 3, 4), f"{ndim}"
119    tensor = ensure_tensor(tensor, dtype)
120    if ndim == 2:
121        assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}"
122        if tensor.ndim == 2:
123            tensor = tensor[None]
124        elif tensor.ndim == 4:
125            assert tensor.shape[0] == 1, f"{tensor.shape}"
126            tensor = tensor[0]
127        elif tensor.ndim == 5:
128            assert tensor.shape[:2] == (1, 1), f"{tensor.shape}"
129            tensor = tensor[0, 0]
130    elif ndim == 3:
131        assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}"
132        if tensor.ndim == 3:
133            tensor = tensor[None]
134        elif tensor.ndim == 5:
135            assert tensor.shape[0] == 1, f"{tensor.shape}"
136            tensor = tensor[0]
137    else:
138        assert tensor.ndim in (4, 5), f"{tensor.ndim}"
139        if tensor.ndim == 5:
140            assert tensor.shape[0] == 1, f"{tensor.shape}"
141            tensor = tensor[0]
142    return tensor
143
144
145def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray:
146    """Ensure that the input is a numpy array, by converting it if necessary.
147
148    Args:
149        array: The input torch tensor or numpy array.
150        dtype: The dtype of the ouptut array.
151
152    Returns:
153        The input converted to a numpy array if necessary.
154    """
155    if torch.is_tensor(array):
156        array = array.detach().cpu().numpy()
157    assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy"
158    if dtype is not None:
159        array = np.require(array, dtype=dtype)
160    return array
161
162
163def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray:
164    """Ensure that the input is a numpy array of a given dimensionality.
165
166    Args:
167        array: The input numpy array or torch tensor.
168        ndim: The requested dimensionality.
169        dtype: The dtype of the output array.
170
171    Returns:
172        A numpy array of the requested dimensionality and data type.
173    """
174    assert ndim in (2, 3)
175    array = ensure_array(array, dtype)
176    if ndim == 2:
177        assert array.ndim in (2, 3, 4, 5), str(array.ndim)
178        if array.ndim == 3:
179            assert array.shape[0] == 1
180            array = array[0]
181        elif array.ndim == 4:
182            assert array.shape[:2] == (1, 1)
183            array = array[0, 0]
184        elif array.ndim == 5:
185            assert array.shape[:3] == (1, 1, 1)
186            array = array[0, 0, 0]
187    else:
188        assert array.ndim in (3, 4, 5), str(array.ndim)
189        if array.ndim == 4:
190            assert array.shape[0] == 1, f"{array.shape}"
191            array = array[0]
192        elif array.ndim == 5:
193            assert array.shape[:2] == (1, 1)
194            array = array[0, 0]
195    return array
196
197
198def ensure_patch_shape(
199    raw: np.ndarray,
200    labels: Optional[np.ndarray],
201    patch_shape: Tuple[int, ...],
202    have_raw_channels: bool = False,
203    have_label_channels: bool = False,
204    channel_first: bool = True,
205) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
206    """Ensure that the raw data and labels have at least the requested patch shape.
207
208    If either raw data or labels do not have the patch shape they will be padded.
209
210    Args:
211        raw: The input raw data.
212        labels: The input labels.
213        patch_shape: The required minimal patch shape.
214        have_raw_channels: Whether the raw data has channels.
215        have_label_channels: Whether the label data has channels.
216        channel_first: Whether the channel axis is the first or last axis.
217
218    Returns:
219        The raw data.
220        The labels.
221    """
222    raw_shape = raw.shape
223    if labels is not None:
224        labels_shape = labels.shape
225
226    # In case the inputs has channels and they are channels first
227    # IMPORTANT: for ImageCollectionDataset
228    if have_raw_channels and channel_first:
229        raw_shape = raw_shape[1:]
230
231    if have_label_channels and channel_first and labels is not None:
232        labels_shape = labels_shape[1:]
233
234    # Extract the pad_width and pad the raw inputs
235    if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)):
236        pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)]
237
238        if have_raw_channels and channel_first:
239            pad_width = [(0, 0), *pw]
240        elif have_raw_channels and not channel_first:
241            pad_width = [*pw, (0, 0)]
242        else:
243            pad_width = pw
244
245        raw = np.pad(array=raw, pad_width=pad_width)
246
247    # Extract the pad width and pad the label inputs
248    if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)):
249        pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)]
250
251        if have_label_channels and channel_first:
252            pad_width = [(0, 0), *pw]
253        elif have_label_channels and not channel_first:
254            pad_width = [*pw, (0, 0)]
255        else:
256            pad_width = pw
257
258        labels = np.pad(array=labels, pad_width=pad_width)
259    if labels is None:
260        return raw
261    else:
262        return raw, labels
263
264
265def get_constructor_arguments(obj):
266    """@private
267    """
268    # All relevant torch_em classes have 'init_kwargs' to directly recover the init call.
269    if hasattr(obj, "init_kwargs"):
270        return getattr(obj, "init_kwargs")
271
272    def _get_args(obj, param_names):
273        return {name: getattr(obj, name) for name in param_names}
274
275    # We don't need to find the constructor arguments for optimizers/schedulers because we deserialize the state later.
276    if isinstance(
277        obj, (
278            torch.optim.Optimizer,
279            torch.optim.lr_scheduler._LRScheduler,
280            # ReduceLROnPlateau does not inherit from _LRScheduler
281            torch.optim.lr_scheduler.ReduceLROnPlateau
282        )
283    ):
284        return {}
285
286    # recover the arguments for torch dataloader
287    elif isinstance(obj, torch.utils.data.DataLoader):
288        # These are all the "simple" arguements.
289        # "sampler", "batch_sampler" and "worker_init_fn" are more complicated
290        # and generally not used in torch_em
291        return _get_args(
292            obj, [
293                "batch_size", "shuffle", "num_workers", "pin_memory", "drop_last",
294                "persistent_workers", "prefetch_factor", "timeout"
295            ]
296        )
297
298    # TODO support common torch losses (e.g. CrossEntropy, BCE)
299    warnings.warn(
300        f"Constructor arguments for {type(obj)} cannot be deduced.\n" +
301        "For this object, empty constructor arguments will be used.\n" +
302        "The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'."
303    )
304    return {}
305
306
307def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None):
308    """Load trainer from a checkpoint.
309
310    Args:
311        checkpoint: The path to the checkpoint.
312        name: The name of the checkpoint.
313        device: The device to use for loading the checkpoint.
314
315    Returns:
316        The trainer.
317    """
318    # try to load from file
319    if isinstance(checkpoint, str):
320        assert os.path.exists(checkpoint), checkpoint
321        trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device)
322    else:
323        trainer = checkpoint
324    assert isinstance(trainer, torch_em.trainer.DefaultTrainer)
325    return trainer
326
327
328def get_normalizer(trainer):
329    """@private
330    """
331    dataset = trainer.train_loader.dataset
332    while (
333        isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset) or
334        isinstance(dataset, torch.utils.data.dataset.ConcatDataset)
335    ):
336        dataset = dataset.datasets[0]
337
338    if isinstance(dataset, torch.utils.data.dataset.Subset):
339        dataset = dataset.dataset
340
341    preprocessor = dataset.raw_transform
342
343    if hasattr(preprocessor, "normalizer"):
344        return preprocessor.normalizer
345    else:
346        return preprocessor
347
348
349def load_model(
350    checkpoint: str,
351    model: Optional[torch.nn.Module] = None,
352    name: str = "best",
353    state_key: str = "model_state",
354    device: Optional[str] = None,
355) -> torch.nn.Module:
356    """Load model from a trainer checkpoint.
357
358    This function can either load the model directly from the trainer (model is not passed),
359    or deserialize the model state from the trainer and load the model state (model is passed).
360
361    Args:
362        checkpoint: The path to the checkpoint folder.
363        model: The model for which the state should be loaded.
364            If it is not passed, the model class and parameters will also be loaded from the trainer.
365        name: The name of the checkpoint.
366        state_key: The name of the model state to load.
367        device: The device on which to load the model.
368
369    Returns:
370        The model.
371    """
372    if model is None:  # load the model and its state from the checkpoint
373        model = get_trainer(checkpoint, name=name, device=device).model
374
375    else:  # load the model state from the checkpoint
376        if os.path.isdir(checkpoint):
377            ckpt = os.path.join(checkpoint, f"{name}.pt")
378        else:
379            ckpt = checkpoint
380
381        state = torch.load(ckpt, map_location=device, weights_only=False)[state_key]
382        # to enable loading compiled models
383        compiled_prefix = "_orig_mod."
384        state = OrderedDict(
385            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()]
386        )
387        model.load_state_dict(state)
388        if device is not None:
389            model.to(device)
390        model.load_state_dict(state)
391
392    return model
393
394
395def model_is_equal(model1, model2):
396    """@private
397    """
398    for p1, p2 in zip(model1.parameters(), model2.parameters()):
399        if p1.data.ne(p2.data).sum() > 0:
400            return False
401    return True
402
403
404def get_random_colors(labels: np.ndarray) -> colors.ListedColormap:
405    """Generate a random color map for a label image.
406
407    Args:
408        labels: The labels.
409
410    Returns:
411        The color map.
412    """
413    unique_labels = np.unique(labels)
414    have_zero = 0 in unique_labels
415    cmap = [[0, 0, 0]] if have_zero else []
416    cmap += np.random.rand(len(unique_labels), 3).tolist()
417    cmap = colors.ListedColormap(cmap)
418    return cmap
def auto_compile( model: torch.nn.modules.module.Module, compile_model: Union[str, bool, NoneType] = None, default_compile: bool = True) -> torch.nn.modules.module.Module:
39def auto_compile(
40    model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True
41) -> torch.nn.Module:
42    """Automatically compile a model for pytorch >= 2.
43
44    Args:
45        model: The model.
46        compile_model: Whether to comile the model.
47            If None, it will not be compiled for torch < 2, and for torch > 2 the behavior
48            specificed by 'default_compile' will be used. If a string is given it will be
49            intepreted as the compile mode (torch.compile(model, mode=compile_model))
50        default_compile: Whether to use the default compilation behavior for torch 2.
51
52    Returns:
53        The compiled model.
54    """
55    torch_major = int(torch.__version__.split(".")[0])
56
57    if compile_model is None:
58        if torch_major < 2:
59            compile_model = False
60        elif is_compiled(model):  # model is already compiled
61            compile_model = False
62        else:
63            compile_model = default_compile
64
65    if compile_model:
66        if torch_major < 2:
67            raise RuntimeError("Model compilation is only supported for pytorch 2")
68
69        print("Compiling pytorch model ...")
70        if isinstance(compile_model, str):
71            model = torch.compile(model, mode=compile_model)
72        else:
73            model = torch.compile(model)
74
75    return model

Automatically compile a model for pytorch >= 2.

Arguments:
  • model: The model.
  • compile_model: Whether to comile the model. If None, it will not be compiled for torch < 2, and for torch > 2 the behavior specificed by 'default_compile' will be used. If a string is given it will be intepreted as the compile mode (torch.compile(model, mode=compile_model))
  • default_compile: Whether to use the default compilation behavior for torch 2.
Returns:

The compiled model.

def ensure_tensor( tensor: Union[torch.Tensor, Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], dtype: Optional[str] = None) -> torch.Tensor:
 78def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor:
 79    """Ensure that the input is a torch tensor, by converting it if necessary.
 80
 81    Args:
 82        tensor: The input object, either a torch tensor or a numpy-array like object.
 83        dtype: The required data type for the output tensor.
 84
 85    Returns:
 86        The input, converted to a torch tensor if necessary.
 87    """
 88    if isinstance(tensor, np.ndarray):
 89        if np.dtype(tensor.dtype) in DTYPE_MAP:
 90            tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 91        # Try to convert the tensor, even if it has wrong byte-order
 92        try:
 93            tensor = torch.from_numpy(tensor)
 94        except ValueError:
 95            tensor = tensor.view(tensor.dtype.newbyteorder())
 96            if np.dtype(tensor.dtype) in DTYPE_MAP:
 97                tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 98            tensor = torch.from_numpy(tensor)
 99
100    assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch"
101    if dtype is not None:
102        tensor = tensor.to(dtype=dtype)
103    return tensor

Ensure that the input is a torch tensor, by converting it if necessary.

Arguments:
  • tensor: The input object, either a torch tensor or a numpy-array like object.
  • dtype: The required data type for the output tensor.
Returns:

The input, converted to a torch tensor if necessary.

def ensure_tensor_with_channels( tensor: Union[torch.Tensor, Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], ndim: int, dtype: Optional[str] = None) -> torch.Tensor:
106def ensure_tensor_with_channels(
107    tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None
108) -> torch.Tensor:
109    """Ensure that the input is a torch tensor of a given dimensionality with channels.
110
111    Args:
112        tensor: The input tensor or numpy-array like data.
113        ndim: The dimensionality of the output tensor.
114        dtype: The data type of the output tensor.
115
116    Returns:
117        The input converted to a torch tensor of the requested dimensionality.
118    """
119    assert ndim in (2, 3, 4), f"{ndim}"
120    tensor = ensure_tensor(tensor, dtype)
121    if ndim == 2:
122        assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}"
123        if tensor.ndim == 2:
124            tensor = tensor[None]
125        elif tensor.ndim == 4:
126            assert tensor.shape[0] == 1, f"{tensor.shape}"
127            tensor = tensor[0]
128        elif tensor.ndim == 5:
129            assert tensor.shape[:2] == (1, 1), f"{tensor.shape}"
130            tensor = tensor[0, 0]
131    elif ndim == 3:
132        assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}"
133        if tensor.ndim == 3:
134            tensor = tensor[None]
135        elif tensor.ndim == 5:
136            assert tensor.shape[0] == 1, f"{tensor.shape}"
137            tensor = tensor[0]
138    else:
139        assert tensor.ndim in (4, 5), f"{tensor.ndim}"
140        if tensor.ndim == 5:
141            assert tensor.shape[0] == 1, f"{tensor.shape}"
142            tensor = tensor[0]
143    return tensor

Ensure that the input is a torch tensor of a given dimensionality with channels.

Arguments:
  • tensor: The input tensor or numpy-array like data.
  • ndim: The dimensionality of the output tensor.
  • dtype: The data type of the output tensor.
Returns:

The input converted to a torch tensor of the requested dimensionality.

def ensure_array( array: Union[numpy.ndarray, torch.Tensor], dtype: str = None) -> numpy.ndarray:
146def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray:
147    """Ensure that the input is a numpy array, by converting it if necessary.
148
149    Args:
150        array: The input torch tensor or numpy array.
151        dtype: The dtype of the ouptut array.
152
153    Returns:
154        The input converted to a numpy array if necessary.
155    """
156    if torch.is_tensor(array):
157        array = array.detach().cpu().numpy()
158    assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy"
159    if dtype is not None:
160        array = np.require(array, dtype=dtype)
161    return array

Ensure that the input is a numpy array, by converting it if necessary.

Arguments:
  • array: The input torch tensor or numpy array.
  • dtype: The dtype of the ouptut array.
Returns:

The input converted to a numpy array if necessary.

def ensure_spatial_array( array: Union[numpy.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> numpy.ndarray:
164def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray:
165    """Ensure that the input is a numpy array of a given dimensionality.
166
167    Args:
168        array: The input numpy array or torch tensor.
169        ndim: The requested dimensionality.
170        dtype: The dtype of the output array.
171
172    Returns:
173        A numpy array of the requested dimensionality and data type.
174    """
175    assert ndim in (2, 3)
176    array = ensure_array(array, dtype)
177    if ndim == 2:
178        assert array.ndim in (2, 3, 4, 5), str(array.ndim)
179        if array.ndim == 3:
180            assert array.shape[0] == 1
181            array = array[0]
182        elif array.ndim == 4:
183            assert array.shape[:2] == (1, 1)
184            array = array[0, 0]
185        elif array.ndim == 5:
186            assert array.shape[:3] == (1, 1, 1)
187            array = array[0, 0, 0]
188    else:
189        assert array.ndim in (3, 4, 5), str(array.ndim)
190        if array.ndim == 4:
191            assert array.shape[0] == 1, f"{array.shape}"
192            array = array[0]
193        elif array.ndim == 5:
194            assert array.shape[:2] == (1, 1)
195            array = array[0, 0]
196    return array

Ensure that the input is a numpy array of a given dimensionality.

Arguments:
  • array: The input numpy array or torch tensor.
  • ndim: The requested dimensionality.
  • dtype: The dtype of the output array.
Returns:

A numpy array of the requested dimensionality and data type.

def ensure_patch_shape( raw: numpy.ndarray, labels: Optional[numpy.ndarray], patch_shape: Tuple[int, ...], have_raw_channels: bool = False, have_label_channels: bool = False, channel_first: bool = True) -> Union[Tuple[numpy.ndarray, numpy.ndarray], numpy.ndarray]:
199def ensure_patch_shape(
200    raw: np.ndarray,
201    labels: Optional[np.ndarray],
202    patch_shape: Tuple[int, ...],
203    have_raw_channels: bool = False,
204    have_label_channels: bool = False,
205    channel_first: bool = True,
206) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
207    """Ensure that the raw data and labels have at least the requested patch shape.
208
209    If either raw data or labels do not have the patch shape they will be padded.
210
211    Args:
212        raw: The input raw data.
213        labels: The input labels.
214        patch_shape: The required minimal patch shape.
215        have_raw_channels: Whether the raw data has channels.
216        have_label_channels: Whether the label data has channels.
217        channel_first: Whether the channel axis is the first or last axis.
218
219    Returns:
220        The raw data.
221        The labels.
222    """
223    raw_shape = raw.shape
224    if labels is not None:
225        labels_shape = labels.shape
226
227    # In case the inputs has channels and they are channels first
228    # IMPORTANT: for ImageCollectionDataset
229    if have_raw_channels and channel_first:
230        raw_shape = raw_shape[1:]
231
232    if have_label_channels and channel_first and labels is not None:
233        labels_shape = labels_shape[1:]
234
235    # Extract the pad_width and pad the raw inputs
236    if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)):
237        pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)]
238
239        if have_raw_channels and channel_first:
240            pad_width = [(0, 0), *pw]
241        elif have_raw_channels and not channel_first:
242            pad_width = [*pw, (0, 0)]
243        else:
244            pad_width = pw
245
246        raw = np.pad(array=raw, pad_width=pad_width)
247
248    # Extract the pad width and pad the label inputs
249    if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)):
250        pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)]
251
252        if have_label_channels and channel_first:
253            pad_width = [(0, 0), *pw]
254        elif have_label_channels and not channel_first:
255            pad_width = [*pw, (0, 0)]
256        else:
257            pad_width = pw
258
259        labels = np.pad(array=labels, pad_width=pad_width)
260    if labels is None:
261        return raw
262    else:
263        return raw, labels

Ensure that the raw data and labels have at least the requested patch shape.

If either raw data or labels do not have the patch shape they will be padded.

Arguments:
  • raw: The input raw data.
  • labels: The input labels.
  • patch_shape: The required minimal patch shape.
  • have_raw_channels: Whether the raw data has channels.
  • have_label_channels: Whether the label data has channels.
  • channel_first: Whether the channel axis is the first or last axis.
Returns:

The raw data. The labels.

def get_trainer(checkpoint: str, name: str = 'best', device: Optional[str] = None):
308def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None):
309    """Load trainer from a checkpoint.
310
311    Args:
312        checkpoint: The path to the checkpoint.
313        name: The name of the checkpoint.
314        device: The device to use for loading the checkpoint.
315
316    Returns:
317        The trainer.
318    """
319    # try to load from file
320    if isinstance(checkpoint, str):
321        assert os.path.exists(checkpoint), checkpoint
322        trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device)
323    else:
324        trainer = checkpoint
325    assert isinstance(trainer, torch_em.trainer.DefaultTrainer)
326    return trainer

Load trainer from a checkpoint.

Arguments:
  • checkpoint: The path to the checkpoint.
  • name: The name of the checkpoint.
  • device: The device to use for loading the checkpoint.
Returns:

The trainer.

def load_model( checkpoint: str, model: Optional[torch.nn.modules.module.Module] = None, name: str = 'best', state_key: str = 'model_state', device: Optional[str] = None) -> torch.nn.modules.module.Module:
350def load_model(
351    checkpoint: str,
352    model: Optional[torch.nn.Module] = None,
353    name: str = "best",
354    state_key: str = "model_state",
355    device: Optional[str] = None,
356) -> torch.nn.Module:
357    """Load model from a trainer checkpoint.
358
359    This function can either load the model directly from the trainer (model is not passed),
360    or deserialize the model state from the trainer and load the model state (model is passed).
361
362    Args:
363        checkpoint: The path to the checkpoint folder.
364        model: The model for which the state should be loaded.
365            If it is not passed, the model class and parameters will also be loaded from the trainer.
366        name: The name of the checkpoint.
367        state_key: The name of the model state to load.
368        device: The device on which to load the model.
369
370    Returns:
371        The model.
372    """
373    if model is None:  # load the model and its state from the checkpoint
374        model = get_trainer(checkpoint, name=name, device=device).model
375
376    else:  # load the model state from the checkpoint
377        if os.path.isdir(checkpoint):
378            ckpt = os.path.join(checkpoint, f"{name}.pt")
379        else:
380            ckpt = checkpoint
381
382        state = torch.load(ckpt, map_location=device, weights_only=False)[state_key]
383        # to enable loading compiled models
384        compiled_prefix = "_orig_mod."
385        state = OrderedDict(
386            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()]
387        )
388        model.load_state_dict(state)
389        if device is not None:
390            model.to(device)
391        model.load_state_dict(state)
392
393    return model

Load model from a trainer checkpoint.

This function can either load the model directly from the trainer (model is not passed), or deserialize the model state from the trainer and load the model state (model is passed).

Arguments:
  • checkpoint: The path to the checkpoint folder.
  • model: The model for which the state should be loaded. If it is not passed, the model class and parameters will also be loaded from the trainer.
  • name: The name of the checkpoint.
  • state_key: The name of the model state to load.
  • device: The device on which to load the model.
Returns:

The model.

def get_random_colors(labels: numpy.ndarray) -> matplotlib.colors.ListedColormap:
405def get_random_colors(labels: np.ndarray) -> colors.ListedColormap:
406    """Generate a random color map for a label image.
407
408    Args:
409        labels: The labels.
410
411    Returns:
412        The color map.
413    """
414    unique_labels = np.unique(labels)
415    have_zero = 0 in unique_labels
416    cmap = [[0, 0, 0]] if have_zero else []
417    cmap += np.random.rand(len(unique_labels), 3).tolist()
418    cmap = colors.ListedColormap(cmap)
419    return cmap

Generate a random color map for a label image.

Arguments:
  • labels: The labels.
Returns:

The color map.