torch_em.util.util

  1import os
  2import warnings
  3from collections import OrderedDict
  4from typing import Optional, Tuple, Union
  5
  6import numpy as np
  7import torch
  8import torch_em
  9from matplotlib import colors
 10from numpy.typing import ArrayLike
 11
 12try:
 13    from torch._dynamo.eval_frame import OptimizedModule
 14except ImportError:
 15    OptimizedModule = None
 16
 17# torch doesn't support most unsigned types,
 18# so we map them to their signed equivalent
 19DTYPE_MAP = {
 20    np.dtype("uint16"): np.int16,
 21    np.dtype("uint32"): np.int32,
 22    np.dtype("uint64"): np.int64
 23}
 24"""@private
 25"""
 26
 27
 28# This is a fairly brittle way to check if a module is compiled.
 29# Would be good to find a better solution, ideall something like model.is_compiled().
 30def is_compiled(model):
 31    """@private
 32    """
 33    if OptimizedModule is None:
 34        return False
 35    return isinstance(model, OptimizedModule)
 36
 37
 38def auto_compile(
 39    model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True
 40) -> torch.nn.Module:
 41    """Automatically compile a model for pytorch >= 2.
 42
 43    Args:
 44        model: The model.
 45        compile_model: Whether to comile the model.
 46            If None, it will not be compiled for torch < 2, and for torch > 2 the behavior
 47            specificed by 'default_compile' will be used. If a string is given it will be
 48            intepreted as the compile mode (torch.compile(model, mode=compile_model))
 49        default_compile: Whether to use the default compilation behavior for torch 2.
 50
 51    Returns:
 52        The compiled model.
 53    """
 54    torch_major = int(torch.__version__.split(".")[0])
 55
 56    if compile_model is None:
 57        if torch_major < 2:
 58            compile_model = False
 59        elif is_compiled(model):  # model is already compiled
 60            compile_model = False
 61        else:
 62            compile_model = default_compile
 63
 64    if compile_model:
 65        if torch_major < 2:
 66            raise RuntimeError("Model compilation is only supported for pytorch 2")
 67
 68        print("Compiling pytorch model ...")
 69        if isinstance(compile_model, str):
 70            model = torch.compile(model, mode=compile_model)
 71        else:
 72            model = torch.compile(model)
 73
 74    return model
 75
 76
 77def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor:
 78    """Ensure that the input is a torch tensor, by converting it if necessary.
 79
 80    Args:
 81        tensor: The input object, either a torch tensor or a numpy-array like object.
 82        dtype: The required data type for the output tensor.
 83
 84    Returns:
 85        The input, converted to a torch tensor if necessary.
 86    """
 87    if isinstance(tensor, np.ndarray):
 88        if np.dtype(tensor.dtype) in DTYPE_MAP:
 89            tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 90        # Try to convert the tensor, even if it has wrong byte-order
 91        try:
 92            tensor = torch.from_numpy(tensor)
 93        except ValueError:
 94            tensor = tensor.view(tensor.dtype.newbyteorder())
 95            if np.dtype(tensor.dtype) in DTYPE_MAP:
 96                tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 97            tensor = torch.from_numpy(tensor)
 98
 99    assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch"
100    if dtype is not None:
101        tensor = tensor.to(dtype=dtype)
102    return tensor
103
104
105def ensure_tensor_with_channels(
106    tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None
107) -> torch.Tensor:
108    """Ensure that the input is a torch tensor of a given dimensionality with channels.
109
110    Args:
111        tensor: The input tensor or numpy-array like data.
112        ndim: The dimensionality of the output tensor.
113        dtype: The data type of the output tensor.
114
115    Returns:
116        The input converted to a torch tensor of the requested dimensionality.
117    """
118    assert ndim in (2, 3, 4), f"{ndim}"
119    tensor = ensure_tensor(tensor, dtype)
120    if ndim == 2:
121        assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}"
122        if tensor.ndim == 2:
123            tensor = tensor[None]
124        elif tensor.ndim == 4:
125            assert tensor.shape[0] == 1, f"{tensor.shape}"
126            tensor = tensor[0]
127        elif tensor.ndim == 5:
128            assert tensor.shape[:2] == (1, 1), f"{tensor.shape}"
129            tensor = tensor[0, 0]
130    elif ndim == 3:
131        assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}"
132        if tensor.ndim == 3:
133            tensor = tensor[None]
134        elif tensor.ndim == 5:
135            assert tensor.shape[0] == 1, f"{tensor.shape}"
136            tensor = tensor[0]
137    else:
138        assert tensor.ndim in (4, 5), f"{tensor.ndim}"
139        if tensor.ndim == 5:
140            assert tensor.shape[0] == 1, f"{tensor.shape}"
141            tensor = tensor[0]
142    return tensor
143
144
145def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray:
146    """Ensure that the input is a numpy array, by converting it if necessary.
147
148    Args:
149        array: The input torch tensor or numpy array.
150        dtype: The dtype of the ouptut array.
151
152    Returns:
153        The input converted to a numpy array if necessary.
154    """
155    if torch.is_tensor(array):
156        array = array.detach().cpu().numpy()
157    assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy"
158    if dtype is not None:
159        array = np.require(array, dtype=dtype)
160    return array
161
162
163def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray:
164    """Ensure that the input is a numpy array of a given dimensionality.
165
166    Args:
167        array: The input numpy array or torch tensor.
168        ndim: The requested dimensionality.
169        dtype: The dtype of the output array.
170
171    Returns:
172        A numpy array of the requested dimensionality and data type.
173    """
174    assert ndim in (2, 3)
175    array = ensure_array(array, dtype)
176    if ndim == 2:
177        assert array.ndim in (2, 3, 4, 5), str(array.ndim)
178        if array.ndim == 3:
179            assert array.shape[0] == 1
180            array = array[0]
181        elif array.ndim == 4:
182            assert array.shape[:2] == (1, 1)
183            array = array[0, 0]
184        elif array.ndim == 5:
185            assert array.shape[:3] == (1, 1, 1)
186            array = array[0, 0, 0]
187    else:
188        assert array.ndim in (3, 4, 5), str(array.ndim)
189        if array.ndim == 4:
190            assert array.shape[0] == 1, f"{array.shape}"
191            array = array[0]
192        elif array.ndim == 5:
193            assert array.shape[:2] == (1, 1)
194            array = array[0, 0]
195    return array
196
197
198def ensure_patch_shape(
199    raw: np.ndarray,
200    labels: Optional[np.ndarray],
201    patch_shape: Tuple[int, ...],
202    have_raw_channels: bool = False,
203    have_label_channels: bool = False,
204    channel_first: bool = True,
205) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
206    """Ensure that the raw data and labels have at least the requested patch shape.
207
208    If either raw data or labels do not have the patch shape they will be padded.
209
210    Args:
211        raw: The input raw data.
212        labels: The input labels.
213        patch_shape: The required minimal patch shape.
214        have_raw_channels: Whether the raw data has channels.
215        have_label_channels: Whether the label data has channels.
216        channel_first: Whether the channel axis is the first or last axis.
217
218    Returns:
219        The raw data.
220        The labels.
221    """
222    raw_shape = raw.shape
223    if labels is not None:
224        labels_shape = labels.shape
225
226    # In case the inputs has channels and they are channels first
227    # IMPORTANT: for ImageCollectionDataset
228    if have_raw_channels and channel_first:
229        raw_shape = raw_shape[1:]
230
231    if have_label_channels and channel_first and labels is not None:
232        labels_shape = labels_shape[1:]
233
234    # Extract the pad_width and pad the raw inputs
235    if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)):
236        pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)]
237
238        if have_raw_channels and channel_first:
239            pad_width = [(0, 0), *pw]
240        elif have_raw_channels and not channel_first:
241            pad_width = [*pw, (0, 0)]
242        else:
243            pad_width = pw
244
245        raw = np.pad(array=raw, pad_width=pad_width)
246
247    # Extract the pad width and pad the label inputs
248    if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)):
249        pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)]
250
251        if have_label_channels and channel_first:
252            pad_width = [(0, 0), *pw]
253        elif have_label_channels and not channel_first:
254            pad_width = [*pw, (0, 0)]
255        else:
256            pad_width = pw
257
258        labels = np.pad(array=labels, pad_width=pad_width)
259    if labels is None:
260        return raw
261    else:
262        return raw, labels
263
264
265def get_constructor_arguments(obj):
266    """@private
267    """
268    # All relevant torch_em classes have 'init_kwargs' to directly recover the init call.
269    if hasattr(obj, "init_kwargs"):
270        return getattr(obj, "init_kwargs")
271
272    def _get_args(obj, param_names):
273        return {name: getattr(obj, name) for name in param_names}
274
275    # We don't need to find the constructor arguments for optimizers/schedulers because we deserialize the state later.
276    if isinstance(
277        obj, (
278            torch.optim.Optimizer,
279            torch.optim.lr_scheduler._LRScheduler,
280            # ReduceLROnPlateau does not inherit from _LRScheduler
281            torch.optim.lr_scheduler.ReduceLROnPlateau
282        )
283    ):
284        return {}
285
286    # recover the arguments for torch dataloader
287    elif isinstance(obj, torch.utils.data.DataLoader):
288        # These are all the "simple" arguements.
289        # "sampler", "batch_sampler" and "worker_init_fn" are more complicated
290        # and generally not used in torch_em
291        return _get_args(
292            obj, [
293                "batch_size", "shuffle", "num_workers", "pin_memory", "drop_last",
294                "persistent_workers", "prefetch_factor", "timeout"
295            ]
296        )
297
298    # TODO support common torch losses (e.g. CrossEntropy, BCE)
299    warnings.warn(
300        f"Constructor arguments for {type(obj)} cannot be deduced.\n" +
301        "For this object, empty constructor arguments will be used.\n" +
302        "The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'."
303    )
304    return {}
305
306
307def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None):
308    """Load trainer from a checkpoint.
309
310    Args:
311        checkpoint: The path to the checkpoint.
312        name: The name of the checkpoint.
313        device: The device to use for loading the checkpoint.
314
315    Returns:
316        The trainer.
317    """
318    # try to load from file
319    if isinstance(checkpoint, str):
320        assert os.path.exists(checkpoint), checkpoint
321        trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device)
322    else:
323        trainer = checkpoint
324    assert isinstance(trainer, torch_em.trainer.DefaultTrainer)
325    return trainer
326
327
328def get_normalizer(trainer):
329    """@private
330    """
331    dataset = trainer.train_loader.dataset
332    while (
333        isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset) or
334        isinstance(dataset, torch.utils.data.dataset.ConcatDataset)
335    ):
336        dataset = dataset.datasets[0]
337
338    if isinstance(dataset, torch.utils.data.dataset.Subset):
339        dataset = dataset.dataset
340
341    preprocessor = dataset.raw_transform
342
343    if hasattr(preprocessor, "normalizer"):
344        return preprocessor.normalizer
345    else:
346        return preprocessor
347
348
349def load_model(
350    checkpoint: str,
351    model: Optional[torch.nn.Module] = None,
352    name: str = "best",
353    state_key: Optional[str] = "model_state",
354    device: Optional[str] = None,
355) -> torch.nn.Module:
356    """Load a model from a trainer checkpoint or a serialized torch model.
357
358    This function can either load the model directly (`model` is not passed),
359    or deserialize the model state and then load it (`model` is passed).
360
361    The `checkpoint` argument must either point to the checkpoint directory of a torch_em trainer
362    or to a serialized torch model.
363
364    Args:
365        checkpoint: The path to the checkpoint folder or serialized torch model.
366        model: The model for which the state should be loaded.
367            If it is not passed, the model class and parameters will also be loaded from the trainer.
368        name: The name of the checkpoint.
369        state_key: The name of the model state to load. Set to None if the model state is stored top-level.
370        device: The device on which to load the model.
371
372    Returns:
373        The model.
374    """
375    if model is None and os.path.isdir(checkpoint):  # Load the model and its state from a torch_em checkpoint.
376        model = get_trainer(checkpoint, name=name, device=device).model
377
378    elif model is None:  # Load the model from a serialized model.
379        model = torch.load(checkpoint, map_location=device, weights_only=False)
380
381    else:  # Load the model state from a checkpoint.
382        if os.path.isdir(checkpoint):  # From a torch_em checkpoint.
383            ckpt = os.path.join(checkpoint, f"{name}.pt")
384        else:  # From a serialized path.
385            ckpt = checkpoint
386
387        state = torch.load(ckpt, map_location=device, weights_only=False)
388        if state_key is not None:
389            state = state[state_key]
390
391        # To enable loading compiled models.
392        compiled_prefix = "_orig_mod."
393        state = OrderedDict(
394            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()]
395        )
396
397        model.load_state_dict(state)
398        if device is not None:
399            model.to(device)
400
401    return model
402
403
404def model_is_equal(model1, model2):
405    """@private
406    """
407    for p1, p2 in zip(model1.parameters(), model2.parameters()):
408        if p1.data.ne(p2.data).sum() > 0:
409            return False
410    return True
411
412
413def get_random_colors(labels: np.ndarray) -> colors.ListedColormap:
414    """Generate a random color map for a label image.
415
416    Args:
417        labels: The labels.
418
419    Returns:
420        The color map.
421    """
422    unique_labels = np.unique(labels)
423    have_zero = 0 in unique_labels
424    cmap = [[0, 0, 0]] if have_zero else []
425    cmap += np.random.rand(len(unique_labels), 3).tolist()
426    cmap = colors.ListedColormap(cmap)
427    return cmap
def auto_compile( model: torch.nn.modules.module.Module, compile_model: Union[str, bool, NoneType] = None, default_compile: bool = True) -> torch.nn.modules.module.Module:
39def auto_compile(
40    model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True
41) -> torch.nn.Module:
42    """Automatically compile a model for pytorch >= 2.
43
44    Args:
45        model: The model.
46        compile_model: Whether to comile the model.
47            If None, it will not be compiled for torch < 2, and for torch > 2 the behavior
48            specificed by 'default_compile' will be used. If a string is given it will be
49            intepreted as the compile mode (torch.compile(model, mode=compile_model))
50        default_compile: Whether to use the default compilation behavior for torch 2.
51
52    Returns:
53        The compiled model.
54    """
55    torch_major = int(torch.__version__.split(".")[0])
56
57    if compile_model is None:
58        if torch_major < 2:
59            compile_model = False
60        elif is_compiled(model):  # model is already compiled
61            compile_model = False
62        else:
63            compile_model = default_compile
64
65    if compile_model:
66        if torch_major < 2:
67            raise RuntimeError("Model compilation is only supported for pytorch 2")
68
69        print("Compiling pytorch model ...")
70        if isinstance(compile_model, str):
71            model = torch.compile(model, mode=compile_model)
72        else:
73            model = torch.compile(model)
74
75    return model

Automatically compile a model for pytorch >= 2.

Arguments:
  • model: The model.
  • compile_model: Whether to comile the model. If None, it will not be compiled for torch < 2, and for torch > 2 the behavior specificed by 'default_compile' will be used. If a string is given it will be intepreted as the compile mode (torch.compile(model, mode=compile_model))
  • default_compile: Whether to use the default compilation behavior for torch 2.
Returns:

The compiled model.

def ensure_tensor( tensor: Union[torch.Tensor, Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], dtype: Optional[str] = None) -> torch.Tensor:
 78def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor:
 79    """Ensure that the input is a torch tensor, by converting it if necessary.
 80
 81    Args:
 82        tensor: The input object, either a torch tensor or a numpy-array like object.
 83        dtype: The required data type for the output tensor.
 84
 85    Returns:
 86        The input, converted to a torch tensor if necessary.
 87    """
 88    if isinstance(tensor, np.ndarray):
 89        if np.dtype(tensor.dtype) in DTYPE_MAP:
 90            tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 91        # Try to convert the tensor, even if it has wrong byte-order
 92        try:
 93            tensor = torch.from_numpy(tensor)
 94        except ValueError:
 95            tensor = tensor.view(tensor.dtype.newbyteorder())
 96            if np.dtype(tensor.dtype) in DTYPE_MAP:
 97                tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 98            tensor = torch.from_numpy(tensor)
 99
100    assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch"
101    if dtype is not None:
102        tensor = tensor.to(dtype=dtype)
103    return tensor

Ensure that the input is a torch tensor, by converting it if necessary.

Arguments:
  • tensor: The input object, either a torch tensor or a numpy-array like object.
  • dtype: The required data type for the output tensor.
Returns:

The input, converted to a torch tensor if necessary.

def ensure_tensor_with_channels( tensor: Union[torch.Tensor, Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], ndim: int, dtype: Optional[str] = None) -> torch.Tensor:
106def ensure_tensor_with_channels(
107    tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None
108) -> torch.Tensor:
109    """Ensure that the input is a torch tensor of a given dimensionality with channels.
110
111    Args:
112        tensor: The input tensor or numpy-array like data.
113        ndim: The dimensionality of the output tensor.
114        dtype: The data type of the output tensor.
115
116    Returns:
117        The input converted to a torch tensor of the requested dimensionality.
118    """
119    assert ndim in (2, 3, 4), f"{ndim}"
120    tensor = ensure_tensor(tensor, dtype)
121    if ndim == 2:
122        assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}"
123        if tensor.ndim == 2:
124            tensor = tensor[None]
125        elif tensor.ndim == 4:
126            assert tensor.shape[0] == 1, f"{tensor.shape}"
127            tensor = tensor[0]
128        elif tensor.ndim == 5:
129            assert tensor.shape[:2] == (1, 1), f"{tensor.shape}"
130            tensor = tensor[0, 0]
131    elif ndim == 3:
132        assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}"
133        if tensor.ndim == 3:
134            tensor = tensor[None]
135        elif tensor.ndim == 5:
136            assert tensor.shape[0] == 1, f"{tensor.shape}"
137            tensor = tensor[0]
138    else:
139        assert tensor.ndim in (4, 5), f"{tensor.ndim}"
140        if tensor.ndim == 5:
141            assert tensor.shape[0] == 1, f"{tensor.shape}"
142            tensor = tensor[0]
143    return tensor

Ensure that the input is a torch tensor of a given dimensionality with channels.

Arguments:
  • tensor: The input tensor or numpy-array like data.
  • ndim: The dimensionality of the output tensor.
  • dtype: The data type of the output tensor.
Returns:

The input converted to a torch tensor of the requested dimensionality.

def ensure_array( array: Union[numpy.ndarray, torch.Tensor], dtype: str = None) -> numpy.ndarray:
146def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray:
147    """Ensure that the input is a numpy array, by converting it if necessary.
148
149    Args:
150        array: The input torch tensor or numpy array.
151        dtype: The dtype of the ouptut array.
152
153    Returns:
154        The input converted to a numpy array if necessary.
155    """
156    if torch.is_tensor(array):
157        array = array.detach().cpu().numpy()
158    assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy"
159    if dtype is not None:
160        array = np.require(array, dtype=dtype)
161    return array

Ensure that the input is a numpy array, by converting it if necessary.

Arguments:
  • array: The input torch tensor or numpy array.
  • dtype: The dtype of the ouptut array.
Returns:

The input converted to a numpy array if necessary.

def ensure_spatial_array( array: Union[numpy.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> numpy.ndarray:
164def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray:
165    """Ensure that the input is a numpy array of a given dimensionality.
166
167    Args:
168        array: The input numpy array or torch tensor.
169        ndim: The requested dimensionality.
170        dtype: The dtype of the output array.
171
172    Returns:
173        A numpy array of the requested dimensionality and data type.
174    """
175    assert ndim in (2, 3)
176    array = ensure_array(array, dtype)
177    if ndim == 2:
178        assert array.ndim in (2, 3, 4, 5), str(array.ndim)
179        if array.ndim == 3:
180            assert array.shape[0] == 1
181            array = array[0]
182        elif array.ndim == 4:
183            assert array.shape[:2] == (1, 1)
184            array = array[0, 0]
185        elif array.ndim == 5:
186            assert array.shape[:3] == (1, 1, 1)
187            array = array[0, 0, 0]
188    else:
189        assert array.ndim in (3, 4, 5), str(array.ndim)
190        if array.ndim == 4:
191            assert array.shape[0] == 1, f"{array.shape}"
192            array = array[0]
193        elif array.ndim == 5:
194            assert array.shape[:2] == (1, 1)
195            array = array[0, 0]
196    return array

Ensure that the input is a numpy array of a given dimensionality.

Arguments:
  • array: The input numpy array or torch tensor.
  • ndim: The requested dimensionality.
  • dtype: The dtype of the output array.
Returns:

A numpy array of the requested dimensionality and data type.

def ensure_patch_shape( raw: numpy.ndarray, labels: Optional[numpy.ndarray], patch_shape: Tuple[int, ...], have_raw_channels: bool = False, have_label_channels: bool = False, channel_first: bool = True) -> Union[Tuple[numpy.ndarray, numpy.ndarray], numpy.ndarray]:
199def ensure_patch_shape(
200    raw: np.ndarray,
201    labels: Optional[np.ndarray],
202    patch_shape: Tuple[int, ...],
203    have_raw_channels: bool = False,
204    have_label_channels: bool = False,
205    channel_first: bool = True,
206) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
207    """Ensure that the raw data and labels have at least the requested patch shape.
208
209    If either raw data or labels do not have the patch shape they will be padded.
210
211    Args:
212        raw: The input raw data.
213        labels: The input labels.
214        patch_shape: The required minimal patch shape.
215        have_raw_channels: Whether the raw data has channels.
216        have_label_channels: Whether the label data has channels.
217        channel_first: Whether the channel axis is the first or last axis.
218
219    Returns:
220        The raw data.
221        The labels.
222    """
223    raw_shape = raw.shape
224    if labels is not None:
225        labels_shape = labels.shape
226
227    # In case the inputs has channels and they are channels first
228    # IMPORTANT: for ImageCollectionDataset
229    if have_raw_channels and channel_first:
230        raw_shape = raw_shape[1:]
231
232    if have_label_channels and channel_first and labels is not None:
233        labels_shape = labels_shape[1:]
234
235    # Extract the pad_width and pad the raw inputs
236    if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)):
237        pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)]
238
239        if have_raw_channels and channel_first:
240            pad_width = [(0, 0), *pw]
241        elif have_raw_channels and not channel_first:
242            pad_width = [*pw, (0, 0)]
243        else:
244            pad_width = pw
245
246        raw = np.pad(array=raw, pad_width=pad_width)
247
248    # Extract the pad width and pad the label inputs
249    if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)):
250        pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)]
251
252        if have_label_channels and channel_first:
253            pad_width = [(0, 0), *pw]
254        elif have_label_channels and not channel_first:
255            pad_width = [*pw, (0, 0)]
256        else:
257            pad_width = pw
258
259        labels = np.pad(array=labels, pad_width=pad_width)
260    if labels is None:
261        return raw
262    else:
263        return raw, labels

Ensure that the raw data and labels have at least the requested patch shape.

If either raw data or labels do not have the patch shape they will be padded.

Arguments:
  • raw: The input raw data.
  • labels: The input labels.
  • patch_shape: The required minimal patch shape.
  • have_raw_channels: Whether the raw data has channels.
  • have_label_channels: Whether the label data has channels.
  • channel_first: Whether the channel axis is the first or last axis.
Returns:

The raw data. The labels.

def get_trainer(checkpoint: str, name: str = 'best', device: Optional[str] = None):
308def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None):
309    """Load trainer from a checkpoint.
310
311    Args:
312        checkpoint: The path to the checkpoint.
313        name: The name of the checkpoint.
314        device: The device to use for loading the checkpoint.
315
316    Returns:
317        The trainer.
318    """
319    # try to load from file
320    if isinstance(checkpoint, str):
321        assert os.path.exists(checkpoint), checkpoint
322        trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device)
323    else:
324        trainer = checkpoint
325    assert isinstance(trainer, torch_em.trainer.DefaultTrainer)
326    return trainer

Load trainer from a checkpoint.

Arguments:
  • checkpoint: The path to the checkpoint.
  • name: The name of the checkpoint.
  • device: The device to use for loading the checkpoint.
Returns:

The trainer.

def load_model( checkpoint: str, model: Optional[torch.nn.modules.module.Module] = None, name: str = 'best', state_key: Optional[str] = 'model_state', device: Optional[str] = None) -> torch.nn.modules.module.Module:
350def load_model(
351    checkpoint: str,
352    model: Optional[torch.nn.Module] = None,
353    name: str = "best",
354    state_key: Optional[str] = "model_state",
355    device: Optional[str] = None,
356) -> torch.nn.Module:
357    """Load a model from a trainer checkpoint or a serialized torch model.
358
359    This function can either load the model directly (`model` is not passed),
360    or deserialize the model state and then load it (`model` is passed).
361
362    The `checkpoint` argument must either point to the checkpoint directory of a torch_em trainer
363    or to a serialized torch model.
364
365    Args:
366        checkpoint: The path to the checkpoint folder or serialized torch model.
367        model: The model for which the state should be loaded.
368            If it is not passed, the model class and parameters will also be loaded from the trainer.
369        name: The name of the checkpoint.
370        state_key: The name of the model state to load. Set to None if the model state is stored top-level.
371        device: The device on which to load the model.
372
373    Returns:
374        The model.
375    """
376    if model is None and os.path.isdir(checkpoint):  # Load the model and its state from a torch_em checkpoint.
377        model = get_trainer(checkpoint, name=name, device=device).model
378
379    elif model is None:  # Load the model from a serialized model.
380        model = torch.load(checkpoint, map_location=device, weights_only=False)
381
382    else:  # Load the model state from a checkpoint.
383        if os.path.isdir(checkpoint):  # From a torch_em checkpoint.
384            ckpt = os.path.join(checkpoint, f"{name}.pt")
385        else:  # From a serialized path.
386            ckpt = checkpoint
387
388        state = torch.load(ckpt, map_location=device, weights_only=False)
389        if state_key is not None:
390            state = state[state_key]
391
392        # To enable loading compiled models.
393        compiled_prefix = "_orig_mod."
394        state = OrderedDict(
395            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()]
396        )
397
398        model.load_state_dict(state)
399        if device is not None:
400            model.to(device)
401
402    return model

Load a model from a trainer checkpoint or a serialized torch model.

This function can either load the model directly (model is not passed), or deserialize the model state and then load it (model is passed).

The checkpoint argument must either point to the checkpoint directory of a torch_em trainer or to a serialized torch model.

Arguments:
  • checkpoint: The path to the checkpoint folder or serialized torch model.
  • model: The model for which the state should be loaded. If it is not passed, the model class and parameters will also be loaded from the trainer.
  • name: The name of the checkpoint.
  • state_key: The name of the model state to load. Set to None if the model state is stored top-level.
  • device: The device on which to load the model.
Returns:

The model.

def get_random_colors(labels: numpy.ndarray) -> matplotlib.colors.ListedColormap:
414def get_random_colors(labels: np.ndarray) -> colors.ListedColormap:
415    """Generate a random color map for a label image.
416
417    Args:
418        labels: The labels.
419
420    Returns:
421        The color map.
422    """
423    unique_labels = np.unique(labels)
424    have_zero = 0 in unique_labels
425    cmap = [[0, 0, 0]] if have_zero else []
426    cmap += np.random.rand(len(unique_labels), 3).tolist()
427    cmap = colors.ListedColormap(cmap)
428    return cmap

Generate a random color map for a label image.

Arguments:
  • labels: The labels.
Returns:

The color map.