torch_em.trainer.default_trainer

  1from __future__ import annotations
  2
  3import os
  4import time
  5import inspect
  6import warnings
  7import contextlib
  8from tqdm import tqdm
  9from collections import OrderedDict
 10from functools import partial
 11from importlib import import_module
 12from typing import Any, Callable, Dict, Optional, Union, Literal
 13
 14import numpy as np
 15
 16import torch
 17
 18from .wandb_logger import WandbLogger
 19from .tensorboard_logger import TensorboardLogger
 20from ..util import auto_compile, get_constructor_arguments, is_compiled
 21
 22
 23class DefaultTrainer:
 24    """Trainer class for training a segmentation network.
 25
 26    The trainer class implements the core logic for training a network with pytorch.
 27    It implements a training loop to run training and validation, which is started with `fit`.
 28    The checkpoints and logs from the training run will be saved in the current working directory,
 29    or in the directory specifified by `save_root`. Training can be continued from a checkpoint
 30    by passing it's location to the `load_from_checkpoint` argument of `fit`.
 31
 32    A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`.
 33    Alternatively, the trainer class can also be instantiated as in this example:
 34    ```python
 35    import torch
 36    from torch_em.loss import DiceLoss
 37    from torch_em.model import UNet2d
 38    from torch_em.data.datasets.light_microscopy import get_dsb_loader
 39    from torch_em.trainer import DefaultTrainer
 40
 41    # The training data will be downloaded to this location.
 42    data_root = "/path/to/save/the/training/data"
 43    patch_shape = (256, 256)
 44
 45    # Create the model and optimizer.
 46    model = UNet2d(in_channels=1, out_channels=1)
 47    optimizer = torch.optim.AdamW(model.parameters())
 48
 49    trainer = DefaultTrainer(
 50        name="unet-training",
 51        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
 52        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
 53        model=model,
 54        loss=DiceLoss(),  # The loss function.
 55        optimizer=optimizer,
 56        metric=DiceLoss(),  # The metric. The trainer expects smaller values to represent better results.
 57        device="cuda",  # The device to use for training.
 58    )
 59    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
 60    ```
 61
 62    Args:
 63        name: The name of the checkpoint that will be created by the trainer.
 64        train_loader: The data loader containing the training data.
 65        val_loader: The data loader containing the validation data.
 66        model: The model to train.
 67        loss: The loss function for training.
 68        optimizer: The optimizer.
 69        metric: The metric for validation.
 70        device: The torch device to use for training. If None, will use a GPU if available.
 71        lr_scheduler: The learning rate scheduler.
 72        log_image_interval: The interval for saving images during logging, in training iterations.
 73        mixed_precision: Whether to train with mixed precision.
 74        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
 75        logger: The logger class. Will be instantiated for logging.
 76            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
 77        logger_kwargs: The keyword arguments for the logger class.
 78        id_: Unique identifier for the trainer. If None then `name` will be used.
 79        save_root: The root folder for saving the checkpoint and logs.
 80        compile_model: Whether to compile the model before training.
 81        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
 82    """
 83    def __init__(
 84        self,
 85        name: Optional[str],
 86        train_loader: torch.utils.data.DataLoader,
 87        val_loader: torch.utils.data.DataLoader,
 88        model: torch.nn.Module,
 89        loss: torch.nn.Module,
 90        optimizer: torch.optim.Optimizer,
 91        metric: Callable,
 92        device: Union[str, torch.device],
 93        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
 94        log_image_interval: int = 100,
 95        mixed_precision: bool = True,
 96        early_stopping: Optional[int] = None,
 97        logger=TensorboardLogger,
 98        logger_kwargs: Optional[Dict[str, Any]] = None,
 99        id_: Optional[str] = None,
100        save_root: Optional[str] = None,
101        compile_model: Optional[Union[bool, str]] = None,
102        rank: Optional[int] = None,
103    ):
104        if name is None and not issubclass(logger, WandbLogger):
105            raise TypeError("Name cannot be None if not using the WandbLogger")
106
107        if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
108            raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.")
109
110        self._generate_name = name is None
111        self.name = name
112        self.id_ = id_ or name
113        self.train_loader = train_loader
114        self.val_loader = val_loader
115        self.model = model
116        self.loss = loss
117        self.optimizer = optimizer
118        self.metric = metric
119        self.device = torch.device(device)
120        self.lr_scheduler = lr_scheduler
121        self.log_image_interval = log_image_interval
122        self.save_root = save_root
123        self.compile_model = compile_model
124        self.rank = rank
125
126        self._iteration = 0
127        self._epoch = 0
128        self._best_epoch = 0
129
130        self.mixed_precision = mixed_precision
131        self.early_stopping = early_stopping
132        self.train_time = 0.0
133
134        if mixed_precision:
135            self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda")
136        else:
137            self.scaler = None
138
139        self.logger_class = logger
140        self.logger_kwargs = logger_kwargs
141        self.log_image_interval = log_image_interval
142
143    @property
144    def checkpoint_folder(self):
145        assert self.id_ is not None  # Because the logger may generate and set trainer.id on logger.__init__.
146        # Save_root enables saving the checkpoints somewhere else than in the local older.
147        # This is handy for filesystems with limited space, where saving the checkpoints
148        # and log files can ead to running out of space.
149        save_root = getattr(self, "save_root", None)
150        return os.path.join("./checkpoints", self.id_) if save_root is None else\
151            os.path.join(save_root, "./checkpoints", self.id_)
152
153    @property
154    def iteration(self):
155        return self._iteration
156
157    @property
158    def epoch(self):
159        return self._epoch
160
161    class Deserializer:
162        """Determines how to deserialize the trainer kwargs from serialized 'init_data'.
163
164        Examples:
165            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
166            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
167
168            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
169            >>> class MyTrainer(DefaultTrainer):
170            >>>     def __init__(self, *args, the_answer: int, **kwargs):
171            >>>         super().__init__(*args, **kwargs)
172            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
173            >>>                                       # see DefaultTrainer.Serializer
174            >>>
175            >>>     class Deserializer(DefaultTrainer.Deserializer):
176            >>>         def load_the_answer(self):
177            >>>             generic_answer = self.init_data["the_answer"]
178            >>>             # (device dependent) special deserialization
179            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
180            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
181            >>>             else:
182            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
183
184        Args:
185            init_data: The initialization data of the trainer.
186            save_path: The path where the checkpoint was saved.
187            device: The device.
188        """
189
190        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
191            self.init_data = init_data
192            self.save_path = save_path
193            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
194            self.trainer_kwargs: Dict[str, Any] = dict(
195                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
196            )
197
198        def load(self, kwarg_name: str, optional):
199            """@private
200            """
201            # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
202            if kwarg_name == "device":
203                pass  # deserialized in __init__
204            elif kwarg_name.endswith("_loader"):
205                self.load_data_loader(kwarg_name, optional)
206            else:
207                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
208                load(kwarg_name, optional=optional)
209
210        def load_data_loader(self, loader_name, optional) -> None:
211            """@private
212            """
213            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
214            if ds is None and optional:
215                return
216
217            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
218            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
219            # monkey patch shuffle loader_name to the loader
220            loader.shuffle = loader_kwargs.get("shuffle", False)
221            self.trainer_kwargs[loader_name] = loader
222
223        def load_generic(
224            self,
225            kwarg_name: str,
226            *dynamic_args: Dict,
227            optional: bool,
228            only_class: bool = False,
229            dynamic_kwargs: Optional[Dict[str, Any]] = None,
230        ) -> None:
231            """@private
232            """
233            if kwarg_name in self.init_data:
234                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
235                return
236
237            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
238            if this_cls is None:
239                if optional:
240                    return
241                else:
242                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
243
244            assert isinstance(this_cls, str), this_cls
245            assert "." in this_cls, this_cls
246            cls_p, cls_m = this_cls.rsplit(".", 1)
247            this_cls = getattr(import_module(cls_p), cls_m)
248            if only_class:
249                self.trainer_kwargs[kwarg_name] = this_cls
250            else:
251                self.trainer_kwargs[kwarg_name] = this_cls(
252                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
253                )
254
255        def load_name(self, kwarg_name: str, optional: bool):
256            """@private
257            """
258            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
259
260        def load_optimizer(self, kwarg_name: str, optional: bool):
261            """@private
262            """
263            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
264
265        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
266            """@private
267            """
268            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
269
270        # todo: remove and rename kwarg 'logger' to 'logger_class'
271        def load_logger(self, kwarg_name: str, optional: bool):
272            """@private
273            """
274            assert kwarg_name == "logger"
275            self.load_generic("logger", optional=optional, only_class=True)
276
277    @staticmethod
278    def _get_save_dict(save_path, device):
279        if not os.path.exists(save_path):
280            raise ValueError(f"Cannot find checkpoint {save_path}")
281        return torch.load(save_path, map_location=device, weights_only=False)
282
283    @classmethod
284    def from_checkpoint(
285        cls,
286        checkpoint_folder: Union[os.PathLike, str],
287        name: Literal["best", "latest"] = "best",
288        device: Optional[Union[str, torch.device]] = None,
289    ):
290        """@private
291        """
292        save_path = os.path.join(checkpoint_folder, f"{name}.pt")
293        # make sure the correct device is set if we don't have access to CUDA
294        if not torch.cuda.is_available():
295            device = "cpu"
296        save_dict = cls._get_save_dict(save_path, device)
297        deserializer = cls.Deserializer(save_dict["init"], save_path, device)
298
299        has_kwargs = False
300        deserialized = []
301        for name, parameter in inspect.signature(cls).parameters.items():
302            if name == "kwargs":
303                has_kwargs = True
304                continue
305            deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty)
306            deserialized.append(name)
307
308        # to deserialze kwargs we can't rely on inspecting the signature, so we
309        # go through the remaning kwarg names in init data instead
310        if has_kwargs:
311            kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized))
312            for name in kwarg_names:
313                if name.endswith("_kwargs"):
314                    continue
315                elif name.endswith("_dataset"):
316                    deserializer.load(name.replace("dataset", "loader"), optional=False)
317                elif name.endswith("_class"):
318                    deserializer.load(name.replace("_class", ""), optional=False)
319                else:
320                    deserializer.load(name, optional=False)
321
322        trainer = cls(**deserializer.trainer_kwargs)
323        trainer._initialize(0, save_dict)
324        trainer._is_initialized = True
325        return trainer
326
327    class Serializer:
328        """Implements how to serialize trainer kwargs from a trainer instance.
329
330        Examples:
331            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
332            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
333            called by the `dump()` method when appropriate cover most cases already.
334
335            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
336            'the_answer' attribute:
337            >>> class MyTrainer(DefaultTrainer):
338            >>>     def __init__(self, *args, the_answer: int, **kwargs):
339            >>>         super().__init__(*args, **kwargs)
340            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
341            >>>         # but let's make things more interesting...
342            >>>         self.the = the_answer // 10
343            >>>         self.answer = the_answer % 10
344            >>>
345            >>>     class Serializer(DefaultTrainer.Serializer):
346            >>>         trainer: MyTrainer
347            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
348            >>>             assert kwarg_name == "the_answer"
349            >>>             # populate self.init_data with the serialized data required by Deserializer
350            >>>             # to restore the trainer kwargs
351            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
352
353            This example with both Serializer and Deserializer adds `the_answer` kwarg,
354            while saving it in two separate entries 'the' and 'answer'
355            >>> class MyTrainer(DefaultTrainer):
356            >>>     def __init__(self, *args, the_answer: int, **kwargs):
357            >>>         super().__init__(*args, **kwargs)
358            >>>         self.the_answer = the_answer
359            >>>
360            >>>     class Serializer(DefaultTrainer.Serializer):
361            >>>         trainer: MyTrainer
362            >>>         def dump_the_answer(self, kwarg_name: str):
363            >>>             assert kwarg_name == "the_answer"
364            >>>             self.init_data.update({
365            >>>                 "the": self.trainer.the_answer // 10,
366            >>>                 "answer": self.trainer.the_answer % 10
367            >>>             })
368            >>>
369            >>>     class Deserializer(DefaultTrainer.Deserializer):
370            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
371            >>>             assert kwarg_name == "the_answer"
372            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
373            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
374
375        Args:
376            trainer: The trainer instance.
377        """
378
379        def __init__(self, trainer: DefaultTrainer):
380            self.trainer = trainer
381            self.init_data = {}  # to be populated during serialization process
382
383        def dump(self, kwarg_name: str) -> None:
384            """@private
385            """
386            dumper = getattr(self, f"dump_{kwarg_name}", None)
387            if dumper is not None:
388                dumper(kwarg_name)
389            elif kwarg_name.endswith("_loader"):
390                self.dump_data_loader(kwarg_name)
391            elif kwarg_name.endswith("_class"):
392                self.dump_generic_class(kwarg_name)
393            elif not hasattr(self.trainer, kwarg_name):
394                raise AttributeError(
395                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
396                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
397                )
398            else:
399                assert hasattr(self.trainer, kwarg_name)
400                obj = getattr(self.trainer, kwarg_name)
401                if obj is None or type(obj) in (
402                    bool,
403                    bytearray,
404                    bytes,
405                    dict,
406                    float,
407                    frozenset,
408                    int,
409                    list,
410                    set,
411                    str,
412                    tuple,
413                ):
414                    self.dump_generic_builtin(kwarg_name)
415                else:
416                    self.dump_generic_instance(kwarg_name)
417
418        def dump_generic_builtin(self, kwarg_name: str) -> None:
419            """@private
420            """
421            assert hasattr(self.trainer, kwarg_name)
422            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
423
424        def dump_generic_class(self, kwarg_name: str) -> None:
425            """@private
426            """
427            assert hasattr(self.trainer, kwarg_name)
428            assert kwarg_name.endswith("_class")
429            obj = getattr(self.trainer, kwarg_name)
430            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
431
432        def dump_generic_instance(self, kwarg_name: str) -> None:
433            """@private
434            """
435            assert hasattr(self.trainer, kwarg_name)
436            instance = getattr(self.trainer, kwarg_name)
437            self.init_data.update(
438                {
439                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
440                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
441                }
442            )
443
444        def dump_device(self, kwarg_name: str):
445            """@private
446            """
447            assert hasattr(self.trainer, kwarg_name)
448            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
449
450        def dump_data_loader(self, kwarg_name: str) -> None:
451            """@private
452            """
453            assert hasattr(self.trainer, kwarg_name)
454            loader = getattr(self.trainer, kwarg_name)
455            if loader is None:
456                return
457            self.init_data.update(
458                {
459                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
460                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
461                }
462            )
463
464        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
465            """@private
466            """
467            self.dump_generic_class(f"{kwarg_name}_class")
468
469        def dump_model(self, kwarg_name: str):
470            """@private
471            """
472            if is_compiled(self.trainer.model):
473                self.init_data.update(
474                    {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs}
475                )
476            else:
477                self.dump_generic_instance("model")
478
479    def _build_init(self) -> Dict[str, Any]:
480        serializer = self.Serializer(self)
481        for name in inspect.signature(self.__class__).parameters:
482            # special rules to serialize kwargs
483            # if a trainer class inherits from DefaultTrainer and has **kwargs
484            # they need to be saved in self._kwargs
485            if name == "kwargs":
486                if not hasattr(self, "_kwargs"):
487                    msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\
488                          "Please add self._kwargs to its __init__ function"
489                    raise RuntimeError(msg)
490                kwargs = getattr(self, "_kwargs")
491                for kwarg_name in kwargs:
492                    serializer.dump(kwarg_name)
493                continue
494            serializer.dump(name)
495
496        return serializer.init_data
497
498    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
499        assert self.train_loader is not None
500        assert self.val_loader is not None
501        assert self.model is not None
502        assert self.loss is not None
503        assert self.optimizer is not None
504        assert self.metric is not None
505        assert self.device is not None
506
507        if load_from_checkpoint is not None:
508            self.load_checkpoint(load_from_checkpoint)
509
510        if sum((iterations is not None, epochs is not None)) != 1:
511            raise ValueError(
512                "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer."
513                f"You have passed 'iterations'={iterations} and 'epochs'={epochs}"
514            )
515
516        if epochs is None:
517            epochs = int(np.ceil(float(iterations) / len(self.train_loader)))
518        else:
519            iterations = epochs * len(self.train_loader)
520
521        self.max_iteration = self._iteration + iterations
522        self.max_epoch = self._epoch + epochs
523
524        if not getattr(self, "_is_initialized", False):
525            # check if we compile the model (only supported by pytorch 2)
526            # to enable (de)serialization of compiled models, we keep track of the model class and kwargs
527            if is_compiled(self.model):
528                warnings.warn(
529                    "You have passed a compiled model to the trainer."
530                    "It will not be possible to (de)serialize the trainer with it."
531                    "If you want to be able to do this please pass the normal model."
532                    "It can be automatically compiled by setting 'compile_model' to True"
533                )
534            self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}"
535            self._model_kwargs = get_constructor_arguments(self.model)
536            self.model = auto_compile(self.model, self.compile_model)
537
538            self.model.to(self.device)
539            self.loss.to(self.device)
540
541            # this saves all the information that is necessary
542            # to fully load the trainer from the checkpoint
543            self.init_data = self._build_init()
544
545            if self.logger_class is None:
546                self.logger = None
547            else:
548                # may set self.name if self.name is None
549                save_root = getattr(self, "save_root", None)
550                try:
551                    self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {}))
552                except PermissionError:
553                    warnings.warn(
554                        f"The checkpoint folder at {self.checkpoint_folder} could not be created."
555                        "The most likely reason for this is that you copied the checkpoint somewhere else,"
556                        "so we skip this error to enable loading the model from this checkpoint."
557                    )
558
559            try:
560                os.makedirs(self.checkpoint_folder, exist_ok=True)
561            except PermissionError:
562                warnings.warn(
563                    f"The checkpoint folder at {self.checkpoint_folder} could not be created."
564                    "The most likely reason for this is that you copied the checkpoint somewhere else,"
565                    "so we skip this error to enable loading the model from this checkpoint."
566                )
567                pass
568
569        best_metric = np.inf
570        return best_metric
571
572    def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
573        """@private
574        """
575        save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
576        extra_init_dict = extra_save_dict.pop("init", {})
577        save_dict = {
578            "iteration": self._iteration,
579            "epoch": self._epoch,
580            "best_epoch": self._best_epoch,
581            "best_metric": best_metric,
582            "current_metric": current_metric,
583            "model_state": self.model.state_dict(),
584            "optimizer_state": self.optimizer.state_dict(),
585            "init": self.init_data | extra_init_dict,
586            "train_time": train_time,
587        }
588        save_dict.update(**extra_save_dict)
589        if self.scaler is not None:
590            save_dict.update({"scaler_state": self.scaler.state_dict()})
591        if self.lr_scheduler is not None:
592            save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()})
593
594        rank = getattr(self, "rank", None)
595        if rank is None or rank == 0:
596            torch.save(save_dict, save_path)
597
598    def load_checkpoint(self, checkpoint="best"):
599        """@private
600        """
601        if isinstance(checkpoint, str):
602            save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
603            if not os.path.exists(save_path):
604                warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.")
605                return
606            save_dict = torch.load(save_path, weights_only=False)
607        elif isinstance(checkpoint, dict):
608            save_dict = checkpoint
609        else:
610            raise RuntimeError
611
612        self._iteration = save_dict["iteration"]
613        self._epoch = save_dict["epoch"]
614        self._best_epoch = save_dict["best_epoch"]
615        self.best_metric = save_dict["best_metric"]
616        self.current_metric = save_dict["current_metric"]
617        self.train_time = save_dict.get("train_time", 0.0)
618
619        model_state = save_dict["model_state"]
620        # to enable loading compiled models
621        compiled_prefix = "_orig_mod."
622        model_state = OrderedDict(
623            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()]
624        )
625        self.model.load_state_dict(model_state)
626        # we need to send the network to the device before loading the optimizer state!
627        self.model.to(self.device)
628
629        self.optimizer.load_state_dict(save_dict["optimizer_state"])
630        if self.scaler is not None:
631            self.scaler.load_state_dict(save_dict["scaler_state"])
632        if self.lr_scheduler is not None:
633            self.lr_scheduler.load_state_dict(save_dict["scheduler_state"])
634
635        return save_dict
636
637    def _verify_if_training_completed(self, checkpoint="latest"):
638        save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
639        save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None
640        if save_dict and self.max_iteration == save_dict.get("iteration"):
641            return True
642        return False
643
644    def fit(
645        self,
646        iterations: Optional[int] = None,
647        load_from_checkpoint: Optional[Union[os.PathLike, str]] = None,
648        epochs: Optional[int] = None,
649        save_every_kth_epoch: Optional[int] = None,
650        progress=None,
651        overwrite_training: bool = True,
652    ):
653        """Run neural network training.
654
655        Exactly one of 'iterations' or 'epochs' has to be passed.
656
657        Args:
658            iterations: How long to train, specified in iterations.
659            load_from_checkpoint: Path to a checkpoint from where training should be continued .
660            epochs: How long to train, specified in epochs.
661            save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file.
662                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
663            progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
664            overwrite_training: Whether to overwrite existing checkpoints in the save directory.
665        """
666        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
667
668        if not overwrite_training:
669            if load_from_checkpoint is not None:
670                raise ValueError(
671                    "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time."
672                )
673
674            if self._verify_if_training_completed():
675                print(
676                    f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs "
677                    "and 'overwrite_training' is set to 'False'."
678                )
679                print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.")
680                return
681
682        print(
683            "Start fitting for",
684            self.max_iteration - self._iteration,
685            "iterations / ",
686            self.max_epoch - self._epoch,
687            "epochs",
688        )
689        print("with", len(self.train_loader), "iterations per epoch")
690
691        if self.mixed_precision:
692            train_epoch = self._train_epoch_mixed
693            validate = self._validate_mixed
694            print("Training with mixed precision")
695        else:
696            train_epoch = self._train_epoch
697            validate = self._validate
698            print("Training with single precision")
699
700        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
701        if progress is None:
702            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
703        else:
704            progress.total = total_iterations
705            progress.set_description(f"Epoch {self._epoch}")
706
707        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
708        train_epochs = self.max_epoch - self._epoch
709        t_start = time.time()
710        for epoch in range(train_epochs):
711
712            # Ensure data is shuffled differently at each epoch.
713            try:
714                self.train_loader.sampler.set_epoch(epoch)
715            except AttributeError:
716                pass
717
718            # Run training and validation for this epoch
719            t_per_iter = train_epoch(progress)
720            current_metric = validate()
721
722            # perform all the post-epoch steps:
723
724            # apply the learning rate scheduler
725            if self.lr_scheduler is not None:
726                self.lr_scheduler.step(current_metric)
727
728            # how long did we train in total?
729            total_train_time = (time.time() - t_start) + self.train_time
730
731            # save this checkpoint as the new best checkpoint if
732            # it has the best overall validation metric
733            if current_metric < best_metric:
734                best_metric = current_metric
735                self._best_epoch = self._epoch
736                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
737
738            # save this checkpoint as the latest checkpoint
739            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
740
741            # if we save after every k-th epoch then check if we need to save now
742            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
743                self.save_checkpoint(
744                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
745                )
746
747            # if early stopping has been specified then check if the stopping condition is met
748            if self.early_stopping is not None:
749                epochs_since_best = self._epoch - self._best_epoch
750                if epochs_since_best > self.early_stopping:
751                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
752                    break
753
754            self._epoch += 1
755            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
756
757        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
758        print(f"The best epoch is number {self._best_epoch}.")
759
760        if self._generate_name:
761            self.name = None
762
763        # Update the train time
764        self.train_time = total_train_time
765
766        # TODO save the model to wandb if we have the wandb logger
767        if isinstance(self.logger, WandbLogger):
768            self.logger.get_wandb().finish()
769
770    def _backprop(self, loss):
771        loss.backward()
772        self.optimizer.step()
773
774    def _backprop_mixed(self, loss):
775        self.scaler.scale(loss).backward()
776        self.scaler.step(self.optimizer)
777        self.scaler.update()
778
779    def _train_epoch(self, progress):
780        return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop)
781
782    def _train_epoch_mixed(self, progress):
783        return self._train_epoch_impl(
784            progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"),
785            self._backprop_mixed
786        )
787
788    def _forward_and_loss(self, x, y):
789        pred = self.model(x)
790        if self._iteration % self.log_image_interval == 0:
791            if pred.requires_grad:
792                pred.retain_grad()
793
794        loss = self.loss(pred, y)
795        return pred, loss
796
797    def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]):
798        self.model.train()
799
800        n_iter = 0
801        t_per_iter = time.time()
802        for x, y in self.train_loader:
803            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
804
805            self.optimizer.zero_grad()
806
807            with forward_context():
808                pred, loss = self._forward_and_loss(x, y)
809
810            backprop(loss)
811
812            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
813            if self.logger is not None:
814                self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True)
815
816            self._iteration += 1
817            n_iter += 1
818            if self._iteration >= self.max_iteration:
819                break
820            progress.update(1)
821
822        t_per_iter = (time.time() - t_per_iter) / n_iter
823        return t_per_iter
824
825    def _validate(self):
826        return self._validate_impl(contextlib.nullcontext)
827
828    def _validate_mixed(self):
829        return self._validate_impl(
830            partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda")
831        )
832
833    def _validate_impl(self, forward_context):
834        self.model.eval()
835
836        metric_val = 0.0
837        loss_val = 0.0
838
839        with torch.no_grad():
840            for x, y in self.val_loader:
841                x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
842                with forward_context():
843                    pred, loss = self._forward_and_loss(x, y)
844                    metric = self.metric(pred, y)
845
846                loss_val += loss.item()
847                metric_val += metric.item()
848
849        metric_val /= len(self.val_loader)
850        loss_val /= len(self.val_loader)
851        if self.logger is not None:
852            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred)
853        return metric_val
class DefaultTrainer:
 24class DefaultTrainer:
 25    """Trainer class for training a segmentation network.
 26
 27    The trainer class implements the core logic for training a network with pytorch.
 28    It implements a training loop to run training and validation, which is started with `fit`.
 29    The checkpoints and logs from the training run will be saved in the current working directory,
 30    or in the directory specifified by `save_root`. Training can be continued from a checkpoint
 31    by passing it's location to the `load_from_checkpoint` argument of `fit`.
 32
 33    A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`.
 34    Alternatively, the trainer class can also be instantiated as in this example:
 35    ```python
 36    import torch
 37    from torch_em.loss import DiceLoss
 38    from torch_em.model import UNet2d
 39    from torch_em.data.datasets.light_microscopy import get_dsb_loader
 40    from torch_em.trainer import DefaultTrainer
 41
 42    # The training data will be downloaded to this location.
 43    data_root = "/path/to/save/the/training/data"
 44    patch_shape = (256, 256)
 45
 46    # Create the model and optimizer.
 47    model = UNet2d(in_channels=1, out_channels=1)
 48    optimizer = torch.optim.AdamW(model.parameters())
 49
 50    trainer = DefaultTrainer(
 51        name="unet-training",
 52        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
 53        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
 54        model=model,
 55        loss=DiceLoss(),  # The loss function.
 56        optimizer=optimizer,
 57        metric=DiceLoss(),  # The metric. The trainer expects smaller values to represent better results.
 58        device="cuda",  # The device to use for training.
 59    )
 60    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
 61    ```
 62
 63    Args:
 64        name: The name of the checkpoint that will be created by the trainer.
 65        train_loader: The data loader containing the training data.
 66        val_loader: The data loader containing the validation data.
 67        model: The model to train.
 68        loss: The loss function for training.
 69        optimizer: The optimizer.
 70        metric: The metric for validation.
 71        device: The torch device to use for training. If None, will use a GPU if available.
 72        lr_scheduler: The learning rate scheduler.
 73        log_image_interval: The interval for saving images during logging, in training iterations.
 74        mixed_precision: Whether to train with mixed precision.
 75        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
 76        logger: The logger class. Will be instantiated for logging.
 77            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
 78        logger_kwargs: The keyword arguments for the logger class.
 79        id_: Unique identifier for the trainer. If None then `name` will be used.
 80        save_root: The root folder for saving the checkpoint and logs.
 81        compile_model: Whether to compile the model before training.
 82        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
 83    """
 84    def __init__(
 85        self,
 86        name: Optional[str],
 87        train_loader: torch.utils.data.DataLoader,
 88        val_loader: torch.utils.data.DataLoader,
 89        model: torch.nn.Module,
 90        loss: torch.nn.Module,
 91        optimizer: torch.optim.Optimizer,
 92        metric: Callable,
 93        device: Union[str, torch.device],
 94        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
 95        log_image_interval: int = 100,
 96        mixed_precision: bool = True,
 97        early_stopping: Optional[int] = None,
 98        logger=TensorboardLogger,
 99        logger_kwargs: Optional[Dict[str, Any]] = None,
100        id_: Optional[str] = None,
101        save_root: Optional[str] = None,
102        compile_model: Optional[Union[bool, str]] = None,
103        rank: Optional[int] = None,
104    ):
105        if name is None and not issubclass(logger, WandbLogger):
106            raise TypeError("Name cannot be None if not using the WandbLogger")
107
108        if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
109            raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.")
110
111        self._generate_name = name is None
112        self.name = name
113        self.id_ = id_ or name
114        self.train_loader = train_loader
115        self.val_loader = val_loader
116        self.model = model
117        self.loss = loss
118        self.optimizer = optimizer
119        self.metric = metric
120        self.device = torch.device(device)
121        self.lr_scheduler = lr_scheduler
122        self.log_image_interval = log_image_interval
123        self.save_root = save_root
124        self.compile_model = compile_model
125        self.rank = rank
126
127        self._iteration = 0
128        self._epoch = 0
129        self._best_epoch = 0
130
131        self.mixed_precision = mixed_precision
132        self.early_stopping = early_stopping
133        self.train_time = 0.0
134
135        if mixed_precision:
136            self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda")
137        else:
138            self.scaler = None
139
140        self.logger_class = logger
141        self.logger_kwargs = logger_kwargs
142        self.log_image_interval = log_image_interval
143
144    @property
145    def checkpoint_folder(self):
146        assert self.id_ is not None  # Because the logger may generate and set trainer.id on logger.__init__.
147        # Save_root enables saving the checkpoints somewhere else than in the local older.
148        # This is handy for filesystems with limited space, where saving the checkpoints
149        # and log files can ead to running out of space.
150        save_root = getattr(self, "save_root", None)
151        return os.path.join("./checkpoints", self.id_) if save_root is None else\
152            os.path.join(save_root, "./checkpoints", self.id_)
153
154    @property
155    def iteration(self):
156        return self._iteration
157
158    @property
159    def epoch(self):
160        return self._epoch
161
162    class Deserializer:
163        """Determines how to deserialize the trainer kwargs from serialized 'init_data'.
164
165        Examples:
166            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
167            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
168
169            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
170            >>> class MyTrainer(DefaultTrainer):
171            >>>     def __init__(self, *args, the_answer: int, **kwargs):
172            >>>         super().__init__(*args, **kwargs)
173            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
174            >>>                                       # see DefaultTrainer.Serializer
175            >>>
176            >>>     class Deserializer(DefaultTrainer.Deserializer):
177            >>>         def load_the_answer(self):
178            >>>             generic_answer = self.init_data["the_answer"]
179            >>>             # (device dependent) special deserialization
180            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
181            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
182            >>>             else:
183            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
184
185        Args:
186            init_data: The initialization data of the trainer.
187            save_path: The path where the checkpoint was saved.
188            device: The device.
189        """
190
191        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
192            self.init_data = init_data
193            self.save_path = save_path
194            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
195            self.trainer_kwargs: Dict[str, Any] = dict(
196                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
197            )
198
199        def load(self, kwarg_name: str, optional):
200            """@private
201            """
202            # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
203            if kwarg_name == "device":
204                pass  # deserialized in __init__
205            elif kwarg_name.endswith("_loader"):
206                self.load_data_loader(kwarg_name, optional)
207            else:
208                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
209                load(kwarg_name, optional=optional)
210
211        def load_data_loader(self, loader_name, optional) -> None:
212            """@private
213            """
214            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
215            if ds is None and optional:
216                return
217
218            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
219            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
220            # monkey patch shuffle loader_name to the loader
221            loader.shuffle = loader_kwargs.get("shuffle", False)
222            self.trainer_kwargs[loader_name] = loader
223
224        def load_generic(
225            self,
226            kwarg_name: str,
227            *dynamic_args: Dict,
228            optional: bool,
229            only_class: bool = False,
230            dynamic_kwargs: Optional[Dict[str, Any]] = None,
231        ) -> None:
232            """@private
233            """
234            if kwarg_name in self.init_data:
235                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
236                return
237
238            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
239            if this_cls is None:
240                if optional:
241                    return
242                else:
243                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
244
245            assert isinstance(this_cls, str), this_cls
246            assert "." in this_cls, this_cls
247            cls_p, cls_m = this_cls.rsplit(".", 1)
248            this_cls = getattr(import_module(cls_p), cls_m)
249            if only_class:
250                self.trainer_kwargs[kwarg_name] = this_cls
251            else:
252                self.trainer_kwargs[kwarg_name] = this_cls(
253                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
254                )
255
256        def load_name(self, kwarg_name: str, optional: bool):
257            """@private
258            """
259            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
260
261        def load_optimizer(self, kwarg_name: str, optional: bool):
262            """@private
263            """
264            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
265
266        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
267            """@private
268            """
269            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
270
271        # todo: remove and rename kwarg 'logger' to 'logger_class'
272        def load_logger(self, kwarg_name: str, optional: bool):
273            """@private
274            """
275            assert kwarg_name == "logger"
276            self.load_generic("logger", optional=optional, only_class=True)
277
278    @staticmethod
279    def _get_save_dict(save_path, device):
280        if not os.path.exists(save_path):
281            raise ValueError(f"Cannot find checkpoint {save_path}")
282        return torch.load(save_path, map_location=device, weights_only=False)
283
284    @classmethod
285    def from_checkpoint(
286        cls,
287        checkpoint_folder: Union[os.PathLike, str],
288        name: Literal["best", "latest"] = "best",
289        device: Optional[Union[str, torch.device]] = None,
290    ):
291        """@private
292        """
293        save_path = os.path.join(checkpoint_folder, f"{name}.pt")
294        # make sure the correct device is set if we don't have access to CUDA
295        if not torch.cuda.is_available():
296            device = "cpu"
297        save_dict = cls._get_save_dict(save_path, device)
298        deserializer = cls.Deserializer(save_dict["init"], save_path, device)
299
300        has_kwargs = False
301        deserialized = []
302        for name, parameter in inspect.signature(cls).parameters.items():
303            if name == "kwargs":
304                has_kwargs = True
305                continue
306            deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty)
307            deserialized.append(name)
308
309        # to deserialze kwargs we can't rely on inspecting the signature, so we
310        # go through the remaning kwarg names in init data instead
311        if has_kwargs:
312            kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized))
313            for name in kwarg_names:
314                if name.endswith("_kwargs"):
315                    continue
316                elif name.endswith("_dataset"):
317                    deserializer.load(name.replace("dataset", "loader"), optional=False)
318                elif name.endswith("_class"):
319                    deserializer.load(name.replace("_class", ""), optional=False)
320                else:
321                    deserializer.load(name, optional=False)
322
323        trainer = cls(**deserializer.trainer_kwargs)
324        trainer._initialize(0, save_dict)
325        trainer._is_initialized = True
326        return trainer
327
328    class Serializer:
329        """Implements how to serialize trainer kwargs from a trainer instance.
330
331        Examples:
332            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
333            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
334            called by the `dump()` method when appropriate cover most cases already.
335
336            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
337            'the_answer' attribute:
338            >>> class MyTrainer(DefaultTrainer):
339            >>>     def __init__(self, *args, the_answer: int, **kwargs):
340            >>>         super().__init__(*args, **kwargs)
341            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
342            >>>         # but let's make things more interesting...
343            >>>         self.the = the_answer // 10
344            >>>         self.answer = the_answer % 10
345            >>>
346            >>>     class Serializer(DefaultTrainer.Serializer):
347            >>>         trainer: MyTrainer
348            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
349            >>>             assert kwarg_name == "the_answer"
350            >>>             # populate self.init_data with the serialized data required by Deserializer
351            >>>             # to restore the trainer kwargs
352            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
353
354            This example with both Serializer and Deserializer adds `the_answer` kwarg,
355            while saving it in two separate entries 'the' and 'answer'
356            >>> class MyTrainer(DefaultTrainer):
357            >>>     def __init__(self, *args, the_answer: int, **kwargs):
358            >>>         super().__init__(*args, **kwargs)
359            >>>         self.the_answer = the_answer
360            >>>
361            >>>     class Serializer(DefaultTrainer.Serializer):
362            >>>         trainer: MyTrainer
363            >>>         def dump_the_answer(self, kwarg_name: str):
364            >>>             assert kwarg_name == "the_answer"
365            >>>             self.init_data.update({
366            >>>                 "the": self.trainer.the_answer // 10,
367            >>>                 "answer": self.trainer.the_answer % 10
368            >>>             })
369            >>>
370            >>>     class Deserializer(DefaultTrainer.Deserializer):
371            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
372            >>>             assert kwarg_name == "the_answer"
373            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
374            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
375
376        Args:
377            trainer: The trainer instance.
378        """
379
380        def __init__(self, trainer: DefaultTrainer):
381            self.trainer = trainer
382            self.init_data = {}  # to be populated during serialization process
383
384        def dump(self, kwarg_name: str) -> None:
385            """@private
386            """
387            dumper = getattr(self, f"dump_{kwarg_name}", None)
388            if dumper is not None:
389                dumper(kwarg_name)
390            elif kwarg_name.endswith("_loader"):
391                self.dump_data_loader(kwarg_name)
392            elif kwarg_name.endswith("_class"):
393                self.dump_generic_class(kwarg_name)
394            elif not hasattr(self.trainer, kwarg_name):
395                raise AttributeError(
396                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
397                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
398                )
399            else:
400                assert hasattr(self.trainer, kwarg_name)
401                obj = getattr(self.trainer, kwarg_name)
402                if obj is None or type(obj) in (
403                    bool,
404                    bytearray,
405                    bytes,
406                    dict,
407                    float,
408                    frozenset,
409                    int,
410                    list,
411                    set,
412                    str,
413                    tuple,
414                ):
415                    self.dump_generic_builtin(kwarg_name)
416                else:
417                    self.dump_generic_instance(kwarg_name)
418
419        def dump_generic_builtin(self, kwarg_name: str) -> None:
420            """@private
421            """
422            assert hasattr(self.trainer, kwarg_name)
423            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
424
425        def dump_generic_class(self, kwarg_name: str) -> None:
426            """@private
427            """
428            assert hasattr(self.trainer, kwarg_name)
429            assert kwarg_name.endswith("_class")
430            obj = getattr(self.trainer, kwarg_name)
431            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
432
433        def dump_generic_instance(self, kwarg_name: str) -> None:
434            """@private
435            """
436            assert hasattr(self.trainer, kwarg_name)
437            instance = getattr(self.trainer, kwarg_name)
438            self.init_data.update(
439                {
440                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
441                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
442                }
443            )
444
445        def dump_device(self, kwarg_name: str):
446            """@private
447            """
448            assert hasattr(self.trainer, kwarg_name)
449            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
450
451        def dump_data_loader(self, kwarg_name: str) -> None:
452            """@private
453            """
454            assert hasattr(self.trainer, kwarg_name)
455            loader = getattr(self.trainer, kwarg_name)
456            if loader is None:
457                return
458            self.init_data.update(
459                {
460                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
461                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
462                }
463            )
464
465        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
466            """@private
467            """
468            self.dump_generic_class(f"{kwarg_name}_class")
469
470        def dump_model(self, kwarg_name: str):
471            """@private
472            """
473            if is_compiled(self.trainer.model):
474                self.init_data.update(
475                    {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs}
476                )
477            else:
478                self.dump_generic_instance("model")
479
480    def _build_init(self) -> Dict[str, Any]:
481        serializer = self.Serializer(self)
482        for name in inspect.signature(self.__class__).parameters:
483            # special rules to serialize kwargs
484            # if a trainer class inherits from DefaultTrainer and has **kwargs
485            # they need to be saved in self._kwargs
486            if name == "kwargs":
487                if not hasattr(self, "_kwargs"):
488                    msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\
489                          "Please add self._kwargs to its __init__ function"
490                    raise RuntimeError(msg)
491                kwargs = getattr(self, "_kwargs")
492                for kwarg_name in kwargs:
493                    serializer.dump(kwarg_name)
494                continue
495            serializer.dump(name)
496
497        return serializer.init_data
498
499    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
500        assert self.train_loader is not None
501        assert self.val_loader is not None
502        assert self.model is not None
503        assert self.loss is not None
504        assert self.optimizer is not None
505        assert self.metric is not None
506        assert self.device is not None
507
508        if load_from_checkpoint is not None:
509            self.load_checkpoint(load_from_checkpoint)
510
511        if sum((iterations is not None, epochs is not None)) != 1:
512            raise ValueError(
513                "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer."
514                f"You have passed 'iterations'={iterations} and 'epochs'={epochs}"
515            )
516
517        if epochs is None:
518            epochs = int(np.ceil(float(iterations) / len(self.train_loader)))
519        else:
520            iterations = epochs * len(self.train_loader)
521
522        self.max_iteration = self._iteration + iterations
523        self.max_epoch = self._epoch + epochs
524
525        if not getattr(self, "_is_initialized", False):
526            # check if we compile the model (only supported by pytorch 2)
527            # to enable (de)serialization of compiled models, we keep track of the model class and kwargs
528            if is_compiled(self.model):
529                warnings.warn(
530                    "You have passed a compiled model to the trainer."
531                    "It will not be possible to (de)serialize the trainer with it."
532                    "If you want to be able to do this please pass the normal model."
533                    "It can be automatically compiled by setting 'compile_model' to True"
534                )
535            self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}"
536            self._model_kwargs = get_constructor_arguments(self.model)
537            self.model = auto_compile(self.model, self.compile_model)
538
539            self.model.to(self.device)
540            self.loss.to(self.device)
541
542            # this saves all the information that is necessary
543            # to fully load the trainer from the checkpoint
544            self.init_data = self._build_init()
545
546            if self.logger_class is None:
547                self.logger = None
548            else:
549                # may set self.name if self.name is None
550                save_root = getattr(self, "save_root", None)
551                try:
552                    self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {}))
553                except PermissionError:
554                    warnings.warn(
555                        f"The checkpoint folder at {self.checkpoint_folder} could not be created."
556                        "The most likely reason for this is that you copied the checkpoint somewhere else,"
557                        "so we skip this error to enable loading the model from this checkpoint."
558                    )
559
560            try:
561                os.makedirs(self.checkpoint_folder, exist_ok=True)
562            except PermissionError:
563                warnings.warn(
564                    f"The checkpoint folder at {self.checkpoint_folder} could not be created."
565                    "The most likely reason for this is that you copied the checkpoint somewhere else,"
566                    "so we skip this error to enable loading the model from this checkpoint."
567                )
568                pass
569
570        best_metric = np.inf
571        return best_metric
572
573    def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
574        """@private
575        """
576        save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
577        extra_init_dict = extra_save_dict.pop("init", {})
578        save_dict = {
579            "iteration": self._iteration,
580            "epoch": self._epoch,
581            "best_epoch": self._best_epoch,
582            "best_metric": best_metric,
583            "current_metric": current_metric,
584            "model_state": self.model.state_dict(),
585            "optimizer_state": self.optimizer.state_dict(),
586            "init": self.init_data | extra_init_dict,
587            "train_time": train_time,
588        }
589        save_dict.update(**extra_save_dict)
590        if self.scaler is not None:
591            save_dict.update({"scaler_state": self.scaler.state_dict()})
592        if self.lr_scheduler is not None:
593            save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()})
594
595        rank = getattr(self, "rank", None)
596        if rank is None or rank == 0:
597            torch.save(save_dict, save_path)
598
599    def load_checkpoint(self, checkpoint="best"):
600        """@private
601        """
602        if isinstance(checkpoint, str):
603            save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
604            if not os.path.exists(save_path):
605                warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.")
606                return
607            save_dict = torch.load(save_path, weights_only=False)
608        elif isinstance(checkpoint, dict):
609            save_dict = checkpoint
610        else:
611            raise RuntimeError
612
613        self._iteration = save_dict["iteration"]
614        self._epoch = save_dict["epoch"]
615        self._best_epoch = save_dict["best_epoch"]
616        self.best_metric = save_dict["best_metric"]
617        self.current_metric = save_dict["current_metric"]
618        self.train_time = save_dict.get("train_time", 0.0)
619
620        model_state = save_dict["model_state"]
621        # to enable loading compiled models
622        compiled_prefix = "_orig_mod."
623        model_state = OrderedDict(
624            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()]
625        )
626        self.model.load_state_dict(model_state)
627        # we need to send the network to the device before loading the optimizer state!
628        self.model.to(self.device)
629
630        self.optimizer.load_state_dict(save_dict["optimizer_state"])
631        if self.scaler is not None:
632            self.scaler.load_state_dict(save_dict["scaler_state"])
633        if self.lr_scheduler is not None:
634            self.lr_scheduler.load_state_dict(save_dict["scheduler_state"])
635
636        return save_dict
637
638    def _verify_if_training_completed(self, checkpoint="latest"):
639        save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
640        save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None
641        if save_dict and self.max_iteration == save_dict.get("iteration"):
642            return True
643        return False
644
645    def fit(
646        self,
647        iterations: Optional[int] = None,
648        load_from_checkpoint: Optional[Union[os.PathLike, str]] = None,
649        epochs: Optional[int] = None,
650        save_every_kth_epoch: Optional[int] = None,
651        progress=None,
652        overwrite_training: bool = True,
653    ):
654        """Run neural network training.
655
656        Exactly one of 'iterations' or 'epochs' has to be passed.
657
658        Args:
659            iterations: How long to train, specified in iterations.
660            load_from_checkpoint: Path to a checkpoint from where training should be continued .
661            epochs: How long to train, specified in epochs.
662            save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file.
663                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
664            progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
665            overwrite_training: Whether to overwrite existing checkpoints in the save directory.
666        """
667        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
668
669        if not overwrite_training:
670            if load_from_checkpoint is not None:
671                raise ValueError(
672                    "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time."
673                )
674
675            if self._verify_if_training_completed():
676                print(
677                    f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs "
678                    "and 'overwrite_training' is set to 'False'."
679                )
680                print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.")
681                return
682
683        print(
684            "Start fitting for",
685            self.max_iteration - self._iteration,
686            "iterations / ",
687            self.max_epoch - self._epoch,
688            "epochs",
689        )
690        print("with", len(self.train_loader), "iterations per epoch")
691
692        if self.mixed_precision:
693            train_epoch = self._train_epoch_mixed
694            validate = self._validate_mixed
695            print("Training with mixed precision")
696        else:
697            train_epoch = self._train_epoch
698            validate = self._validate
699            print("Training with single precision")
700
701        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
702        if progress is None:
703            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
704        else:
705            progress.total = total_iterations
706            progress.set_description(f"Epoch {self._epoch}")
707
708        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
709        train_epochs = self.max_epoch - self._epoch
710        t_start = time.time()
711        for epoch in range(train_epochs):
712
713            # Ensure data is shuffled differently at each epoch.
714            try:
715                self.train_loader.sampler.set_epoch(epoch)
716            except AttributeError:
717                pass
718
719            # Run training and validation for this epoch
720            t_per_iter = train_epoch(progress)
721            current_metric = validate()
722
723            # perform all the post-epoch steps:
724
725            # apply the learning rate scheduler
726            if self.lr_scheduler is not None:
727                self.lr_scheduler.step(current_metric)
728
729            # how long did we train in total?
730            total_train_time = (time.time() - t_start) + self.train_time
731
732            # save this checkpoint as the new best checkpoint if
733            # it has the best overall validation metric
734            if current_metric < best_metric:
735                best_metric = current_metric
736                self._best_epoch = self._epoch
737                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
738
739            # save this checkpoint as the latest checkpoint
740            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
741
742            # if we save after every k-th epoch then check if we need to save now
743            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
744                self.save_checkpoint(
745                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
746                )
747
748            # if early stopping has been specified then check if the stopping condition is met
749            if self.early_stopping is not None:
750                epochs_since_best = self._epoch - self._best_epoch
751                if epochs_since_best > self.early_stopping:
752                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
753                    break
754
755            self._epoch += 1
756            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
757
758        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
759        print(f"The best epoch is number {self._best_epoch}.")
760
761        if self._generate_name:
762            self.name = None
763
764        # Update the train time
765        self.train_time = total_train_time
766
767        # TODO save the model to wandb if we have the wandb logger
768        if isinstance(self.logger, WandbLogger):
769            self.logger.get_wandb().finish()
770
771    def _backprop(self, loss):
772        loss.backward()
773        self.optimizer.step()
774
775    def _backprop_mixed(self, loss):
776        self.scaler.scale(loss).backward()
777        self.scaler.step(self.optimizer)
778        self.scaler.update()
779
780    def _train_epoch(self, progress):
781        return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop)
782
783    def _train_epoch_mixed(self, progress):
784        return self._train_epoch_impl(
785            progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"),
786            self._backprop_mixed
787        )
788
789    def _forward_and_loss(self, x, y):
790        pred = self.model(x)
791        if self._iteration % self.log_image_interval == 0:
792            if pred.requires_grad:
793                pred.retain_grad()
794
795        loss = self.loss(pred, y)
796        return pred, loss
797
798    def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]):
799        self.model.train()
800
801        n_iter = 0
802        t_per_iter = time.time()
803        for x, y in self.train_loader:
804            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
805
806            self.optimizer.zero_grad()
807
808            with forward_context():
809                pred, loss = self._forward_and_loss(x, y)
810
811            backprop(loss)
812
813            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
814            if self.logger is not None:
815                self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True)
816
817            self._iteration += 1
818            n_iter += 1
819            if self._iteration >= self.max_iteration:
820                break
821            progress.update(1)
822
823        t_per_iter = (time.time() - t_per_iter) / n_iter
824        return t_per_iter
825
826    def _validate(self):
827        return self._validate_impl(contextlib.nullcontext)
828
829    def _validate_mixed(self):
830        return self._validate_impl(
831            partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda")
832        )
833
834    def _validate_impl(self, forward_context):
835        self.model.eval()
836
837        metric_val = 0.0
838        loss_val = 0.0
839
840        with torch.no_grad():
841            for x, y in self.val_loader:
842                x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
843                with forward_context():
844                    pred, loss = self._forward_and_loss(x, y)
845                    metric = self.metric(pred, y)
846
847                loss_val += loss.item()
848                metric_val += metric.item()
849
850        metric_val /= len(self.val_loader)
851        loss_val /= len(self.val_loader)
852        if self.logger is not None:
853            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred)
854        return metric_val

Trainer class for training a segmentation network.

The trainer class implements the core logic for training a network with pytorch. It implements a training loop to run training and validation, which is started with fit. The checkpoints and logs from the training run will be saved in the current working directory, or in the directory specifified by save_root. Training can be continued from a checkpoint by passing it's location to the load_from_checkpoint argument of fit.

A pre-configured instance of the trainer can be obtained from torch_em.default_segmentation_trainer. Alternatively, the trainer class can also be instantiated as in this example:

import torch
from torch_em.loss import DiceLoss
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
from torch_em.trainer import DefaultTrainer

# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)

# Create the model and optimizer.
model = UNet2d(in_channels=1, out_channels=1)
optimizer = torch.optim.AdamW(model.parameters())

trainer = DefaultTrainer(
    name="unet-training",
    train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
    val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
    model=model,
    loss=DiceLoss(),  # The loss function.
    optimizer=optimizer,
    metric=DiceLoss(),  # The metric. The trainer expects smaller values to represent better results.
    device="cuda",  # The device to use for training.
)
trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
Arguments:
  • name: The name of the checkpoint that will be created by the trainer.
  • train_loader: The data loader containing the training data.
  • val_loader: The data loader containing the validation data.
  • model: The model to train.
  • loss: The loss function for training.
  • optimizer: The optimizer.
  • metric: The metric for validation.
  • device: The torch device to use for training. If None, will use a GPU if available.
  • lr_scheduler: The learning rate scheduler.
  • log_image_interval: The interval for saving images during logging, in training iterations.
  • mixed_precision: Whether to train with mixed precision.
  • early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
  • logger: The logger class. Will be instantiated for logging. By default uses torch_em.training.tensorboard_logger.TensorboardLogger.
  • logger_kwargs: The keyword arguments for the logger class.
  • id_: Unique identifier for the trainer. If None then name will be used.
  • save_root: The root folder for saving the checkpoint and logs.
  • compile_model: Whether to compile the model before training.
  • rank: Rank argument for distributed training. See torch_em.multi_gpu_training for details.
DefaultTrainer( name: Optional[str], train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, model: torch.nn.modules.module.Module, loss: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, metric: Callable, device: Union[str, torch.device], lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, log_image_interval: int = 100, mixed_precision: bool = True, early_stopping: Optional[int] = None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, logger_kwargs: Optional[Dict[str, Any]] = None, id_: Optional[str] = None, save_root: Optional[str] = None, compile_model: Union[bool, str, NoneType] = None, rank: Optional[int] = None)
 84    def __init__(
 85        self,
 86        name: Optional[str],
 87        train_loader: torch.utils.data.DataLoader,
 88        val_loader: torch.utils.data.DataLoader,
 89        model: torch.nn.Module,
 90        loss: torch.nn.Module,
 91        optimizer: torch.optim.Optimizer,
 92        metric: Callable,
 93        device: Union[str, torch.device],
 94        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
 95        log_image_interval: int = 100,
 96        mixed_precision: bool = True,
 97        early_stopping: Optional[int] = None,
 98        logger=TensorboardLogger,
 99        logger_kwargs: Optional[Dict[str, Any]] = None,
100        id_: Optional[str] = None,
101        save_root: Optional[str] = None,
102        compile_model: Optional[Union[bool, str]] = None,
103        rank: Optional[int] = None,
104    ):
105        if name is None and not issubclass(logger, WandbLogger):
106            raise TypeError("Name cannot be None if not using the WandbLogger")
107
108        if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
109            raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.")
110
111        self._generate_name = name is None
112        self.name = name
113        self.id_ = id_ or name
114        self.train_loader = train_loader
115        self.val_loader = val_loader
116        self.model = model
117        self.loss = loss
118        self.optimizer = optimizer
119        self.metric = metric
120        self.device = torch.device(device)
121        self.lr_scheduler = lr_scheduler
122        self.log_image_interval = log_image_interval
123        self.save_root = save_root
124        self.compile_model = compile_model
125        self.rank = rank
126
127        self._iteration = 0
128        self._epoch = 0
129        self._best_epoch = 0
130
131        self.mixed_precision = mixed_precision
132        self.early_stopping = early_stopping
133        self.train_time = 0.0
134
135        if mixed_precision:
136            self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda")
137        else:
138            self.scaler = None
139
140        self.logger_class = logger
141        self.logger_kwargs = logger_kwargs
142        self.log_image_interval = log_image_interval
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
144    @property
145    def checkpoint_folder(self):
146        assert self.id_ is not None  # Because the logger may generate and set trainer.id on logger.__init__.
147        # Save_root enables saving the checkpoints somewhere else than in the local older.
148        # This is handy for filesystems with limited space, where saving the checkpoints
149        # and log files can ead to running out of space.
150        save_root = getattr(self, "save_root", None)
151        return os.path.join("./checkpoints", self.id_) if save_root is None else\
152            os.path.join(save_root, "./checkpoints", self.id_)
iteration
154    @property
155    def iteration(self):
156        return self._iteration
epoch
158    @property
159    def epoch(self):
160        return self._epoch
def fit( self, iterations: Optional[int] = None, load_from_checkpoint: Union[str, os.PathLike, NoneType] = None, epochs: Optional[int] = None, save_every_kth_epoch: Optional[int] = None, progress=None, overwrite_training: bool = True):
645    def fit(
646        self,
647        iterations: Optional[int] = None,
648        load_from_checkpoint: Optional[Union[os.PathLike, str]] = None,
649        epochs: Optional[int] = None,
650        save_every_kth_epoch: Optional[int] = None,
651        progress=None,
652        overwrite_training: bool = True,
653    ):
654        """Run neural network training.
655
656        Exactly one of 'iterations' or 'epochs' has to be passed.
657
658        Args:
659            iterations: How long to train, specified in iterations.
660            load_from_checkpoint: Path to a checkpoint from where training should be continued .
661            epochs: How long to train, specified in epochs.
662            save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file.
663                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
664            progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
665            overwrite_training: Whether to overwrite existing checkpoints in the save directory.
666        """
667        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
668
669        if not overwrite_training:
670            if load_from_checkpoint is not None:
671                raise ValueError(
672                    "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time."
673                )
674
675            if self._verify_if_training_completed():
676                print(
677                    f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs "
678                    "and 'overwrite_training' is set to 'False'."
679                )
680                print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.")
681                return
682
683        print(
684            "Start fitting for",
685            self.max_iteration - self._iteration,
686            "iterations / ",
687            self.max_epoch - self._epoch,
688            "epochs",
689        )
690        print("with", len(self.train_loader), "iterations per epoch")
691
692        if self.mixed_precision:
693            train_epoch = self._train_epoch_mixed
694            validate = self._validate_mixed
695            print("Training with mixed precision")
696        else:
697            train_epoch = self._train_epoch
698            validate = self._validate
699            print("Training with single precision")
700
701        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
702        if progress is None:
703            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
704        else:
705            progress.total = total_iterations
706            progress.set_description(f"Epoch {self._epoch}")
707
708        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
709        train_epochs = self.max_epoch - self._epoch
710        t_start = time.time()
711        for epoch in range(train_epochs):
712
713            # Ensure data is shuffled differently at each epoch.
714            try:
715                self.train_loader.sampler.set_epoch(epoch)
716            except AttributeError:
717                pass
718
719            # Run training and validation for this epoch
720            t_per_iter = train_epoch(progress)
721            current_metric = validate()
722
723            # perform all the post-epoch steps:
724
725            # apply the learning rate scheduler
726            if self.lr_scheduler is not None:
727                self.lr_scheduler.step(current_metric)
728
729            # how long did we train in total?
730            total_train_time = (time.time() - t_start) + self.train_time
731
732            # save this checkpoint as the new best checkpoint if
733            # it has the best overall validation metric
734            if current_metric < best_metric:
735                best_metric = current_metric
736                self._best_epoch = self._epoch
737                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
738
739            # save this checkpoint as the latest checkpoint
740            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
741
742            # if we save after every k-th epoch then check if we need to save now
743            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
744                self.save_checkpoint(
745                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
746                )
747
748            # if early stopping has been specified then check if the stopping condition is met
749            if self.early_stopping is not None:
750                epochs_since_best = self._epoch - self._best_epoch
751                if epochs_since_best > self.early_stopping:
752                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
753                    break
754
755            self._epoch += 1
756            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
757
758        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
759        print(f"The best epoch is number {self._best_epoch}.")
760
761        if self._generate_name:
762            self.name = None
763
764        # Update the train time
765        self.train_time = total_train_time
766
767        # TODO save the model to wandb if we have the wandb logger
768        if isinstance(self.logger, WandbLogger):
769            self.logger.get_wandb().finish()

Run neural network training.

Exactly one of 'iterations' or 'epochs' has to be passed.

Arguments:
  • iterations: How long to train, specified in iterations.
  • load_from_checkpoint: Path to a checkpoint from where training should be continued .
  • epochs: How long to train, specified in epochs.
  • save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
  • progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
  • overwrite_training: Whether to overwrite existing checkpoints in the save directory.
class DefaultTrainer.Deserializer:
162    class Deserializer:
163        """Determines how to deserialize the trainer kwargs from serialized 'init_data'.
164
165        Examples:
166            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
167            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
168
169            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
170            >>> class MyTrainer(DefaultTrainer):
171            >>>     def __init__(self, *args, the_answer: int, **kwargs):
172            >>>         super().__init__(*args, **kwargs)
173            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
174            >>>                                       # see DefaultTrainer.Serializer
175            >>>
176            >>>     class Deserializer(DefaultTrainer.Deserializer):
177            >>>         def load_the_answer(self):
178            >>>             generic_answer = self.init_data["the_answer"]
179            >>>             # (device dependent) special deserialization
180            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
181            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
182            >>>             else:
183            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
184
185        Args:
186            init_data: The initialization data of the trainer.
187            save_path: The path where the checkpoint was saved.
188            device: The device.
189        """
190
191        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
192            self.init_data = init_data
193            self.save_path = save_path
194            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
195            self.trainer_kwargs: Dict[str, Any] = dict(
196                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
197            )
198
199        def load(self, kwarg_name: str, optional):
200            """@private
201            """
202            # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
203            if kwarg_name == "device":
204                pass  # deserialized in __init__
205            elif kwarg_name.endswith("_loader"):
206                self.load_data_loader(kwarg_name, optional)
207            else:
208                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
209                load(kwarg_name, optional=optional)
210
211        def load_data_loader(self, loader_name, optional) -> None:
212            """@private
213            """
214            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
215            if ds is None and optional:
216                return
217
218            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
219            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
220            # monkey patch shuffle loader_name to the loader
221            loader.shuffle = loader_kwargs.get("shuffle", False)
222            self.trainer_kwargs[loader_name] = loader
223
224        def load_generic(
225            self,
226            kwarg_name: str,
227            *dynamic_args: Dict,
228            optional: bool,
229            only_class: bool = False,
230            dynamic_kwargs: Optional[Dict[str, Any]] = None,
231        ) -> None:
232            """@private
233            """
234            if kwarg_name in self.init_data:
235                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
236                return
237
238            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
239            if this_cls is None:
240                if optional:
241                    return
242                else:
243                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
244
245            assert isinstance(this_cls, str), this_cls
246            assert "." in this_cls, this_cls
247            cls_p, cls_m = this_cls.rsplit(".", 1)
248            this_cls = getattr(import_module(cls_p), cls_m)
249            if only_class:
250                self.trainer_kwargs[kwarg_name] = this_cls
251            else:
252                self.trainer_kwargs[kwarg_name] = this_cls(
253                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
254                )
255
256        def load_name(self, kwarg_name: str, optional: bool):
257            """@private
258            """
259            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
260
261        def load_optimizer(self, kwarg_name: str, optional: bool):
262            """@private
263            """
264            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
265
266        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
267            """@private
268            """
269            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
270
271        # todo: remove and rename kwarg 'logger' to 'logger_class'
272        def load_logger(self, kwarg_name: str, optional: bool):
273            """@private
274            """
275            assert kwarg_name == "logger"
276            self.load_generic("logger", optional=optional, only_class=True)

Determines how to deserialize the trainer kwargs from serialized 'init_data'.

Examples:

To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. Note that DefaultTrainer.Deserializer.load_generic() covers most cases already.

This example adds the_answer kwarg, which requires 'calculations' upon initialization:

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
>>>                                       # see DefaultTrainer.Serializer
>>>
>>>     class Deserializer(DefaultTrainer.Deserializer):
>>>         def load_the_answer(self):
>>>             generic_answer = self.init_data["the_answer"]
>>>             # (device dependent) special deserialization
>>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
>>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
>>>             else:
>>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
Arguments:
  • init_data: The initialization data of the trainer.
  • save_path: The path where the checkpoint was saved.
  • device: The device.
DefaultTrainer.Deserializer(init_data: Dict, save_path: str, device: Union[str, torch.device])
191        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
192            self.init_data = init_data
193            self.save_path = save_path
194            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
195            self.trainer_kwargs: Dict[str, Any] = dict(
196                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
197            )
init_data
save_path
trainer_kwargs: Dict[str, Any]
class DefaultTrainer.Serializer:
328    class Serializer:
329        """Implements how to serialize trainer kwargs from a trainer instance.
330
331        Examples:
332            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
333            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
334            called by the `dump()` method when appropriate cover most cases already.
335
336            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
337            'the_answer' attribute:
338            >>> class MyTrainer(DefaultTrainer):
339            >>>     def __init__(self, *args, the_answer: int, **kwargs):
340            >>>         super().__init__(*args, **kwargs)
341            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
342            >>>         # but let's make things more interesting...
343            >>>         self.the = the_answer // 10
344            >>>         self.answer = the_answer % 10
345            >>>
346            >>>     class Serializer(DefaultTrainer.Serializer):
347            >>>         trainer: MyTrainer
348            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
349            >>>             assert kwarg_name == "the_answer"
350            >>>             # populate self.init_data with the serialized data required by Deserializer
351            >>>             # to restore the trainer kwargs
352            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
353
354            This example with both Serializer and Deserializer adds `the_answer` kwarg,
355            while saving it in two separate entries 'the' and 'answer'
356            >>> class MyTrainer(DefaultTrainer):
357            >>>     def __init__(self, *args, the_answer: int, **kwargs):
358            >>>         super().__init__(*args, **kwargs)
359            >>>         self.the_answer = the_answer
360            >>>
361            >>>     class Serializer(DefaultTrainer.Serializer):
362            >>>         trainer: MyTrainer
363            >>>         def dump_the_answer(self, kwarg_name: str):
364            >>>             assert kwarg_name == "the_answer"
365            >>>             self.init_data.update({
366            >>>                 "the": self.trainer.the_answer // 10,
367            >>>                 "answer": self.trainer.the_answer % 10
368            >>>             })
369            >>>
370            >>>     class Deserializer(DefaultTrainer.Deserializer):
371            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
372            >>>             assert kwarg_name == "the_answer"
373            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
374            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
375
376        Args:
377            trainer: The trainer instance.
378        """
379
380        def __init__(self, trainer: DefaultTrainer):
381            self.trainer = trainer
382            self.init_data = {}  # to be populated during serialization process
383
384        def dump(self, kwarg_name: str) -> None:
385            """@private
386            """
387            dumper = getattr(self, f"dump_{kwarg_name}", None)
388            if dumper is not None:
389                dumper(kwarg_name)
390            elif kwarg_name.endswith("_loader"):
391                self.dump_data_loader(kwarg_name)
392            elif kwarg_name.endswith("_class"):
393                self.dump_generic_class(kwarg_name)
394            elif not hasattr(self.trainer, kwarg_name):
395                raise AttributeError(
396                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
397                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
398                )
399            else:
400                assert hasattr(self.trainer, kwarg_name)
401                obj = getattr(self.trainer, kwarg_name)
402                if obj is None or type(obj) in (
403                    bool,
404                    bytearray,
405                    bytes,
406                    dict,
407                    float,
408                    frozenset,
409                    int,
410                    list,
411                    set,
412                    str,
413                    tuple,
414                ):
415                    self.dump_generic_builtin(kwarg_name)
416                else:
417                    self.dump_generic_instance(kwarg_name)
418
419        def dump_generic_builtin(self, kwarg_name: str) -> None:
420            """@private
421            """
422            assert hasattr(self.trainer, kwarg_name)
423            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
424
425        def dump_generic_class(self, kwarg_name: str) -> None:
426            """@private
427            """
428            assert hasattr(self.trainer, kwarg_name)
429            assert kwarg_name.endswith("_class")
430            obj = getattr(self.trainer, kwarg_name)
431            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
432
433        def dump_generic_instance(self, kwarg_name: str) -> None:
434            """@private
435            """
436            assert hasattr(self.trainer, kwarg_name)
437            instance = getattr(self.trainer, kwarg_name)
438            self.init_data.update(
439                {
440                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
441                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
442                }
443            )
444
445        def dump_device(self, kwarg_name: str):
446            """@private
447            """
448            assert hasattr(self.trainer, kwarg_name)
449            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
450
451        def dump_data_loader(self, kwarg_name: str) -> None:
452            """@private
453            """
454            assert hasattr(self.trainer, kwarg_name)
455            loader = getattr(self.trainer, kwarg_name)
456            if loader is None:
457                return
458            self.init_data.update(
459                {
460                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
461                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
462                }
463            )
464
465        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
466            """@private
467            """
468            self.dump_generic_class(f"{kwarg_name}_class")
469
470        def dump_model(self, kwarg_name: str):
471            """@private
472            """
473            if is_compiled(self.trainer.model):
474                self.init_data.update(
475                    {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs}
476                )
477            else:
478                self.dump_generic_instance("model")

Implements how to serialize trainer kwargs from a trainer instance.

Examples:

To extend the serialization process you can inherite from this Serializer in a derived Trainer class. Note that the methods dump_generic_builtin(), dump_generic_class() and dump_generic_instance() called by the dump() method when appropriate cover most cases already.

This example adds the_answer kwarg, which requires extra steps on dumping only because we don't keep a 'the_answer' attribute:

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
>>>         # but let's make things more interesting...
>>>         self.the = the_answer // 10
>>>         self.answer = the_answer % 10
>>>
>>>     class Serializer(DefaultTrainer.Serializer):
>>>         trainer: MyTrainer
>>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
>>>             assert kwarg_name == "the_answer"
>>>             # populate self.init_data with the serialized data required by Deserializer
>>>             # to restore the trainer kwargs
>>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer

This example with both Serializer and Deserializer adds the_answer kwarg, while saving it in two separate entries 'the' and 'answer'

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         self.the_answer = the_answer
>>>
>>>     class Serializer(DefaultTrainer.Serializer):
>>>         trainer: MyTrainer
>>>         def dump_the_answer(self, kwarg_name: str):
>>>             assert kwarg_name == "the_answer"
>>>             self.init_data.update({
>>>                 "the": self.trainer.the_answer // 10,
>>>                 "answer": self.trainer.the_answer % 10
>>>             })
>>>
>>>     class Deserializer(DefaultTrainer.Deserializer):
>>>         def load_the_answer(self, kwarg_name: str, optional: bool):
>>>             assert kwarg_name == "the_answer"
>>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
>>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
Arguments:
  • trainer: The trainer instance.
DefaultTrainer.Serializer(trainer: DefaultTrainer)
380        def __init__(self, trainer: DefaultTrainer):
381            self.trainer = trainer
382            self.init_data = {}  # to be populated during serialization process
trainer
init_data