torch_em.util.util

  1import os
  2import warnings
  3from collections import OrderedDict
  4from typing import Optional, Tuple, Union
  5
  6import numpy as np
  7import torch
  8import torch_em
  9from matplotlib import colors
 10from numpy.typing import ArrayLike
 11
 12try:
 13    from torch._dynamo.eval_frame import OptimizedModule
 14except ImportError:
 15    OptimizedModule = None
 16
 17# torch doesn't support most unsigned types,
 18# so we map them to their signed equivalent
 19DTYPE_MAP = {
 20    np.dtype("uint16"): np.int16,
 21    np.dtype("uint32"): np.int32,
 22    np.dtype("uint64"): np.int64
 23}
 24"""@private
 25"""
 26
 27
 28# This is a fairly brittle way to check if a module is compiled.
 29# Would be good to find a better solution, ideall something like model.is_compiled().
 30def is_compiled(model):
 31    """@private
 32    """
 33    if OptimizedModule is None:
 34        return False
 35    return isinstance(model, OptimizedModule)
 36
 37
 38def auto_compile(
 39    model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True
 40) -> torch.nn.Module:
 41    """Automatically compile a model for pytorch >= 2.
 42
 43    Args:
 44        model: The model.
 45        compile_model: Whether to comile the model.
 46            If None, it will not be compiled for torch < 2, and for torch > 2 the behavior
 47            specificed by 'default_compile' will be used. If a string is given it will be
 48            intepreted as the compile mode (torch.compile(model, mode=compile_model))
 49        default_compile: Whether to use the default compilation behavior for torch 2.
 50
 51    Returns:
 52        The compiled model.
 53    """
 54    torch_major = int(torch.__version__.split(".")[0])
 55
 56    if compile_model is None:
 57        if torch_major < 2:
 58            compile_model = False
 59        elif is_compiled(model):  # model is already compiled
 60            compile_model = False
 61        else:
 62            compile_model = default_compile
 63
 64    if compile_model:
 65        if torch_major < 2:
 66            raise RuntimeError("Model compilation is only supported for pytorch 2")
 67
 68        print("Compiling pytorch model ...")
 69        if isinstance(compile_model, str):
 70            model = torch.compile(model, mode=compile_model)
 71        else:
 72            model = torch.compile(model)
 73
 74    return model
 75
 76
 77def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor:
 78    """Ensure that the input is a torch tensor, by converting it if necessary.
 79
 80    Args:
 81        tensor: The input object, either a torch tensor or a numpy-array like object.
 82        dtype: The required data type for the output tensor.
 83
 84    Returns:
 85        The input, converted to a torch tensor if necessary.
 86    """
 87    if isinstance(tensor, np.ndarray):
 88        if np.dtype(tensor.dtype) in DTYPE_MAP:
 89            tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 90        # Try to convert the tensor, even if it has wrong byte-order
 91        try:
 92            tensor = torch.from_numpy(tensor if tensor.flags.writeable else tensor.copy())
 93        except ValueError:
 94            tensor = tensor.view(tensor.dtype.newbyteorder())
 95            if np.dtype(tensor.dtype) in DTYPE_MAP:
 96                tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 97            tensor = torch.from_numpy(tensor)
 98
 99    assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch"
100    if dtype is not None:
101        tensor = tensor.to(dtype=dtype)
102    return tensor
103
104
105def validate_roi(roi, shape, patch_shape=None):
106    """Normalize an ROI to explicit slices and validate that it is non-empty."""
107    if roi is None:
108        return None
109    if isinstance(roi, slice):
110        roi = (roi,)
111    if not isinstance(roi, tuple):
112        raise TypeError(f"Invalid roi type: {type(roi)}")
113    if len(roi) > len(shape):
114        raise ValueError(f"Invalid roi {roi} for data shape {shape}: too many dimensions")
115
116    normalized_roi = []
117    for this_roi, dim in zip(roi, shape):
118        if not isinstance(this_roi, slice):
119            raise TypeError(f"Invalid roi entry: {this_roi}. Only slices are supported")
120        step = 1 if this_roi.step is None else this_roi.step
121        if step != 1:
122            raise ValueError(f"Invalid roi {roi}: slice steps other than 1 are not supported")
123        start, stop, _ = this_roi.indices(dim)
124        normalized_roi.append(slice(start, stop))
125
126    if len(roi) < len(shape):
127        normalized_roi.extend(slice(0, dim) for dim in shape[len(roi):])
128
129    roi_shape = tuple(sl.stop - sl.start for sl in normalized_roi)
130    if any(sh <= 0 for sh in roi_shape):
131        msg = f"Invalid roi {roi} for data shape {shape}: it results in an empty region"
132        if patch_shape is not None:
133            msg += f" for patch_shape {patch_shape}"
134        raise ValueError(msg)
135
136    return tuple(normalized_roi)
137
138
139def ensure_tensor_with_channels(
140    tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None
141) -> torch.Tensor:
142    """Ensure that the input is a torch tensor of a given dimensionality with channels.
143
144    Args:
145        tensor: The input tensor or numpy-array like data.
146        ndim: The dimensionality of the output tensor.
147        dtype: The data type of the output tensor.
148
149    Returns:
150        The input converted to a torch tensor of the requested dimensionality.
151    """
152    assert ndim in (2, 3, 4), f"{ndim}"
153    tensor = ensure_tensor(tensor, dtype)
154    if ndim == 2:
155        assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}"
156        if tensor.ndim == 2:
157            tensor = tensor[None]
158        elif tensor.ndim == 4:
159            assert tensor.shape[0] == 1, f"{tensor.shape}"
160            tensor = tensor[0]
161        elif tensor.ndim == 5:
162            assert tensor.shape[:2] == (1, 1), f"{tensor.shape}"
163            tensor = tensor[0, 0]
164    elif ndim == 3:
165        assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}"
166        if tensor.ndim == 3:
167            tensor = tensor[None]
168        elif tensor.ndim == 5:
169            assert tensor.shape[0] == 1, f"{tensor.shape}"
170            tensor = tensor[0]
171    else:
172        assert tensor.ndim in (4, 5), f"{tensor.ndim}"
173        if tensor.ndim == 5:
174            assert tensor.shape[0] == 1, f"{tensor.shape}"
175            tensor = tensor[0]
176    return tensor
177
178
179def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray:
180    """Ensure that the input is a numpy array, by converting it if necessary.
181
182    Args:
183        array: The input torch tensor or numpy array.
184        dtype: The dtype of the ouptut array.
185
186    Returns:
187        The input converted to a numpy array if necessary.
188    """
189    if torch.is_tensor(array):
190        array = array.detach().cpu().numpy()
191    assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy"
192    if dtype is not None:
193        array = np.require(array, dtype=dtype)
194    return array
195
196
197def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray:
198    """Ensure that the input is a numpy array of a given dimensionality.
199
200    Args:
201        array: The input numpy array or torch tensor.
202        ndim: The requested dimensionality.
203        dtype: The dtype of the output array.
204
205    Returns:
206        A numpy array of the requested dimensionality and data type.
207    """
208    assert ndim in (2, 3)
209    array = ensure_array(array, dtype)
210    if ndim == 2:
211        assert array.ndim in (2, 3, 4, 5), str(array.ndim)
212        if array.ndim == 3:
213            assert array.shape[0] == 1
214            array = array[0]
215        elif array.ndim == 4:
216            assert array.shape[:2] == (1, 1)
217            array = array[0, 0]
218        elif array.ndim == 5:
219            assert array.shape[:3] == (1, 1, 1)
220            array = array[0, 0, 0]
221    else:
222        assert array.ndim in (3, 4, 5), str(array.ndim)
223        if array.ndim == 4:
224            assert array.shape[0] == 1, f"{array.shape}"
225            array = array[0]
226        elif array.ndim == 5:
227            assert array.shape[:2] == (1, 1)
228            array = array[0, 0]
229    return array
230
231
232def ensure_patch_shape(
233    raw: np.ndarray,
234    labels: Optional[np.ndarray],
235    patch_shape: Tuple[int, ...],
236    have_raw_channels: bool = False,
237    have_label_channels: bool = False,
238    channel_first: bool = True,
239) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
240    """Ensure that the raw data and labels have at least the requested patch shape.
241
242    If either raw data or labels do not have the patch shape they will be padded.
243
244    Args:
245        raw: The input raw data.
246        labels: The input labels.
247        patch_shape: The required minimal patch shape.
248        have_raw_channels: Whether the raw data has channels.
249        have_label_channels: Whether the label data has channels.
250        channel_first: Whether the channel axis is the first or last axis.
251
252    Returns:
253        The raw data.
254        The labels.
255    """
256    raw_shape = raw.shape
257    if labels is not None:
258        labels_shape = labels.shape
259
260    # In case the inputs has channels and they are channels first
261    # IMPORTANT: for ImageCollectionDataset
262    if have_raw_channels and channel_first:
263        raw_shape = raw_shape[1:]
264
265    if have_label_channels and channel_first and labels is not None:
266        labels_shape = labels_shape[1:]
267
268    # Extract the pad_width and pad the raw inputs
269    if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)):
270        pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)]
271
272        if have_raw_channels and channel_first:
273            pad_width = [(0, 0), *pw]
274        elif have_raw_channels and not channel_first:
275            pad_width = [*pw, (0, 0)]
276        else:
277            pad_width = pw
278
279        raw = np.pad(array=raw, pad_width=pad_width)
280
281    # Extract the pad width and pad the label inputs
282    if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)):
283        pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)]
284
285        if have_label_channels and channel_first:
286            pad_width = [(0, 0), *pw]
287        elif have_label_channels and not channel_first:
288            pad_width = [*pw, (0, 0)]
289        else:
290            pad_width = pw
291
292        labels = np.pad(array=labels, pad_width=pad_width)
293    if labels is None:
294        return raw
295    else:
296        return raw, labels
297
298
299def get_constructor_arguments(obj):
300    """@private
301    """
302    # All relevant torch_em classes have 'init_kwargs' to directly recover the init call.
303    if hasattr(obj, "init_kwargs"):
304        return getattr(obj, "init_kwargs")
305
306    def _get_args(obj, param_names):
307        return {name: getattr(obj, name) for name in param_names}
308
309    # We don't need to find the constructor arguments for optimizers/schedulers because we deserialize the state later.
310    if isinstance(
311        obj, (
312            torch.optim.Optimizer,
313            torch.optim.lr_scheduler._LRScheduler,
314            # ReduceLROnPlateau does not inherit from _LRScheduler
315            torch.optim.lr_scheduler.ReduceLROnPlateau
316        )
317    ):
318        return {}
319
320    # recover the arguments for torch dataloader
321    elif isinstance(obj, torch.utils.data.DataLoader):
322        # These are all the "simple" arguements.
323        # "sampler", "batch_sampler" and "worker_init_fn" are more complicated
324        # and generally not used in torch_em
325        sampler = getattr(obj, "sampler", None)
326        if sampler is not None and not isinstance(
327            sampler,
328            (
329                torch.utils.data.RandomSampler,
330                torch.utils.data.SequentialSampler,
331                torch.utils.data.SubsetRandomSampler,
332            ),
333        ):
334            warnings.warn(
335                f"DataLoader uses sampler {type(sampler).__name__}, but only its effective `shuffle` setting "
336                "is serialized. `DefaultTrainer.from_checkpoint` will recreate the loader without the original "
337                "sampler, so sampling behavior may change."
338            )
339        shuffle = getattr(obj, "shuffle", None)
340        if shuffle is None:
341            shuffle = getattr(sampler, "shuffle", None)
342        if shuffle is None:
343            # Only randomized samplers map to shuffle=True. SequentialSampler is handled
344            # by the default fallback of shuffle=False and does not need a special case.
345            shuffle = isinstance(sampler, (torch.utils.data.RandomSampler, torch.utils.data.SubsetRandomSampler))
346
347        return {
348            **_get_args(
349                obj, [
350                    "batch_size", "num_workers", "pin_memory", "drop_last",
351                    "persistent_workers", "prefetch_factor", "timeout"
352                ]
353            ),
354            "shuffle": shuffle,
355        }
356
357    # TODO support common torch losses (e.g. CrossEntropy, BCE)
358    warnings.warn(
359        f"Constructor arguments for {type(obj)} cannot be deduced.\n" +
360        "For this object, empty constructor arguments will be used.\n" +
361        "The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'."
362    )
363    return {}
364
365
366def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None):
367    """Load trainer from a checkpoint.
368
369    Args:
370        checkpoint: The path to the checkpoint.
371        name: The name of the checkpoint.
372        device: The device to use for loading the checkpoint.
373
374    Returns:
375        The trainer.
376    """
377    # try to load from file
378    if isinstance(checkpoint, str):
379        assert os.path.exists(checkpoint), checkpoint
380        trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device)
381    else:
382        trainer = checkpoint
383    assert isinstance(trainer, torch_em.trainer.DefaultTrainer)
384    return trainer
385
386
387def get_normalizer(trainer):
388    """@private
389    """
390    dataset = trainer.train_loader.dataset
391    while (
392        isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset) or
393        isinstance(dataset, torch.utils.data.dataset.ConcatDataset)
394    ):
395        dataset = dataset.datasets[0]
396
397    if isinstance(dataset, torch.utils.data.dataset.Subset):
398        dataset = dataset.dataset
399
400    preprocessor = dataset.raw_transform
401
402    if hasattr(preprocessor, "normalizer"):
403        return preprocessor.normalizer
404    else:
405        return preprocessor
406
407
408def load_model(
409    checkpoint: str,
410    model: Optional[torch.nn.Module] = None,
411    name: str = "best",
412    state_key: Optional[str] = "model_state",
413    device: Optional[str] = None,
414) -> torch.nn.Module:
415    """Load a model from a trainer checkpoint or a serialized torch model.
416
417    This function can either load the model directly (`model` is not passed),
418    or deserialize the model state and then load it (`model` is passed).
419
420    The `checkpoint` argument must either point to the checkpoint directory of a torch_em trainer
421    or to a serialized torch model.
422
423    Args:
424        checkpoint: The path to the checkpoint folder or serialized torch model.
425        model: The model for which the state should be loaded.
426            If it is not passed, the model class and parameters will also be loaded from the trainer.
427        name: The name of the checkpoint.
428        state_key: The name of the model state to load. Set to None if the model state is stored top-level.
429        device: The device on which to load the model.
430
431    Returns:
432        The model.
433    """
434    if model is None and os.path.isdir(checkpoint):  # Load the model and its state from a torch_em checkpoint.
435        model = get_trainer(checkpoint, name=name, device=device).model
436
437    elif model is None:  # Load the model from a serialized model.
438        model = torch.load(checkpoint, map_location=device, weights_only=False)
439
440    else:  # Load the model state from a checkpoint.
441        if os.path.isdir(checkpoint):  # From a torch_em checkpoint.
442            ckpt = os.path.join(checkpoint, f"{name}.pt")
443        else:  # From a serialized path.
444            ckpt = checkpoint
445
446        state = torch.load(ckpt, map_location=device, weights_only=False)
447        if state_key is not None:
448            state = state[state_key]
449
450        # To enable loading compiled models.
451        compiled_prefix = "_orig_mod."
452        state = OrderedDict(
453            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()]
454        )
455
456        model.load_state_dict(state)
457        if device is not None:
458            model.to(device)
459
460    return model
461
462
463def model_is_equal(model1, model2):
464    """@private
465    """
466    for p1, p2 in zip(model1.parameters(), model2.parameters()):
467        if p1.data.ne(p2.data).sum() > 0:
468            return False
469    return True
470
471
472def get_random_colors(labels: np.ndarray) -> colors.ListedColormap:
473    """Generate a random color map for a label image.
474
475    Args:
476        labels: The labels.
477
478    Returns:
479        The color map.
480    """
481    unique_labels = np.unique(labels)
482    have_zero = 0 in unique_labels
483    cmap = [[0, 0, 0]] if have_zero else []
484    cmap += np.random.rand(len(unique_labels), 3).tolist()
485    cmap = colors.ListedColormap(cmap)
486    return cmap
def auto_compile( model: torch.nn.modules.module.Module, compile_model: Union[str, bool, NoneType] = None, default_compile: bool = True) -> torch.nn.modules.module.Module:
39def auto_compile(
40    model: torch.nn.Module, compile_model: Optional[Union[str, bool]] = None, default_compile: bool = True
41) -> torch.nn.Module:
42    """Automatically compile a model for pytorch >= 2.
43
44    Args:
45        model: The model.
46        compile_model: Whether to comile the model.
47            If None, it will not be compiled for torch < 2, and for torch > 2 the behavior
48            specificed by 'default_compile' will be used. If a string is given it will be
49            intepreted as the compile mode (torch.compile(model, mode=compile_model))
50        default_compile: Whether to use the default compilation behavior for torch 2.
51
52    Returns:
53        The compiled model.
54    """
55    torch_major = int(torch.__version__.split(".")[0])
56
57    if compile_model is None:
58        if torch_major < 2:
59            compile_model = False
60        elif is_compiled(model):  # model is already compiled
61            compile_model = False
62        else:
63            compile_model = default_compile
64
65    if compile_model:
66        if torch_major < 2:
67            raise RuntimeError("Model compilation is only supported for pytorch 2")
68
69        print("Compiling pytorch model ...")
70        if isinstance(compile_model, str):
71            model = torch.compile(model, mode=compile_model)
72        else:
73            model = torch.compile(model)
74
75    return model

Automatically compile a model for pytorch >= 2.

Arguments:
  • model: The model.
  • compile_model: Whether to comile the model. If None, it will not be compiled for torch < 2, and for torch > 2 the behavior specificed by 'default_compile' will be used. If a string is given it will be intepreted as the compile mode (torch.compile(model, mode=compile_model))
  • default_compile: Whether to use the default compilation behavior for torch 2.
Returns:

The compiled model.

def ensure_tensor( tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor:
 78def ensure_tensor(tensor: Union[torch.Tensor, ArrayLike], dtype: Optional[str] = None) -> torch.Tensor:
 79    """Ensure that the input is a torch tensor, by converting it if necessary.
 80
 81    Args:
 82        tensor: The input object, either a torch tensor or a numpy-array like object.
 83        dtype: The required data type for the output tensor.
 84
 85    Returns:
 86        The input, converted to a torch tensor if necessary.
 87    """
 88    if isinstance(tensor, np.ndarray):
 89        if np.dtype(tensor.dtype) in DTYPE_MAP:
 90            tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 91        # Try to convert the tensor, even if it has wrong byte-order
 92        try:
 93            tensor = torch.from_numpy(tensor if tensor.flags.writeable else tensor.copy())
 94        except ValueError:
 95            tensor = tensor.view(tensor.dtype.newbyteorder())
 96            if np.dtype(tensor.dtype) in DTYPE_MAP:
 97                tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
 98            tensor = torch.from_numpy(tensor)
 99
100    assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch"
101    if dtype is not None:
102        tensor = tensor.to(dtype=dtype)
103    return tensor

Ensure that the input is a torch tensor, by converting it if necessary.

Arguments:
  • tensor: The input object, either a torch tensor or a numpy-array like object.
  • dtype: The required data type for the output tensor.
Returns:

The input, converted to a torch tensor if necessary.

def validate_roi(roi, shape, patch_shape=None):
106def validate_roi(roi, shape, patch_shape=None):
107    """Normalize an ROI to explicit slices and validate that it is non-empty."""
108    if roi is None:
109        return None
110    if isinstance(roi, slice):
111        roi = (roi,)
112    if not isinstance(roi, tuple):
113        raise TypeError(f"Invalid roi type: {type(roi)}")
114    if len(roi) > len(shape):
115        raise ValueError(f"Invalid roi {roi} for data shape {shape}: too many dimensions")
116
117    normalized_roi = []
118    for this_roi, dim in zip(roi, shape):
119        if not isinstance(this_roi, slice):
120            raise TypeError(f"Invalid roi entry: {this_roi}. Only slices are supported")
121        step = 1 if this_roi.step is None else this_roi.step
122        if step != 1:
123            raise ValueError(f"Invalid roi {roi}: slice steps other than 1 are not supported")
124        start, stop, _ = this_roi.indices(dim)
125        normalized_roi.append(slice(start, stop))
126
127    if len(roi) < len(shape):
128        normalized_roi.extend(slice(0, dim) for dim in shape[len(roi):])
129
130    roi_shape = tuple(sl.stop - sl.start for sl in normalized_roi)
131    if any(sh <= 0 for sh in roi_shape):
132        msg = f"Invalid roi {roi} for data shape {shape}: it results in an empty region"
133        if patch_shape is not None:
134            msg += f" for patch_shape {patch_shape}"
135        raise ValueError(msg)
136
137    return tuple(normalized_roi)

Normalize an ROI to explicit slices and validate that it is non-empty.

def ensure_tensor_with_channels( tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None) -> torch.Tensor:
140def ensure_tensor_with_channels(
141    tensor: Union[torch.Tensor, ArrayLike], ndim: int, dtype: Optional[str] = None
142) -> torch.Tensor:
143    """Ensure that the input is a torch tensor of a given dimensionality with channels.
144
145    Args:
146        tensor: The input tensor or numpy-array like data.
147        ndim: The dimensionality of the output tensor.
148        dtype: The data type of the output tensor.
149
150    Returns:
151        The input converted to a torch tensor of the requested dimensionality.
152    """
153    assert ndim in (2, 3, 4), f"{ndim}"
154    tensor = ensure_tensor(tensor, dtype)
155    if ndim == 2:
156        assert tensor.ndim in (2, 3, 4, 5), f"{tensor.ndim}"
157        if tensor.ndim == 2:
158            tensor = tensor[None]
159        elif tensor.ndim == 4:
160            assert tensor.shape[0] == 1, f"{tensor.shape}"
161            tensor = tensor[0]
162        elif tensor.ndim == 5:
163            assert tensor.shape[:2] == (1, 1), f"{tensor.shape}"
164            tensor = tensor[0, 0]
165    elif ndim == 3:
166        assert tensor.ndim in (3, 4, 5), f"{tensor.ndim}"
167        if tensor.ndim == 3:
168            tensor = tensor[None]
169        elif tensor.ndim == 5:
170            assert tensor.shape[0] == 1, f"{tensor.shape}"
171            tensor = tensor[0]
172    else:
173        assert tensor.ndim in (4, 5), f"{tensor.ndim}"
174        if tensor.ndim == 5:
175            assert tensor.shape[0] == 1, f"{tensor.shape}"
176            tensor = tensor[0]
177    return tensor

Ensure that the input is a torch tensor of a given dimensionality with channels.

Arguments:
  • tensor: The input tensor or numpy-array like data.
  • ndim: The dimensionality of the output tensor.
  • dtype: The data type of the output tensor.
Returns:

The input converted to a torch tensor of the requested dimensionality.

def ensure_array( array: Union[numpy.ndarray, torch.Tensor], dtype: str = None) -> numpy.ndarray:
180def ensure_array(array: Union[np.ndarray, torch.Tensor], dtype: str = None) -> np.ndarray:
181    """Ensure that the input is a numpy array, by converting it if necessary.
182
183    Args:
184        array: The input torch tensor or numpy array.
185        dtype: The dtype of the ouptut array.
186
187    Returns:
188        The input converted to a numpy array if necessary.
189    """
190    if torch.is_tensor(array):
191        array = array.detach().cpu().numpy()
192    assert isinstance(array, np.ndarray), f"Cannot convert {type(array)} to numpy"
193    if dtype is not None:
194        array = np.require(array, dtype=dtype)
195    return array

Ensure that the input is a numpy array, by converting it if necessary.

Arguments:
  • array: The input torch tensor or numpy array.
  • dtype: The dtype of the ouptut array.
Returns:

The input converted to a numpy array if necessary.

def ensure_spatial_array( array: Union[numpy.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> numpy.ndarray:
198def ensure_spatial_array(array: Union[np.ndarray, torch.Tensor], ndim: int, dtype: str = None) -> np.ndarray:
199    """Ensure that the input is a numpy array of a given dimensionality.
200
201    Args:
202        array: The input numpy array or torch tensor.
203        ndim: The requested dimensionality.
204        dtype: The dtype of the output array.
205
206    Returns:
207        A numpy array of the requested dimensionality and data type.
208    """
209    assert ndim in (2, 3)
210    array = ensure_array(array, dtype)
211    if ndim == 2:
212        assert array.ndim in (2, 3, 4, 5), str(array.ndim)
213        if array.ndim == 3:
214            assert array.shape[0] == 1
215            array = array[0]
216        elif array.ndim == 4:
217            assert array.shape[:2] == (1, 1)
218            array = array[0, 0]
219        elif array.ndim == 5:
220            assert array.shape[:3] == (1, 1, 1)
221            array = array[0, 0, 0]
222    else:
223        assert array.ndim in (3, 4, 5), str(array.ndim)
224        if array.ndim == 4:
225            assert array.shape[0] == 1, f"{array.shape}"
226            array = array[0]
227        elif array.ndim == 5:
228            assert array.shape[:2] == (1, 1)
229            array = array[0, 0]
230    return array

Ensure that the input is a numpy array of a given dimensionality.

Arguments:
  • array: The input numpy array or torch tensor.
  • ndim: The requested dimensionality.
  • dtype: The dtype of the output array.
Returns:

A numpy array of the requested dimensionality and data type.

def ensure_patch_shape( raw: numpy.ndarray, labels: Optional[numpy.ndarray], patch_shape: Tuple[int, ...], have_raw_channels: bool = False, have_label_channels: bool = False, channel_first: bool = True) -> Union[Tuple[numpy.ndarray, numpy.ndarray], numpy.ndarray]:
233def ensure_patch_shape(
234    raw: np.ndarray,
235    labels: Optional[np.ndarray],
236    patch_shape: Tuple[int, ...],
237    have_raw_channels: bool = False,
238    have_label_channels: bool = False,
239    channel_first: bool = True,
240) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
241    """Ensure that the raw data and labels have at least the requested patch shape.
242
243    If either raw data or labels do not have the patch shape they will be padded.
244
245    Args:
246        raw: The input raw data.
247        labels: The input labels.
248        patch_shape: The required minimal patch shape.
249        have_raw_channels: Whether the raw data has channels.
250        have_label_channels: Whether the label data has channels.
251        channel_first: Whether the channel axis is the first or last axis.
252
253    Returns:
254        The raw data.
255        The labels.
256    """
257    raw_shape = raw.shape
258    if labels is not None:
259        labels_shape = labels.shape
260
261    # In case the inputs has channels and they are channels first
262    # IMPORTANT: for ImageCollectionDataset
263    if have_raw_channels and channel_first:
264        raw_shape = raw_shape[1:]
265
266    if have_label_channels and channel_first and labels is not None:
267        labels_shape = labels_shape[1:]
268
269    # Extract the pad_width and pad the raw inputs
270    if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)):
271        pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)]
272
273        if have_raw_channels and channel_first:
274            pad_width = [(0, 0), *pw]
275        elif have_raw_channels and not channel_first:
276            pad_width = [*pw, (0, 0)]
277        else:
278            pad_width = pw
279
280        raw = np.pad(array=raw, pad_width=pad_width)
281
282    # Extract the pad width and pad the label inputs
283    if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)):
284        pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)]
285
286        if have_label_channels and channel_first:
287            pad_width = [(0, 0), *pw]
288        elif have_label_channels and not channel_first:
289            pad_width = [*pw, (0, 0)]
290        else:
291            pad_width = pw
292
293        labels = np.pad(array=labels, pad_width=pad_width)
294    if labels is None:
295        return raw
296    else:
297        return raw, labels

Ensure that the raw data and labels have at least the requested patch shape.

If either raw data or labels do not have the patch shape they will be padded.

Arguments:
  • raw: The input raw data.
  • labels: The input labels.
  • patch_shape: The required minimal patch shape.
  • have_raw_channels: Whether the raw data has channels.
  • have_label_channels: Whether the label data has channels.
  • channel_first: Whether the channel axis is the first or last axis.
Returns:

The raw data. The labels.

def get_trainer(checkpoint: str, name: str = 'best', device: Optional[str] = None):
367def get_trainer(checkpoint: str, name: str = "best", device: Optional[str] = None):
368    """Load trainer from a checkpoint.
369
370    Args:
371        checkpoint: The path to the checkpoint.
372        name: The name of the checkpoint.
373        device: The device to use for loading the checkpoint.
374
375    Returns:
376        The trainer.
377    """
378    # try to load from file
379    if isinstance(checkpoint, str):
380        assert os.path.exists(checkpoint), checkpoint
381        trainer = torch_em.trainer.DefaultTrainer.from_checkpoint(checkpoint, name=name, device=device)
382    else:
383        trainer = checkpoint
384    assert isinstance(trainer, torch_em.trainer.DefaultTrainer)
385    return trainer

Load trainer from a checkpoint.

Arguments:
  • checkpoint: The path to the checkpoint.
  • name: The name of the checkpoint.
  • device: The device to use for loading the checkpoint.
Returns:

The trainer.

def load_model( checkpoint: str, model: Optional[torch.nn.modules.module.Module] = None, name: str = 'best', state_key: Optional[str] = 'model_state', device: Optional[str] = None) -> torch.nn.modules.module.Module:
409def load_model(
410    checkpoint: str,
411    model: Optional[torch.nn.Module] = None,
412    name: str = "best",
413    state_key: Optional[str] = "model_state",
414    device: Optional[str] = None,
415) -> torch.nn.Module:
416    """Load a model from a trainer checkpoint or a serialized torch model.
417
418    This function can either load the model directly (`model` is not passed),
419    or deserialize the model state and then load it (`model` is passed).
420
421    The `checkpoint` argument must either point to the checkpoint directory of a torch_em trainer
422    or to a serialized torch model.
423
424    Args:
425        checkpoint: The path to the checkpoint folder or serialized torch model.
426        model: The model for which the state should be loaded.
427            If it is not passed, the model class and parameters will also be loaded from the trainer.
428        name: The name of the checkpoint.
429        state_key: The name of the model state to load. Set to None if the model state is stored top-level.
430        device: The device on which to load the model.
431
432    Returns:
433        The model.
434    """
435    if model is None and os.path.isdir(checkpoint):  # Load the model and its state from a torch_em checkpoint.
436        model = get_trainer(checkpoint, name=name, device=device).model
437
438    elif model is None:  # Load the model from a serialized model.
439        model = torch.load(checkpoint, map_location=device, weights_only=False)
440
441    else:  # Load the model state from a checkpoint.
442        if os.path.isdir(checkpoint):  # From a torch_em checkpoint.
443            ckpt = os.path.join(checkpoint, f"{name}.pt")
444        else:  # From a serialized path.
445            ckpt = checkpoint
446
447        state = torch.load(ckpt, map_location=device, weights_only=False)
448        if state_key is not None:
449            state = state[state_key]
450
451        # To enable loading compiled models.
452        compiled_prefix = "_orig_mod."
453        state = OrderedDict(
454            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in state.items()]
455        )
456
457        model.load_state_dict(state)
458        if device is not None:
459            model.to(device)
460
461    return model

Load a model from a trainer checkpoint or a serialized torch model.

This function can either load the model directly (model is not passed), or deserialize the model state and then load it (model is passed).

The checkpoint argument must either point to the checkpoint directory of a torch_em trainer or to a serialized torch model.

Arguments:
  • checkpoint: The path to the checkpoint folder or serialized torch model.
  • model: The model for which the state should be loaded. If it is not passed, the model class and parameters will also be loaded from the trainer.
  • name: The name of the checkpoint.
  • state_key: The name of the model state to load. Set to None if the model state is stored top-level.
  • device: The device on which to load the model.
Returns:

The model.

def get_random_colors(labels: numpy.ndarray) -> matplotlib.colors.ListedColormap:
473def get_random_colors(labels: np.ndarray) -> colors.ListedColormap:
474    """Generate a random color map for a label image.
475
476    Args:
477        labels: The labels.
478
479    Returns:
480        The color map.
481    """
482    unique_labels = np.unique(labels)
483    have_zero = 0 in unique_labels
484    cmap = [[0, 0, 0]] if have_zero else []
485    cmap += np.random.rand(len(unique_labels), 3).tolist()
486    cmap = colors.ListedColormap(cmap)
487    return cmap

Generate a random color map for a label image.

Arguments:
  • labels: The labels.
Returns:

The color map.