torch_em.trainer.default_trainer

  1from __future__ import annotations
  2
  3import os
  4import time
  5import inspect
  6import warnings
  7import contextlib
  8from tqdm import tqdm
  9from functools import partial
 10from datetime import datetime
 11from collections import OrderedDict
 12from importlib import import_module
 13from typing import Any, Callable, Dict, Optional, Union, Literal
 14
 15import numpy as np
 16
 17import torch
 18
 19from .wandb_logger import WandbLogger
 20from .tensorboard_logger import TensorboardLogger
 21from ..util import auto_compile, get_constructor_arguments, is_compiled
 22
 23
 24class DefaultTrainer:
 25    """Trainer class for training a segmentation network.
 26
 27    The trainer class implements the core logic for training a network with pytorch.
 28    It implements a training loop to run training and validation, which is started with `fit`.
 29    The checkpoints and logs from the training run will be saved in the current working directory,
 30    or in the directory specifified by `save_root`. Training can be continued from a checkpoint
 31    by passing it's location to the `load_from_checkpoint` argument of `fit`.
 32
 33    A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`.
 34    Alternatively, the trainer class can also be instantiated as in this example:
 35    ```python
 36    import torch
 37    from torch_em.loss import DiceLoss
 38    from torch_em.model import UNet2d
 39    from torch_em.data.datasets.light_microscopy import get_dsb_loader
 40    from torch_em.trainer import DefaultTrainer
 41
 42    # The training data will be downloaded to this location.
 43    data_root = "/path/to/save/the/training/data"
 44    patch_shape = (256, 256)
 45
 46    # Create the model and optimizer.
 47    model = UNet2d(in_channels=1, out_channels=1)
 48    optimizer = torch.optim.AdamW(model.parameters())
 49
 50    trainer = DefaultTrainer(
 51        name="unet-training",
 52        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
 53        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
 54        model=model,
 55        loss=DiceLoss(),  # The loss function.
 56        optimizer=optimizer,
 57        metric=DiceLoss(),  # The metric. The trainer expects smaller values to represent better results.
 58        device="cuda",  # The device to use for training.
 59    )
 60    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
 61    ```
 62
 63    Args:
 64        name: The name of the checkpoint that will be created by the trainer.
 65        train_loader: The data loader containing the training data.
 66        val_loader: The data loader containing the validation data.
 67        model: The model to train.
 68        loss: The loss function for training.
 69        optimizer: The optimizer.
 70        metric: The metric for validation.
 71        device: The torch device to use for training. If None, will use a GPU if available.
 72        lr_scheduler: The learning rate scheduler.
 73        log_image_interval: The interval for saving images during logging, in training iterations.
 74        mixed_precision: Whether to train with mixed precision.
 75        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
 76        logger: The logger class. Will be instantiated for logging.
 77            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
 78        logger_kwargs: The keyword arguments for the logger class.
 79        id_: Unique identifier for the trainer. If None then `name` will be used.
 80        save_root: The root folder for saving the checkpoint and logs.
 81        compile_model: Whether to compile the model before training.
 82        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
 83    """
 84    def __init__(
 85        self,
 86        name: Optional[str],
 87        train_loader: torch.utils.data.DataLoader,
 88        val_loader: torch.utils.data.DataLoader,
 89        model: torch.nn.Module,
 90        loss: torch.nn.Module,
 91        optimizer: torch.optim.Optimizer,
 92        metric: Callable,
 93        device: Union[str, torch.device],
 94        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
 95        log_image_interval: int = 100,
 96        mixed_precision: bool = True,
 97        early_stopping: Optional[int] = None,
 98        logger=TensorboardLogger,
 99        logger_kwargs: Optional[Dict[str, Any]] = None,
100        id_: Optional[str] = None,
101        save_root: Optional[str] = None,
102        compile_model: Optional[Union[bool, str]] = None,
103        rank: Optional[int] = None,
104    ):
105        if name is None and not issubclass(logger, WandbLogger):
106            raise TypeError("Name cannot be None if not using the WandbLogger")
107
108        self._generate_name = name is None
109        self.name = name
110        self.id_ = id_ or name
111        self.train_loader = train_loader
112        self.val_loader = val_loader
113        self.model = model
114        self.loss = loss
115        self.optimizer = optimizer
116        self.metric = metric
117        self.device = torch.device(device)
118        self.lr_scheduler = lr_scheduler
119        self.log_image_interval = log_image_interval
120        self.save_root = save_root
121        self.compile_model = compile_model
122        self.rank = rank
123
124        self._iteration = 0
125        self._epoch = 0
126        self._best_epoch = 0
127
128        self.mixed_precision = mixed_precision
129        self.early_stopping = early_stopping
130        self.train_time = 0.0
131
132        if mixed_precision:
133            self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda")
134        else:
135            self.scaler = None
136
137        self.logger_class = logger
138        self.logger_kwargs = logger_kwargs
139        self.log_image_interval = log_image_interval
140
141    @property
142    def checkpoint_folder(self):
143        assert self.id_ is not None  # Because the logger may generate and set trainer.id on logger.__init__.
144        # Save_root enables saving the checkpoints somewhere else than in the local folder.
145        # This is handy for filesystems with limited space, where saving the checkpoints
146        # and log files can lead to running out of space.
147        save_root = getattr(self, "save_root", None)
148        return os.path.join("./checkpoints", self.id_) if save_root is None else\
149            os.path.join(save_root, "./checkpoints", self.id_)
150
151    @property
152    def iteration(self):
153        return self._iteration
154
155    @property
156    def epoch(self):
157        return self._epoch
158
159    class Deserializer:
160        """Determines how to deserialize the trainer kwargs from serialized 'init_data'.
161
162        Examples:
163            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
164            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
165
166            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
167            >>> class MyTrainer(DefaultTrainer):
168            >>>     def __init__(self, *args, the_answer: int, **kwargs):
169            >>>         super().__init__(*args, **kwargs)
170            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
171            >>>                                       # see DefaultTrainer.Serializer
172            >>>
173            >>>     class Deserializer(DefaultTrainer.Deserializer):
174            >>>         def load_the_answer(self):
175            >>>             generic_answer = self.init_data["the_answer"]
176            >>>             # (device dependent) special deserialization
177            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
178            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
179            >>>             else:
180            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
181
182        Args:
183            init_data: The initialization data of the trainer.
184            save_path: The path where the checkpoint was saved.
185            device: The device.
186        """
187
188        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
189            self.init_data = init_data
190            self.save_path = save_path
191            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
192            self.trainer_kwargs: Dict[str, Any] = dict(
193                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
194            )
195
196        def load(self, kwarg_name: str, optional):
197            """@private
198            """
199            # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
200            if kwarg_name == "device":
201                pass  # deserialized in __init__
202            elif kwarg_name.endswith("_loader"):
203                self.load_data_loader(kwarg_name, optional)
204            else:
205                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
206                load(kwarg_name, optional=optional)
207
208        def load_data_loader(self, loader_name, optional) -> None:
209            """@private
210            """
211            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
212            if ds is None and optional:
213                return
214
215            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
216            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
217            # monkey patch shuffle loader_name to the loader
218            loader.shuffle = loader_kwargs.get("shuffle", False)
219            self.trainer_kwargs[loader_name] = loader
220
221        def load_generic(
222            self,
223            kwarg_name: str,
224            *dynamic_args: Dict,
225            optional: bool,
226            only_class: bool = False,
227            dynamic_kwargs: Optional[Dict[str, Any]] = None,
228        ) -> None:
229            """@private
230            """
231            if kwarg_name in self.init_data:
232                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
233                return
234
235            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
236            if this_cls is None:
237                if optional:
238                    return
239                else:
240                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
241
242            assert isinstance(this_cls, str), this_cls
243            assert "." in this_cls, this_cls
244            cls_p, cls_m = this_cls.rsplit(".", 1)
245            this_cls = getattr(import_module(cls_p), cls_m)
246            if only_class:
247                self.trainer_kwargs[kwarg_name] = this_cls
248            else:
249                self.trainer_kwargs[kwarg_name] = this_cls(
250                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
251                )
252
253        def load_name(self, kwarg_name: str, optional: bool):
254            """@private
255            """
256            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
257
258        def load_optimizer(self, kwarg_name: str, optional: bool):
259            """@private
260            """
261            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
262
263        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
264            """@private
265            """
266            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
267
268        # todo: remove and rename kwarg 'logger' to 'logger_class'
269        def load_logger(self, kwarg_name: str, optional: bool):
270            """@private
271            """
272            assert kwarg_name == "logger"
273            self.load_generic("logger", optional=optional, only_class=True)
274
275    @staticmethod
276    def _get_save_dict(save_path, device):
277        if not os.path.exists(save_path):
278            raise ValueError(f"Cannot find checkpoint {save_path}")
279        return torch.load(save_path, map_location=device, weights_only=False)
280
281    @classmethod
282    def from_checkpoint(
283        cls,
284        checkpoint_folder: Union[os.PathLike, str],
285        name: Literal["best", "latest"] = "best",
286        device: Optional[Union[str, torch.device]] = None,
287    ):
288        """@private
289        """
290        save_path = os.path.join(checkpoint_folder, f"{name}.pt")
291        # make sure the correct device is set if we don't have access to CUDA
292        if not torch.cuda.is_available():
293            device = "cpu"
294        save_dict = cls._get_save_dict(save_path, device)
295        deserializer = cls.Deserializer(save_dict["init"], save_path, device)
296
297        has_kwargs = False
298        deserialized = []
299        for name, parameter in inspect.signature(cls).parameters.items():
300            if name == "kwargs":
301                has_kwargs = True
302                continue
303            deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty)
304            deserialized.append(name)
305
306        # to deserialze kwargs we can't rely on inspecting the signature, so we
307        # go through the remaning kwarg names in init data instead
308        if has_kwargs:
309            kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized))
310            for name in kwarg_names:
311                if name.endswith("_kwargs"):
312                    continue
313                elif name.endswith("_dataset"):
314                    deserializer.load(name.replace("dataset", "loader"), optional=False)
315                elif name.endswith("_class"):
316                    deserializer.load(name.replace("_class", ""), optional=False)
317                else:
318                    deserializer.load(name, optional=False)
319
320        trainer = cls(**deserializer.trainer_kwargs)
321        trainer._initialize(0, save_dict)
322        trainer._is_initialized = True
323        return trainer
324
325    class Serializer:
326        """Implements how to serialize trainer kwargs from a trainer instance.
327
328        Examples:
329            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
330            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
331            called by the `dump()` method when appropriate cover most cases already.
332
333            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
334            'the_answer' attribute:
335            >>> class MyTrainer(DefaultTrainer):
336            >>>     def __init__(self, *args, the_answer: int, **kwargs):
337            >>>         super().__init__(*args, **kwargs)
338            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
339            >>>         # but let's make things more interesting...
340            >>>         self.the = the_answer // 10
341            >>>         self.answer = the_answer % 10
342            >>>
343            >>>     class Serializer(DefaultTrainer.Serializer):
344            >>>         trainer: MyTrainer
345            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
346            >>>             assert kwarg_name == "the_answer"
347            >>>             # populate self.init_data with the serialized data required by Deserializer
348            >>>             # to restore the trainer kwargs
349            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
350
351            This example with both Serializer and Deserializer adds `the_answer` kwarg,
352            while saving it in two separate entries 'the' and 'answer'
353            >>> class MyTrainer(DefaultTrainer):
354            >>>     def __init__(self, *args, the_answer: int, **kwargs):
355            >>>         super().__init__(*args, **kwargs)
356            >>>         self.the_answer = the_answer
357            >>>
358            >>>     class Serializer(DefaultTrainer.Serializer):
359            >>>         trainer: MyTrainer
360            >>>         def dump_the_answer(self, kwarg_name: str):
361            >>>             assert kwarg_name == "the_answer"
362            >>>             self.init_data.update({
363            >>>                 "the": self.trainer.the_answer // 10,
364            >>>                 "answer": self.trainer.the_answer % 10
365            >>>             })
366            >>>
367            >>>     class Deserializer(DefaultTrainer.Deserializer):
368            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
369            >>>             assert kwarg_name == "the_answer"
370            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
371            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
372
373        Args:
374            trainer: The trainer instance.
375        """
376
377        def __init__(self, trainer: DefaultTrainer):
378            self.trainer = trainer
379            self.init_data = {}  # to be populated during serialization process
380
381        def dump(self, kwarg_name: str) -> None:
382            """@private
383            """
384            dumper = getattr(self, f"dump_{kwarg_name}", None)
385            if dumper is not None:
386                dumper(kwarg_name)
387            elif kwarg_name.endswith("_loader"):
388                self.dump_data_loader(kwarg_name)
389            elif kwarg_name.endswith("_class"):
390                self.dump_generic_class(kwarg_name)
391            elif not hasattr(self.trainer, kwarg_name):
392                raise AttributeError(
393                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
394                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
395                )
396            else:
397                assert hasattr(self.trainer, kwarg_name)
398                obj = getattr(self.trainer, kwarg_name)
399                if obj is None or type(obj) in (
400                    bool,
401                    bytearray,
402                    bytes,
403                    dict,
404                    float,
405                    frozenset,
406                    int,
407                    list,
408                    set,
409                    str,
410                    tuple,
411                ):
412                    self.dump_generic_builtin(kwarg_name)
413                else:
414                    self.dump_generic_instance(kwarg_name)
415
416        def dump_generic_builtin(self, kwarg_name: str) -> None:
417            """@private
418            """
419            assert hasattr(self.trainer, kwarg_name)
420            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
421
422        def dump_generic_class(self, kwarg_name: str) -> None:
423            """@private
424            """
425            assert hasattr(self.trainer, kwarg_name)
426            assert kwarg_name.endswith("_class")
427            obj = getattr(self.trainer, kwarg_name)
428            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
429
430        def dump_generic_instance(self, kwarg_name: str) -> None:
431            """@private
432            """
433            assert hasattr(self.trainer, kwarg_name)
434            instance = getattr(self.trainer, kwarg_name)
435            self.init_data.update(
436                {
437                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
438                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
439                }
440            )
441
442        def dump_device(self, kwarg_name: str):
443            """@private
444            """
445            assert hasattr(self.trainer, kwarg_name)
446            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
447
448        def dump_data_loader(self, kwarg_name: str) -> None:
449            """@private
450            """
451            assert hasattr(self.trainer, kwarg_name)
452            loader = getattr(self.trainer, kwarg_name)
453            if loader is None:
454                return
455            self.init_data.update(
456                {
457                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
458                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
459                }
460            )
461
462        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
463            """@private
464            """
465            self.dump_generic_class(f"{kwarg_name}_class")
466
467        def dump_model(self, kwarg_name: str):
468            """@private
469            """
470            if is_compiled(self.trainer.model):
471                self.init_data.update(
472                    {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs}
473                )
474            else:
475                self.dump_generic_instance("model")
476
477    def _build_init(self) -> Dict[str, Any]:
478        serializer = self.Serializer(self)
479        for name in inspect.signature(self.__class__).parameters:
480            # special rules to serialize kwargs
481            # if a trainer class inherits from DefaultTrainer and has **kwargs
482            # they need to be saved in self._kwargs
483            if name == "kwargs":
484                if not hasattr(self, "_kwargs"):
485                    msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\
486                          "Please add self._kwargs to its __init__ function"
487                    raise RuntimeError(msg)
488                kwargs = getattr(self, "_kwargs")
489                for kwarg_name in kwargs:
490                    serializer.dump(kwarg_name)
491                continue
492            serializer.dump(name)
493
494        return serializer.init_data
495
496    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
497        assert self.train_loader is not None
498        assert self.val_loader is not None
499        assert self.model is not None
500        assert self.loss is not None
501        assert self.optimizer is not None
502        assert self.metric is not None
503        assert self.device is not None
504
505        if load_from_checkpoint is not None:
506            self.load_checkpoint(load_from_checkpoint)
507
508        if sum((iterations is not None, epochs is not None)) != 1:
509            raise ValueError(
510                "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer."
511                f"You have passed 'iterations'={iterations} and 'epochs'={epochs}"
512            )
513
514        if epochs is None:
515            epochs = int(np.ceil(float(iterations) / len(self.train_loader)))
516        else:
517            iterations = epochs * len(self.train_loader)
518
519        self.max_iteration = self._iteration + iterations
520        self.max_epoch = self._epoch + epochs
521
522        if not getattr(self, "_is_initialized", False):
523            # check if we compile the model (only supported by pytorch 2)
524            # to enable (de)serialization of compiled models, we keep track of the model class and kwargs
525            if is_compiled(self.model):
526                warnings.warn(
527                    "You have passed a compiled model to the trainer."
528                    "It will not be possible to (de)serialize the trainer with it."
529                    "If you want to be able to do this please pass the normal model."
530                    "It can be automatically compiled by setting 'compile_model' to True"
531                )
532            self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}"
533            self._model_kwargs = get_constructor_arguments(self.model)
534            self.model = auto_compile(self.model, self.compile_model)
535
536            self.model.to(self.device)
537            self.loss.to(self.device)
538
539            # this saves all the information that is necessary
540            # to fully load the trainer from the checkpoint
541            self.init_data = self._build_init()
542
543            if self.logger_class is None:
544                self.logger = None
545            else:
546                # may set self.name if self.name is None
547                save_root = getattr(self, "save_root", None)
548                try:
549                    self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {}))
550                except (PermissionError, RuntimeError):
551                    warnings.warn(
552                        f"The checkpoint folder at {self.checkpoint_folder} could not be created."
553                        "The most likely reason for this is that you copied the checkpoint somewhere else,"
554                        "so we skip this error to enable loading the model from this checkpoint."
555                    )
556
557            try:
558                os.makedirs(self.checkpoint_folder, exist_ok=True)
559            except PermissionError:
560                warnings.warn(
561                    f"The checkpoint folder at {self.checkpoint_folder} could not be created."
562                    "The most likely reason for this is that you copied the checkpoint somewhere else,"
563                    "so we skip this error to enable loading the model from this checkpoint."
564                )
565                pass
566
567        best_metric = np.inf
568        return best_metric
569
570    def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
571        """@private
572        """
573        save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
574        extra_init_dict = extra_save_dict.pop("init", {})
575        save_dict = {
576            "iteration": self._iteration,
577            "epoch": self._epoch,
578            "best_epoch": self._best_epoch,
579            "best_metric": best_metric,
580            "current_metric": current_metric,
581            "model_state": self.model.state_dict(),
582            "optimizer_state": self.optimizer.state_dict(),
583            "init": self.init_data | extra_init_dict,
584            "train_time": train_time,
585            "timestamp": datetime.now().strftime("%d-%m-%Y (%H:%M:%S)"),
586        }
587        save_dict.update(**extra_save_dict)
588        if self.scaler is not None:
589            save_dict.update({"scaler_state": self.scaler.state_dict()})
590        if self.lr_scheduler is not None:
591            save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()})
592
593        rank = getattr(self, "rank", None)
594        if rank is None or rank == 0:
595            torch.save(save_dict, save_path)
596
597    def load_checkpoint(self, checkpoint="best"):
598        """@private
599        """
600        if isinstance(checkpoint, str):
601            save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
602            if not os.path.exists(save_path):
603                warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.")
604                return
605            save_dict = torch.load(save_path, weights_only=False)
606        elif isinstance(checkpoint, dict):
607            save_dict = checkpoint
608        else:
609            raise RuntimeError
610
611        self._iteration = save_dict["iteration"]
612        self._epoch = save_dict["epoch"]
613        self._best_epoch = save_dict["best_epoch"]
614        self.best_metric = save_dict["best_metric"]
615        self.current_metric = save_dict["current_metric"]
616        self.train_time = save_dict.get("train_time", 0.0)
617
618        model_state = save_dict["model_state"]
619        # to enable loading compiled models
620        compiled_prefix = "_orig_mod."
621        model_state = OrderedDict(
622            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()]
623        )
624        self.model.load_state_dict(model_state)
625        # we need to send the network to the device before loading the optimizer state!
626        self.model.to(self.device)
627
628        self.optimizer.load_state_dict(save_dict["optimizer_state"])
629        if self.scaler is not None:
630            self.scaler.load_state_dict(save_dict["scaler_state"])
631        if self.lr_scheduler is not None:
632            self.lr_scheduler.load_state_dict(save_dict["scheduler_state"])
633
634        return save_dict
635
636    def _verify_if_training_completed(self, checkpoint="latest"):
637        save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
638        save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None
639        if save_dict and self.max_iteration == save_dict.get("iteration"):
640            return True
641        return False
642
643    def fit(
644        self,
645        iterations: Optional[int] = None,
646        load_from_checkpoint: Optional[Union[os.PathLike, str]] = None,
647        epochs: Optional[int] = None,
648        save_every_kth_epoch: Optional[int] = None,
649        progress=None,
650        overwrite_training: bool = True,
651    ):
652        """Run neural network training.
653
654        Exactly one of 'iterations' or 'epochs' has to be passed.
655
656        Args:
657            iterations: How long to train, specified in iterations.
658            load_from_checkpoint: Path to a checkpoint from where training should be continued .
659            epochs: How long to train, specified in epochs.
660            save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file.
661                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
662            progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
663            overwrite_training: Whether to overwrite existing checkpoints in the save directory.
664        """
665        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
666
667        if not overwrite_training:
668            if load_from_checkpoint is not None:
669                raise ValueError(
670                    "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time."
671                )
672
673            if self._verify_if_training_completed():
674                print(
675                    f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs "
676                    "and 'overwrite_training' is set to 'False'."
677                )
678                print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.")
679                return
680
681        print(
682            "Start fitting for",
683            self.max_iteration - self._iteration,
684            "iterations / ",
685            self.max_epoch - self._epoch,
686            "epochs",
687        )
688        print("with", len(self.train_loader), "iterations per epoch")
689
690        if self.mixed_precision:
691            train_epoch = self._train_epoch_mixed
692            validate = self._validate_mixed
693            print("Training with mixed precision")
694        else:
695            train_epoch = self._train_epoch
696            validate = self._validate
697            print("Training with single precision")
698
699        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
700        if progress is None:
701            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
702        else:
703            progress.total = total_iterations
704            progress.set_description(f"Epoch {self._epoch}")
705
706        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
707        train_epochs = self.max_epoch - self._epoch
708        t_start = time.time()
709        for epoch in range(train_epochs):
710
711            # Ensure data is shuffled differently at each epoch.
712            try:
713                self.train_loader.sampler.set_epoch(epoch)
714            except AttributeError:
715                pass
716
717            # Run training and validation for this epoch
718            t_per_iter = train_epoch(progress)
719            current_metric = validate()
720
721            # perform all the post-epoch steps:
722
723            # apply the learning rate scheduler
724            if self.lr_scheduler is not None:
725                self.lr_scheduler.step(current_metric)
726
727            # how long did we train in total?
728            total_train_time = (time.time() - t_start) + self.train_time
729
730            # save this checkpoint as the new best checkpoint if
731            # it has the best overall validation metric
732            if current_metric < best_metric:
733                best_metric = current_metric
734                self._best_epoch = self._epoch
735                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
736
737            # save this checkpoint as the latest checkpoint
738            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
739
740            # if we save after every k-th epoch then check if we need to save now
741            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
742                self.save_checkpoint(
743                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
744                )
745
746            # if early stopping has been specified then check if the stopping condition is met
747            if self.early_stopping is not None:
748                epochs_since_best = self._epoch - self._best_epoch
749                if epochs_since_best > self.early_stopping:
750                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
751                    break
752
753            self._epoch += 1
754            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
755
756        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
757        print(f"The best epoch is number {self._best_epoch}.")
758
759        if self._generate_name:
760            self.name = None
761
762        # Update the train time
763        self.train_time = total_train_time
764
765        # TODO save the model to wandb if we have the wandb logger
766        if isinstance(self.logger, WandbLogger):
767            self.logger.get_wandb().finish()
768
769    def _backprop(self, loss):
770        loss.backward()
771        self.optimizer.step()
772
773    def _backprop_mixed(self, loss):
774        self.scaler.scale(loss).backward()
775        self.scaler.step(self.optimizer)
776        self.scaler.update()
777
778    def _train_epoch(self, progress):
779        return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop)
780
781    def _train_epoch_mixed(self, progress):
782        return self._train_epoch_impl(
783            progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"),
784            self._backprop_mixed
785        )
786
787    def _forward_and_loss(self, x, y):
788        pred = self.model(x)
789        if self._iteration % self.log_image_interval == 0:
790            if pred.requires_grad:
791                pred.retain_grad()
792
793        loss = self.loss(pred, y)
794        return pred, loss
795
796    def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]):
797        self.model.train()
798
799        n_iter = 0
800        t_per_iter = time.time()
801        for x, y in self.train_loader:
802            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
803
804            self.optimizer.zero_grad()
805
806            with forward_context():
807                pred, loss = self._forward_and_loss(x, y)
808
809            backprop(loss)
810
811            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
812            if self.logger is not None:
813                self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True)
814
815            self._iteration += 1
816            n_iter += 1
817            if self._iteration >= self.max_iteration:
818                break
819            progress.update(1)
820
821        t_per_iter = (time.time() - t_per_iter) / n_iter
822        return t_per_iter
823
824    def _validate(self):
825        return self._validate_impl(contextlib.nullcontext)
826
827    def _validate_mixed(self):
828        return self._validate_impl(
829            partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda")
830        )
831
832    def _validate_impl(self, forward_context):
833        self.model.eval()
834
835        metric_val = 0.0
836        loss_val = 0.0
837
838        with torch.no_grad():
839            for x, y in self.val_loader:
840                x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
841                with forward_context():
842                    pred, loss = self._forward_and_loss(x, y)
843                    metric = self.metric(pred, y)
844
845                loss_val += loss.item()
846                metric_val += metric.item()
847
848        metric_val /= len(self.val_loader)
849        loss_val /= len(self.val_loader)
850        if self.logger is not None:
851            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred)
852        return metric_val
class DefaultTrainer:
 25class DefaultTrainer:
 26    """Trainer class for training a segmentation network.
 27
 28    The trainer class implements the core logic for training a network with pytorch.
 29    It implements a training loop to run training and validation, which is started with `fit`.
 30    The checkpoints and logs from the training run will be saved in the current working directory,
 31    or in the directory specifified by `save_root`. Training can be continued from a checkpoint
 32    by passing it's location to the `load_from_checkpoint` argument of `fit`.
 33
 34    A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`.
 35    Alternatively, the trainer class can also be instantiated as in this example:
 36    ```python
 37    import torch
 38    from torch_em.loss import DiceLoss
 39    from torch_em.model import UNet2d
 40    from torch_em.data.datasets.light_microscopy import get_dsb_loader
 41    from torch_em.trainer import DefaultTrainer
 42
 43    # The training data will be downloaded to this location.
 44    data_root = "/path/to/save/the/training/data"
 45    patch_shape = (256, 256)
 46
 47    # Create the model and optimizer.
 48    model = UNet2d(in_channels=1, out_channels=1)
 49    optimizer = torch.optim.AdamW(model.parameters())
 50
 51    trainer = DefaultTrainer(
 52        name="unet-training",
 53        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
 54        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
 55        model=model,
 56        loss=DiceLoss(),  # The loss function.
 57        optimizer=optimizer,
 58        metric=DiceLoss(),  # The metric. The trainer expects smaller values to represent better results.
 59        device="cuda",  # The device to use for training.
 60    )
 61    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
 62    ```
 63
 64    Args:
 65        name: The name of the checkpoint that will be created by the trainer.
 66        train_loader: The data loader containing the training data.
 67        val_loader: The data loader containing the validation data.
 68        model: The model to train.
 69        loss: The loss function for training.
 70        optimizer: The optimizer.
 71        metric: The metric for validation.
 72        device: The torch device to use for training. If None, will use a GPU if available.
 73        lr_scheduler: The learning rate scheduler.
 74        log_image_interval: The interval for saving images during logging, in training iterations.
 75        mixed_precision: Whether to train with mixed precision.
 76        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
 77        logger: The logger class. Will be instantiated for logging.
 78            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
 79        logger_kwargs: The keyword arguments for the logger class.
 80        id_: Unique identifier for the trainer. If None then `name` will be used.
 81        save_root: The root folder for saving the checkpoint and logs.
 82        compile_model: Whether to compile the model before training.
 83        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
 84    """
 85    def __init__(
 86        self,
 87        name: Optional[str],
 88        train_loader: torch.utils.data.DataLoader,
 89        val_loader: torch.utils.data.DataLoader,
 90        model: torch.nn.Module,
 91        loss: torch.nn.Module,
 92        optimizer: torch.optim.Optimizer,
 93        metric: Callable,
 94        device: Union[str, torch.device],
 95        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
 96        log_image_interval: int = 100,
 97        mixed_precision: bool = True,
 98        early_stopping: Optional[int] = None,
 99        logger=TensorboardLogger,
100        logger_kwargs: Optional[Dict[str, Any]] = None,
101        id_: Optional[str] = None,
102        save_root: Optional[str] = None,
103        compile_model: Optional[Union[bool, str]] = None,
104        rank: Optional[int] = None,
105    ):
106        if name is None and not issubclass(logger, WandbLogger):
107            raise TypeError("Name cannot be None if not using the WandbLogger")
108
109        self._generate_name = name is None
110        self.name = name
111        self.id_ = id_ or name
112        self.train_loader = train_loader
113        self.val_loader = val_loader
114        self.model = model
115        self.loss = loss
116        self.optimizer = optimizer
117        self.metric = metric
118        self.device = torch.device(device)
119        self.lr_scheduler = lr_scheduler
120        self.log_image_interval = log_image_interval
121        self.save_root = save_root
122        self.compile_model = compile_model
123        self.rank = rank
124
125        self._iteration = 0
126        self._epoch = 0
127        self._best_epoch = 0
128
129        self.mixed_precision = mixed_precision
130        self.early_stopping = early_stopping
131        self.train_time = 0.0
132
133        if mixed_precision:
134            self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda")
135        else:
136            self.scaler = None
137
138        self.logger_class = logger
139        self.logger_kwargs = logger_kwargs
140        self.log_image_interval = log_image_interval
141
142    @property
143    def checkpoint_folder(self):
144        assert self.id_ is not None  # Because the logger may generate and set trainer.id on logger.__init__.
145        # Save_root enables saving the checkpoints somewhere else than in the local folder.
146        # This is handy for filesystems with limited space, where saving the checkpoints
147        # and log files can lead to running out of space.
148        save_root = getattr(self, "save_root", None)
149        return os.path.join("./checkpoints", self.id_) if save_root is None else\
150            os.path.join(save_root, "./checkpoints", self.id_)
151
152    @property
153    def iteration(self):
154        return self._iteration
155
156    @property
157    def epoch(self):
158        return self._epoch
159
160    class Deserializer:
161        """Determines how to deserialize the trainer kwargs from serialized 'init_data'.
162
163        Examples:
164            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
165            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
166
167            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
168            >>> class MyTrainer(DefaultTrainer):
169            >>>     def __init__(self, *args, the_answer: int, **kwargs):
170            >>>         super().__init__(*args, **kwargs)
171            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
172            >>>                                       # see DefaultTrainer.Serializer
173            >>>
174            >>>     class Deserializer(DefaultTrainer.Deserializer):
175            >>>         def load_the_answer(self):
176            >>>             generic_answer = self.init_data["the_answer"]
177            >>>             # (device dependent) special deserialization
178            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
179            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
180            >>>             else:
181            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
182
183        Args:
184            init_data: The initialization data of the trainer.
185            save_path: The path where the checkpoint was saved.
186            device: The device.
187        """
188
189        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
190            self.init_data = init_data
191            self.save_path = save_path
192            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
193            self.trainer_kwargs: Dict[str, Any] = dict(
194                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
195            )
196
197        def load(self, kwarg_name: str, optional):
198            """@private
199            """
200            # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
201            if kwarg_name == "device":
202                pass  # deserialized in __init__
203            elif kwarg_name.endswith("_loader"):
204                self.load_data_loader(kwarg_name, optional)
205            else:
206                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
207                load(kwarg_name, optional=optional)
208
209        def load_data_loader(self, loader_name, optional) -> None:
210            """@private
211            """
212            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
213            if ds is None and optional:
214                return
215
216            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
217            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
218            # monkey patch shuffle loader_name to the loader
219            loader.shuffle = loader_kwargs.get("shuffle", False)
220            self.trainer_kwargs[loader_name] = loader
221
222        def load_generic(
223            self,
224            kwarg_name: str,
225            *dynamic_args: Dict,
226            optional: bool,
227            only_class: bool = False,
228            dynamic_kwargs: Optional[Dict[str, Any]] = None,
229        ) -> None:
230            """@private
231            """
232            if kwarg_name in self.init_data:
233                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
234                return
235
236            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
237            if this_cls is None:
238                if optional:
239                    return
240                else:
241                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
242
243            assert isinstance(this_cls, str), this_cls
244            assert "." in this_cls, this_cls
245            cls_p, cls_m = this_cls.rsplit(".", 1)
246            this_cls = getattr(import_module(cls_p), cls_m)
247            if only_class:
248                self.trainer_kwargs[kwarg_name] = this_cls
249            else:
250                self.trainer_kwargs[kwarg_name] = this_cls(
251                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
252                )
253
254        def load_name(self, kwarg_name: str, optional: bool):
255            """@private
256            """
257            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
258
259        def load_optimizer(self, kwarg_name: str, optional: bool):
260            """@private
261            """
262            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
263
264        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
265            """@private
266            """
267            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
268
269        # todo: remove and rename kwarg 'logger' to 'logger_class'
270        def load_logger(self, kwarg_name: str, optional: bool):
271            """@private
272            """
273            assert kwarg_name == "logger"
274            self.load_generic("logger", optional=optional, only_class=True)
275
276    @staticmethod
277    def _get_save_dict(save_path, device):
278        if not os.path.exists(save_path):
279            raise ValueError(f"Cannot find checkpoint {save_path}")
280        return torch.load(save_path, map_location=device, weights_only=False)
281
282    @classmethod
283    def from_checkpoint(
284        cls,
285        checkpoint_folder: Union[os.PathLike, str],
286        name: Literal["best", "latest"] = "best",
287        device: Optional[Union[str, torch.device]] = None,
288    ):
289        """@private
290        """
291        save_path = os.path.join(checkpoint_folder, f"{name}.pt")
292        # make sure the correct device is set if we don't have access to CUDA
293        if not torch.cuda.is_available():
294            device = "cpu"
295        save_dict = cls._get_save_dict(save_path, device)
296        deserializer = cls.Deserializer(save_dict["init"], save_path, device)
297
298        has_kwargs = False
299        deserialized = []
300        for name, parameter in inspect.signature(cls).parameters.items():
301            if name == "kwargs":
302                has_kwargs = True
303                continue
304            deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty)
305            deserialized.append(name)
306
307        # to deserialze kwargs we can't rely on inspecting the signature, so we
308        # go through the remaning kwarg names in init data instead
309        if has_kwargs:
310            kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized))
311            for name in kwarg_names:
312                if name.endswith("_kwargs"):
313                    continue
314                elif name.endswith("_dataset"):
315                    deserializer.load(name.replace("dataset", "loader"), optional=False)
316                elif name.endswith("_class"):
317                    deserializer.load(name.replace("_class", ""), optional=False)
318                else:
319                    deserializer.load(name, optional=False)
320
321        trainer = cls(**deserializer.trainer_kwargs)
322        trainer._initialize(0, save_dict)
323        trainer._is_initialized = True
324        return trainer
325
326    class Serializer:
327        """Implements how to serialize trainer kwargs from a trainer instance.
328
329        Examples:
330            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
331            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
332            called by the `dump()` method when appropriate cover most cases already.
333
334            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
335            'the_answer' attribute:
336            >>> class MyTrainer(DefaultTrainer):
337            >>>     def __init__(self, *args, the_answer: int, **kwargs):
338            >>>         super().__init__(*args, **kwargs)
339            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
340            >>>         # but let's make things more interesting...
341            >>>         self.the = the_answer // 10
342            >>>         self.answer = the_answer % 10
343            >>>
344            >>>     class Serializer(DefaultTrainer.Serializer):
345            >>>         trainer: MyTrainer
346            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
347            >>>             assert kwarg_name == "the_answer"
348            >>>             # populate self.init_data with the serialized data required by Deserializer
349            >>>             # to restore the trainer kwargs
350            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
351
352            This example with both Serializer and Deserializer adds `the_answer` kwarg,
353            while saving it in two separate entries 'the' and 'answer'
354            >>> class MyTrainer(DefaultTrainer):
355            >>>     def __init__(self, *args, the_answer: int, **kwargs):
356            >>>         super().__init__(*args, **kwargs)
357            >>>         self.the_answer = the_answer
358            >>>
359            >>>     class Serializer(DefaultTrainer.Serializer):
360            >>>         trainer: MyTrainer
361            >>>         def dump_the_answer(self, kwarg_name: str):
362            >>>             assert kwarg_name == "the_answer"
363            >>>             self.init_data.update({
364            >>>                 "the": self.trainer.the_answer // 10,
365            >>>                 "answer": self.trainer.the_answer % 10
366            >>>             })
367            >>>
368            >>>     class Deserializer(DefaultTrainer.Deserializer):
369            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
370            >>>             assert kwarg_name == "the_answer"
371            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
372            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
373
374        Args:
375            trainer: The trainer instance.
376        """
377
378        def __init__(self, trainer: DefaultTrainer):
379            self.trainer = trainer
380            self.init_data = {}  # to be populated during serialization process
381
382        def dump(self, kwarg_name: str) -> None:
383            """@private
384            """
385            dumper = getattr(self, f"dump_{kwarg_name}", None)
386            if dumper is not None:
387                dumper(kwarg_name)
388            elif kwarg_name.endswith("_loader"):
389                self.dump_data_loader(kwarg_name)
390            elif kwarg_name.endswith("_class"):
391                self.dump_generic_class(kwarg_name)
392            elif not hasattr(self.trainer, kwarg_name):
393                raise AttributeError(
394                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
395                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
396                )
397            else:
398                assert hasattr(self.trainer, kwarg_name)
399                obj = getattr(self.trainer, kwarg_name)
400                if obj is None or type(obj) in (
401                    bool,
402                    bytearray,
403                    bytes,
404                    dict,
405                    float,
406                    frozenset,
407                    int,
408                    list,
409                    set,
410                    str,
411                    tuple,
412                ):
413                    self.dump_generic_builtin(kwarg_name)
414                else:
415                    self.dump_generic_instance(kwarg_name)
416
417        def dump_generic_builtin(self, kwarg_name: str) -> None:
418            """@private
419            """
420            assert hasattr(self.trainer, kwarg_name)
421            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
422
423        def dump_generic_class(self, kwarg_name: str) -> None:
424            """@private
425            """
426            assert hasattr(self.trainer, kwarg_name)
427            assert kwarg_name.endswith("_class")
428            obj = getattr(self.trainer, kwarg_name)
429            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
430
431        def dump_generic_instance(self, kwarg_name: str) -> None:
432            """@private
433            """
434            assert hasattr(self.trainer, kwarg_name)
435            instance = getattr(self.trainer, kwarg_name)
436            self.init_data.update(
437                {
438                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
439                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
440                }
441            )
442
443        def dump_device(self, kwarg_name: str):
444            """@private
445            """
446            assert hasattr(self.trainer, kwarg_name)
447            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
448
449        def dump_data_loader(self, kwarg_name: str) -> None:
450            """@private
451            """
452            assert hasattr(self.trainer, kwarg_name)
453            loader = getattr(self.trainer, kwarg_name)
454            if loader is None:
455                return
456            self.init_data.update(
457                {
458                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
459                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
460                }
461            )
462
463        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
464            """@private
465            """
466            self.dump_generic_class(f"{kwarg_name}_class")
467
468        def dump_model(self, kwarg_name: str):
469            """@private
470            """
471            if is_compiled(self.trainer.model):
472                self.init_data.update(
473                    {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs}
474                )
475            else:
476                self.dump_generic_instance("model")
477
478    def _build_init(self) -> Dict[str, Any]:
479        serializer = self.Serializer(self)
480        for name in inspect.signature(self.__class__).parameters:
481            # special rules to serialize kwargs
482            # if a trainer class inherits from DefaultTrainer and has **kwargs
483            # they need to be saved in self._kwargs
484            if name == "kwargs":
485                if not hasattr(self, "_kwargs"):
486                    msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\
487                          "Please add self._kwargs to its __init__ function"
488                    raise RuntimeError(msg)
489                kwargs = getattr(self, "_kwargs")
490                for kwarg_name in kwargs:
491                    serializer.dump(kwarg_name)
492                continue
493            serializer.dump(name)
494
495        return serializer.init_data
496
497    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
498        assert self.train_loader is not None
499        assert self.val_loader is not None
500        assert self.model is not None
501        assert self.loss is not None
502        assert self.optimizer is not None
503        assert self.metric is not None
504        assert self.device is not None
505
506        if load_from_checkpoint is not None:
507            self.load_checkpoint(load_from_checkpoint)
508
509        if sum((iterations is not None, epochs is not None)) != 1:
510            raise ValueError(
511                "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer."
512                f"You have passed 'iterations'={iterations} and 'epochs'={epochs}"
513            )
514
515        if epochs is None:
516            epochs = int(np.ceil(float(iterations) / len(self.train_loader)))
517        else:
518            iterations = epochs * len(self.train_loader)
519
520        self.max_iteration = self._iteration + iterations
521        self.max_epoch = self._epoch + epochs
522
523        if not getattr(self, "_is_initialized", False):
524            # check if we compile the model (only supported by pytorch 2)
525            # to enable (de)serialization of compiled models, we keep track of the model class and kwargs
526            if is_compiled(self.model):
527                warnings.warn(
528                    "You have passed a compiled model to the trainer."
529                    "It will not be possible to (de)serialize the trainer with it."
530                    "If you want to be able to do this please pass the normal model."
531                    "It can be automatically compiled by setting 'compile_model' to True"
532                )
533            self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}"
534            self._model_kwargs = get_constructor_arguments(self.model)
535            self.model = auto_compile(self.model, self.compile_model)
536
537            self.model.to(self.device)
538            self.loss.to(self.device)
539
540            # this saves all the information that is necessary
541            # to fully load the trainer from the checkpoint
542            self.init_data = self._build_init()
543
544            if self.logger_class is None:
545                self.logger = None
546            else:
547                # may set self.name if self.name is None
548                save_root = getattr(self, "save_root", None)
549                try:
550                    self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {}))
551                except (PermissionError, RuntimeError):
552                    warnings.warn(
553                        f"The checkpoint folder at {self.checkpoint_folder} could not be created."
554                        "The most likely reason for this is that you copied the checkpoint somewhere else,"
555                        "so we skip this error to enable loading the model from this checkpoint."
556                    )
557
558            try:
559                os.makedirs(self.checkpoint_folder, exist_ok=True)
560            except PermissionError:
561                warnings.warn(
562                    f"The checkpoint folder at {self.checkpoint_folder} could not be created."
563                    "The most likely reason for this is that you copied the checkpoint somewhere else,"
564                    "so we skip this error to enable loading the model from this checkpoint."
565                )
566                pass
567
568        best_metric = np.inf
569        return best_metric
570
571    def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
572        """@private
573        """
574        save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
575        extra_init_dict = extra_save_dict.pop("init", {})
576        save_dict = {
577            "iteration": self._iteration,
578            "epoch": self._epoch,
579            "best_epoch": self._best_epoch,
580            "best_metric": best_metric,
581            "current_metric": current_metric,
582            "model_state": self.model.state_dict(),
583            "optimizer_state": self.optimizer.state_dict(),
584            "init": self.init_data | extra_init_dict,
585            "train_time": train_time,
586            "timestamp": datetime.now().strftime("%d-%m-%Y (%H:%M:%S)"),
587        }
588        save_dict.update(**extra_save_dict)
589        if self.scaler is not None:
590            save_dict.update({"scaler_state": self.scaler.state_dict()})
591        if self.lr_scheduler is not None:
592            save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()})
593
594        rank = getattr(self, "rank", None)
595        if rank is None or rank == 0:
596            torch.save(save_dict, save_path)
597
598    def load_checkpoint(self, checkpoint="best"):
599        """@private
600        """
601        if isinstance(checkpoint, str):
602            save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
603            if not os.path.exists(save_path):
604                warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.")
605                return
606            save_dict = torch.load(save_path, weights_only=False)
607        elif isinstance(checkpoint, dict):
608            save_dict = checkpoint
609        else:
610            raise RuntimeError
611
612        self._iteration = save_dict["iteration"]
613        self._epoch = save_dict["epoch"]
614        self._best_epoch = save_dict["best_epoch"]
615        self.best_metric = save_dict["best_metric"]
616        self.current_metric = save_dict["current_metric"]
617        self.train_time = save_dict.get("train_time", 0.0)
618
619        model_state = save_dict["model_state"]
620        # to enable loading compiled models
621        compiled_prefix = "_orig_mod."
622        model_state = OrderedDict(
623            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()]
624        )
625        self.model.load_state_dict(model_state)
626        # we need to send the network to the device before loading the optimizer state!
627        self.model.to(self.device)
628
629        self.optimizer.load_state_dict(save_dict["optimizer_state"])
630        if self.scaler is not None:
631            self.scaler.load_state_dict(save_dict["scaler_state"])
632        if self.lr_scheduler is not None:
633            self.lr_scheduler.load_state_dict(save_dict["scheduler_state"])
634
635        return save_dict
636
637    def _verify_if_training_completed(self, checkpoint="latest"):
638        save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
639        save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None
640        if save_dict and self.max_iteration == save_dict.get("iteration"):
641            return True
642        return False
643
644    def fit(
645        self,
646        iterations: Optional[int] = None,
647        load_from_checkpoint: Optional[Union[os.PathLike, str]] = None,
648        epochs: Optional[int] = None,
649        save_every_kth_epoch: Optional[int] = None,
650        progress=None,
651        overwrite_training: bool = True,
652    ):
653        """Run neural network training.
654
655        Exactly one of 'iterations' or 'epochs' has to be passed.
656
657        Args:
658            iterations: How long to train, specified in iterations.
659            load_from_checkpoint: Path to a checkpoint from where training should be continued .
660            epochs: How long to train, specified in epochs.
661            save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file.
662                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
663            progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
664            overwrite_training: Whether to overwrite existing checkpoints in the save directory.
665        """
666        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
667
668        if not overwrite_training:
669            if load_from_checkpoint is not None:
670                raise ValueError(
671                    "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time."
672                )
673
674            if self._verify_if_training_completed():
675                print(
676                    f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs "
677                    "and 'overwrite_training' is set to 'False'."
678                )
679                print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.")
680                return
681
682        print(
683            "Start fitting for",
684            self.max_iteration - self._iteration,
685            "iterations / ",
686            self.max_epoch - self._epoch,
687            "epochs",
688        )
689        print("with", len(self.train_loader), "iterations per epoch")
690
691        if self.mixed_precision:
692            train_epoch = self._train_epoch_mixed
693            validate = self._validate_mixed
694            print("Training with mixed precision")
695        else:
696            train_epoch = self._train_epoch
697            validate = self._validate
698            print("Training with single precision")
699
700        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
701        if progress is None:
702            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
703        else:
704            progress.total = total_iterations
705            progress.set_description(f"Epoch {self._epoch}")
706
707        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
708        train_epochs = self.max_epoch - self._epoch
709        t_start = time.time()
710        for epoch in range(train_epochs):
711
712            # Ensure data is shuffled differently at each epoch.
713            try:
714                self.train_loader.sampler.set_epoch(epoch)
715            except AttributeError:
716                pass
717
718            # Run training and validation for this epoch
719            t_per_iter = train_epoch(progress)
720            current_metric = validate()
721
722            # perform all the post-epoch steps:
723
724            # apply the learning rate scheduler
725            if self.lr_scheduler is not None:
726                self.lr_scheduler.step(current_metric)
727
728            # how long did we train in total?
729            total_train_time = (time.time() - t_start) + self.train_time
730
731            # save this checkpoint as the new best checkpoint if
732            # it has the best overall validation metric
733            if current_metric < best_metric:
734                best_metric = current_metric
735                self._best_epoch = self._epoch
736                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
737
738            # save this checkpoint as the latest checkpoint
739            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
740
741            # if we save after every k-th epoch then check if we need to save now
742            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
743                self.save_checkpoint(
744                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
745                )
746
747            # if early stopping has been specified then check if the stopping condition is met
748            if self.early_stopping is not None:
749                epochs_since_best = self._epoch - self._best_epoch
750                if epochs_since_best > self.early_stopping:
751                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
752                    break
753
754            self._epoch += 1
755            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
756
757        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
758        print(f"The best epoch is number {self._best_epoch}.")
759
760        if self._generate_name:
761            self.name = None
762
763        # Update the train time
764        self.train_time = total_train_time
765
766        # TODO save the model to wandb if we have the wandb logger
767        if isinstance(self.logger, WandbLogger):
768            self.logger.get_wandb().finish()
769
770    def _backprop(self, loss):
771        loss.backward()
772        self.optimizer.step()
773
774    def _backprop_mixed(self, loss):
775        self.scaler.scale(loss).backward()
776        self.scaler.step(self.optimizer)
777        self.scaler.update()
778
779    def _train_epoch(self, progress):
780        return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop)
781
782    def _train_epoch_mixed(self, progress):
783        return self._train_epoch_impl(
784            progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"),
785            self._backprop_mixed
786        )
787
788    def _forward_and_loss(self, x, y):
789        pred = self.model(x)
790        if self._iteration % self.log_image_interval == 0:
791            if pred.requires_grad:
792                pred.retain_grad()
793
794        loss = self.loss(pred, y)
795        return pred, loss
796
797    def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]):
798        self.model.train()
799
800        n_iter = 0
801        t_per_iter = time.time()
802        for x, y in self.train_loader:
803            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
804
805            self.optimizer.zero_grad()
806
807            with forward_context():
808                pred, loss = self._forward_and_loss(x, y)
809
810            backprop(loss)
811
812            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
813            if self.logger is not None:
814                self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True)
815
816            self._iteration += 1
817            n_iter += 1
818            if self._iteration >= self.max_iteration:
819                break
820            progress.update(1)
821
822        t_per_iter = (time.time() - t_per_iter) / n_iter
823        return t_per_iter
824
825    def _validate(self):
826        return self._validate_impl(contextlib.nullcontext)
827
828    def _validate_mixed(self):
829        return self._validate_impl(
830            partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda")
831        )
832
833    def _validate_impl(self, forward_context):
834        self.model.eval()
835
836        metric_val = 0.0
837        loss_val = 0.0
838
839        with torch.no_grad():
840            for x, y in self.val_loader:
841                x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
842                with forward_context():
843                    pred, loss = self._forward_and_loss(x, y)
844                    metric = self.metric(pred, y)
845
846                loss_val += loss.item()
847                metric_val += metric.item()
848
849        metric_val /= len(self.val_loader)
850        loss_val /= len(self.val_loader)
851        if self.logger is not None:
852            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred)
853        return metric_val

Trainer class for training a segmentation network.

The trainer class implements the core logic for training a network with pytorch. It implements a training loop to run training and validation, which is started with fit. The checkpoints and logs from the training run will be saved in the current working directory, or in the directory specifified by save_root. Training can be continued from a checkpoint by passing it's location to the load_from_checkpoint argument of fit.

A pre-configured instance of the trainer can be obtained from torch_em.default_segmentation_trainer. Alternatively, the trainer class can also be instantiated as in this example:

import torch
from torch_em.loss import DiceLoss
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
from torch_em.trainer import DefaultTrainer

# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)

# Create the model and optimizer.
model = UNet2d(in_channels=1, out_channels=1)
optimizer = torch.optim.AdamW(model.parameters())

trainer = DefaultTrainer(
    name="unet-training",
    train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
    val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
    model=model,
    loss=DiceLoss(),  # The loss function.
    optimizer=optimizer,
    metric=DiceLoss(),  # The metric. The trainer expects smaller values to represent better results.
    device="cuda",  # The device to use for training.
)
trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
Arguments:
  • name: The name of the checkpoint that will be created by the trainer.
  • train_loader: The data loader containing the training data.
  • val_loader: The data loader containing the validation data.
  • model: The model to train.
  • loss: The loss function for training.
  • optimizer: The optimizer.
  • metric: The metric for validation.
  • device: The torch device to use for training. If None, will use a GPU if available.
  • lr_scheduler: The learning rate scheduler.
  • log_image_interval: The interval for saving images during logging, in training iterations.
  • mixed_precision: Whether to train with mixed precision.
  • early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
  • logger: The logger class. Will be instantiated for logging. By default uses torch_em.training.tensorboard_logger.TensorboardLogger.
  • logger_kwargs: The keyword arguments for the logger class.
  • id_: Unique identifier for the trainer. If None then name will be used.
  • save_root: The root folder for saving the checkpoint and logs.
  • compile_model: Whether to compile the model before training.
  • rank: Rank argument for distributed training. See torch_em.multi_gpu_training for details.
DefaultTrainer( name: Optional[str], train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, model: torch.nn.modules.module.Module, loss: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, metric: Callable, device: Union[str, torch.device], lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, log_image_interval: int = 100, mixed_precision: bool = True, early_stopping: Optional[int] = None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, logger_kwargs: Optional[Dict[str, Any]] = None, id_: Optional[str] = None, save_root: Optional[str] = None, compile_model: Union[bool, str, NoneType] = None, rank: Optional[int] = None)
 85    def __init__(
 86        self,
 87        name: Optional[str],
 88        train_loader: torch.utils.data.DataLoader,
 89        val_loader: torch.utils.data.DataLoader,
 90        model: torch.nn.Module,
 91        loss: torch.nn.Module,
 92        optimizer: torch.optim.Optimizer,
 93        metric: Callable,
 94        device: Union[str, torch.device],
 95        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
 96        log_image_interval: int = 100,
 97        mixed_precision: bool = True,
 98        early_stopping: Optional[int] = None,
 99        logger=TensorboardLogger,
100        logger_kwargs: Optional[Dict[str, Any]] = None,
101        id_: Optional[str] = None,
102        save_root: Optional[str] = None,
103        compile_model: Optional[Union[bool, str]] = None,
104        rank: Optional[int] = None,
105    ):
106        if name is None and not issubclass(logger, WandbLogger):
107            raise TypeError("Name cannot be None if not using the WandbLogger")
108
109        self._generate_name = name is None
110        self.name = name
111        self.id_ = id_ or name
112        self.train_loader = train_loader
113        self.val_loader = val_loader
114        self.model = model
115        self.loss = loss
116        self.optimizer = optimizer
117        self.metric = metric
118        self.device = torch.device(device)
119        self.lr_scheduler = lr_scheduler
120        self.log_image_interval = log_image_interval
121        self.save_root = save_root
122        self.compile_model = compile_model
123        self.rank = rank
124
125        self._iteration = 0
126        self._epoch = 0
127        self._best_epoch = 0
128
129        self.mixed_precision = mixed_precision
130        self.early_stopping = early_stopping
131        self.train_time = 0.0
132
133        if mixed_precision:
134            self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda")
135        else:
136            self.scaler = None
137
138        self.logger_class = logger
139        self.logger_kwargs = logger_kwargs
140        self.log_image_interval = log_image_interval
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
142    @property
143    def checkpoint_folder(self):
144        assert self.id_ is not None  # Because the logger may generate and set trainer.id on logger.__init__.
145        # Save_root enables saving the checkpoints somewhere else than in the local folder.
146        # This is handy for filesystems with limited space, where saving the checkpoints
147        # and log files can lead to running out of space.
148        save_root = getattr(self, "save_root", None)
149        return os.path.join("./checkpoints", self.id_) if save_root is None else\
150            os.path.join(save_root, "./checkpoints", self.id_)
iteration
152    @property
153    def iteration(self):
154        return self._iteration
epoch
156    @property
157    def epoch(self):
158        return self._epoch
def fit( self, iterations: Optional[int] = None, load_from_checkpoint: Union[str, os.PathLike, NoneType] = None, epochs: Optional[int] = None, save_every_kth_epoch: Optional[int] = None, progress=None, overwrite_training: bool = True):
644    def fit(
645        self,
646        iterations: Optional[int] = None,
647        load_from_checkpoint: Optional[Union[os.PathLike, str]] = None,
648        epochs: Optional[int] = None,
649        save_every_kth_epoch: Optional[int] = None,
650        progress=None,
651        overwrite_training: bool = True,
652    ):
653        """Run neural network training.
654
655        Exactly one of 'iterations' or 'epochs' has to be passed.
656
657        Args:
658            iterations: How long to train, specified in iterations.
659            load_from_checkpoint: Path to a checkpoint from where training should be continued .
660            epochs: How long to train, specified in epochs.
661            save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file.
662                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
663            progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
664            overwrite_training: Whether to overwrite existing checkpoints in the save directory.
665        """
666        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
667
668        if not overwrite_training:
669            if load_from_checkpoint is not None:
670                raise ValueError(
671                    "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time."
672                )
673
674            if self._verify_if_training_completed():
675                print(
676                    f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs "
677                    "and 'overwrite_training' is set to 'False'."
678                )
679                print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.")
680                return
681
682        print(
683            "Start fitting for",
684            self.max_iteration - self._iteration,
685            "iterations / ",
686            self.max_epoch - self._epoch,
687            "epochs",
688        )
689        print("with", len(self.train_loader), "iterations per epoch")
690
691        if self.mixed_precision:
692            train_epoch = self._train_epoch_mixed
693            validate = self._validate_mixed
694            print("Training with mixed precision")
695        else:
696            train_epoch = self._train_epoch
697            validate = self._validate
698            print("Training with single precision")
699
700        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
701        if progress is None:
702            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
703        else:
704            progress.total = total_iterations
705            progress.set_description(f"Epoch {self._epoch}")
706
707        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
708        train_epochs = self.max_epoch - self._epoch
709        t_start = time.time()
710        for epoch in range(train_epochs):
711
712            # Ensure data is shuffled differently at each epoch.
713            try:
714                self.train_loader.sampler.set_epoch(epoch)
715            except AttributeError:
716                pass
717
718            # Run training and validation for this epoch
719            t_per_iter = train_epoch(progress)
720            current_metric = validate()
721
722            # perform all the post-epoch steps:
723
724            # apply the learning rate scheduler
725            if self.lr_scheduler is not None:
726                self.lr_scheduler.step(current_metric)
727
728            # how long did we train in total?
729            total_train_time = (time.time() - t_start) + self.train_time
730
731            # save this checkpoint as the new best checkpoint if
732            # it has the best overall validation metric
733            if current_metric < best_metric:
734                best_metric = current_metric
735                self._best_epoch = self._epoch
736                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
737
738            # save this checkpoint as the latest checkpoint
739            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
740
741            # if we save after every k-th epoch then check if we need to save now
742            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
743                self.save_checkpoint(
744                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
745                )
746
747            # if early stopping has been specified then check if the stopping condition is met
748            if self.early_stopping is not None:
749                epochs_since_best = self._epoch - self._best_epoch
750                if epochs_since_best > self.early_stopping:
751                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
752                    break
753
754            self._epoch += 1
755            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
756
757        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
758        print(f"The best epoch is number {self._best_epoch}.")
759
760        if self._generate_name:
761            self.name = None
762
763        # Update the train time
764        self.train_time = total_train_time
765
766        # TODO save the model to wandb if we have the wandb logger
767        if isinstance(self.logger, WandbLogger):
768            self.logger.get_wandb().finish()

Run neural network training.

Exactly one of 'iterations' or 'epochs' has to be passed.

Arguments:
  • iterations: How long to train, specified in iterations.
  • load_from_checkpoint: Path to a checkpoint from where training should be continued .
  • epochs: How long to train, specified in epochs.
  • save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
  • progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
  • overwrite_training: Whether to overwrite existing checkpoints in the save directory.
class DefaultTrainer.Deserializer:
160    class Deserializer:
161        """Determines how to deserialize the trainer kwargs from serialized 'init_data'.
162
163        Examples:
164            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
165            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
166
167            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
168            >>> class MyTrainer(DefaultTrainer):
169            >>>     def __init__(self, *args, the_answer: int, **kwargs):
170            >>>         super().__init__(*args, **kwargs)
171            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
172            >>>                                       # see DefaultTrainer.Serializer
173            >>>
174            >>>     class Deserializer(DefaultTrainer.Deserializer):
175            >>>         def load_the_answer(self):
176            >>>             generic_answer = self.init_data["the_answer"]
177            >>>             # (device dependent) special deserialization
178            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
179            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
180            >>>             else:
181            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
182
183        Args:
184            init_data: The initialization data of the trainer.
185            save_path: The path where the checkpoint was saved.
186            device: The device.
187        """
188
189        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
190            self.init_data = init_data
191            self.save_path = save_path
192            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
193            self.trainer_kwargs: Dict[str, Any] = dict(
194                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
195            )
196
197        def load(self, kwarg_name: str, optional):
198            """@private
199            """
200            # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
201            if kwarg_name == "device":
202                pass  # deserialized in __init__
203            elif kwarg_name.endswith("_loader"):
204                self.load_data_loader(kwarg_name, optional)
205            else:
206                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
207                load(kwarg_name, optional=optional)
208
209        def load_data_loader(self, loader_name, optional) -> None:
210            """@private
211            """
212            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
213            if ds is None and optional:
214                return
215
216            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
217            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
218            # monkey patch shuffle loader_name to the loader
219            loader.shuffle = loader_kwargs.get("shuffle", False)
220            self.trainer_kwargs[loader_name] = loader
221
222        def load_generic(
223            self,
224            kwarg_name: str,
225            *dynamic_args: Dict,
226            optional: bool,
227            only_class: bool = False,
228            dynamic_kwargs: Optional[Dict[str, Any]] = None,
229        ) -> None:
230            """@private
231            """
232            if kwarg_name in self.init_data:
233                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
234                return
235
236            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
237            if this_cls is None:
238                if optional:
239                    return
240                else:
241                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
242
243            assert isinstance(this_cls, str), this_cls
244            assert "." in this_cls, this_cls
245            cls_p, cls_m = this_cls.rsplit(".", 1)
246            this_cls = getattr(import_module(cls_p), cls_m)
247            if only_class:
248                self.trainer_kwargs[kwarg_name] = this_cls
249            else:
250                self.trainer_kwargs[kwarg_name] = this_cls(
251                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
252                )
253
254        def load_name(self, kwarg_name: str, optional: bool):
255            """@private
256            """
257            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
258
259        def load_optimizer(self, kwarg_name: str, optional: bool):
260            """@private
261            """
262            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
263
264        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
265            """@private
266            """
267            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
268
269        # todo: remove and rename kwarg 'logger' to 'logger_class'
270        def load_logger(self, kwarg_name: str, optional: bool):
271            """@private
272            """
273            assert kwarg_name == "logger"
274            self.load_generic("logger", optional=optional, only_class=True)

Determines how to deserialize the trainer kwargs from serialized 'init_data'.

Examples:

To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. Note that DefaultTrainer.Deserializer.load_generic() covers most cases already.

This example adds the_answer kwarg, which requires 'calculations' upon initialization:

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
>>>                                       # see DefaultTrainer.Serializer
>>>
>>>     class Deserializer(DefaultTrainer.Deserializer):
>>>         def load_the_answer(self):
>>>             generic_answer = self.init_data["the_answer"]
>>>             # (device dependent) special deserialization
>>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
>>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
>>>             else:
>>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
Arguments:
  • init_data: The initialization data of the trainer.
  • save_path: The path where the checkpoint was saved.
  • device: The device.
DefaultTrainer.Deserializer(init_data: Dict, save_path: str, device: Union[str, torch.device])
189        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
190            self.init_data = init_data
191            self.save_path = save_path
192            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
193            self.trainer_kwargs: Dict[str, Any] = dict(
194                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
195            )
init_data
save_path
trainer_kwargs: Dict[str, Any]
class DefaultTrainer.Serializer:
326    class Serializer:
327        """Implements how to serialize trainer kwargs from a trainer instance.
328
329        Examples:
330            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
331            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
332            called by the `dump()` method when appropriate cover most cases already.
333
334            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
335            'the_answer' attribute:
336            >>> class MyTrainer(DefaultTrainer):
337            >>>     def __init__(self, *args, the_answer: int, **kwargs):
338            >>>         super().__init__(*args, **kwargs)
339            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
340            >>>         # but let's make things more interesting...
341            >>>         self.the = the_answer // 10
342            >>>         self.answer = the_answer % 10
343            >>>
344            >>>     class Serializer(DefaultTrainer.Serializer):
345            >>>         trainer: MyTrainer
346            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
347            >>>             assert kwarg_name == "the_answer"
348            >>>             # populate self.init_data with the serialized data required by Deserializer
349            >>>             # to restore the trainer kwargs
350            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
351
352            This example with both Serializer and Deserializer adds `the_answer` kwarg,
353            while saving it in two separate entries 'the' and 'answer'
354            >>> class MyTrainer(DefaultTrainer):
355            >>>     def __init__(self, *args, the_answer: int, **kwargs):
356            >>>         super().__init__(*args, **kwargs)
357            >>>         self.the_answer = the_answer
358            >>>
359            >>>     class Serializer(DefaultTrainer.Serializer):
360            >>>         trainer: MyTrainer
361            >>>         def dump_the_answer(self, kwarg_name: str):
362            >>>             assert kwarg_name == "the_answer"
363            >>>             self.init_data.update({
364            >>>                 "the": self.trainer.the_answer // 10,
365            >>>                 "answer": self.trainer.the_answer % 10
366            >>>             })
367            >>>
368            >>>     class Deserializer(DefaultTrainer.Deserializer):
369            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
370            >>>             assert kwarg_name == "the_answer"
371            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
372            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
373
374        Args:
375            trainer: The trainer instance.
376        """
377
378        def __init__(self, trainer: DefaultTrainer):
379            self.trainer = trainer
380            self.init_data = {}  # to be populated during serialization process
381
382        def dump(self, kwarg_name: str) -> None:
383            """@private
384            """
385            dumper = getattr(self, f"dump_{kwarg_name}", None)
386            if dumper is not None:
387                dumper(kwarg_name)
388            elif kwarg_name.endswith("_loader"):
389                self.dump_data_loader(kwarg_name)
390            elif kwarg_name.endswith("_class"):
391                self.dump_generic_class(kwarg_name)
392            elif not hasattr(self.trainer, kwarg_name):
393                raise AttributeError(
394                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
395                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
396                )
397            else:
398                assert hasattr(self.trainer, kwarg_name)
399                obj = getattr(self.trainer, kwarg_name)
400                if obj is None or type(obj) in (
401                    bool,
402                    bytearray,
403                    bytes,
404                    dict,
405                    float,
406                    frozenset,
407                    int,
408                    list,
409                    set,
410                    str,
411                    tuple,
412                ):
413                    self.dump_generic_builtin(kwarg_name)
414                else:
415                    self.dump_generic_instance(kwarg_name)
416
417        def dump_generic_builtin(self, kwarg_name: str) -> None:
418            """@private
419            """
420            assert hasattr(self.trainer, kwarg_name)
421            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
422
423        def dump_generic_class(self, kwarg_name: str) -> None:
424            """@private
425            """
426            assert hasattr(self.trainer, kwarg_name)
427            assert kwarg_name.endswith("_class")
428            obj = getattr(self.trainer, kwarg_name)
429            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
430
431        def dump_generic_instance(self, kwarg_name: str) -> None:
432            """@private
433            """
434            assert hasattr(self.trainer, kwarg_name)
435            instance = getattr(self.trainer, kwarg_name)
436            self.init_data.update(
437                {
438                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
439                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
440                }
441            )
442
443        def dump_device(self, kwarg_name: str):
444            """@private
445            """
446            assert hasattr(self.trainer, kwarg_name)
447            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
448
449        def dump_data_loader(self, kwarg_name: str) -> None:
450            """@private
451            """
452            assert hasattr(self.trainer, kwarg_name)
453            loader = getattr(self.trainer, kwarg_name)
454            if loader is None:
455                return
456            self.init_data.update(
457                {
458                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
459                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
460                }
461            )
462
463        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
464            """@private
465            """
466            self.dump_generic_class(f"{kwarg_name}_class")
467
468        def dump_model(self, kwarg_name: str):
469            """@private
470            """
471            if is_compiled(self.trainer.model):
472                self.init_data.update(
473                    {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs}
474                )
475            else:
476                self.dump_generic_instance("model")

Implements how to serialize trainer kwargs from a trainer instance.

Examples:

To extend the serialization process you can inherite from this Serializer in a derived Trainer class. Note that the methods dump_generic_builtin(), dump_generic_class() and dump_generic_instance() called by the dump() method when appropriate cover most cases already.

This example adds the_answer kwarg, which requires extra steps on dumping only because we don't keep a 'the_answer' attribute:

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
>>>         # but let's make things more interesting...
>>>         self.the = the_answer // 10
>>>         self.answer = the_answer % 10
>>>
>>>     class Serializer(DefaultTrainer.Serializer):
>>>         trainer: MyTrainer
>>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
>>>             assert kwarg_name == "the_answer"
>>>             # populate self.init_data with the serialized data required by Deserializer
>>>             # to restore the trainer kwargs
>>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer

This example with both Serializer and Deserializer adds the_answer kwarg, while saving it in two separate entries 'the' and 'answer'

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         self.the_answer = the_answer
>>>
>>>     class Serializer(DefaultTrainer.Serializer):
>>>         trainer: MyTrainer
>>>         def dump_the_answer(self, kwarg_name: str):
>>>             assert kwarg_name == "the_answer"
>>>             self.init_data.update({
>>>                 "the": self.trainer.the_answer // 10,
>>>                 "answer": self.trainer.the_answer % 10
>>>             })
>>>
>>>     class Deserializer(DefaultTrainer.Deserializer):
>>>         def load_the_answer(self, kwarg_name: str, optional: bool):
>>>             assert kwarg_name == "the_answer"
>>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
>>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
Arguments:
  • trainer: The trainer instance.
DefaultTrainer.Serializer(trainer: DefaultTrainer)
378        def __init__(self, trainer: DefaultTrainer):
379            self.trainer = trainer
380            self.init_data = {}  # to be populated during serialization process
trainer
init_data