torch_em.trainer.default_trainer

  1from __future__ import annotations
  2
  3import contextlib
  4import inspect
  5import os
  6import time
  7import warnings
  8from collections import OrderedDict
  9from importlib import import_module
 10from typing import Any, Callable, Dict, Optional, Union
 11
 12import numpy as np
 13import torch
 14import torch.cuda.amp as amp
 15from tqdm import tqdm
 16
 17from .tensorboard_logger import TensorboardLogger
 18from .wandb_logger import WandbLogger
 19from ..util import auto_compile, get_constructor_arguments, is_compiled
 20
 21
 22class DefaultTrainer:
 23    """Trainer class for 2d/3d training on a single GPU."""
 24
 25    def __init__(
 26        self,
 27        name: Optional[str],
 28        train_loader: torch.utils.data.DataLoader,
 29        val_loader: torch.utils.data.DataLoader,
 30        model: torch.nn.Module,
 31        loss,
 32        optimizer,
 33        metric,
 34        device: Union[str, torch.device],
 35        lr_scheduler=None,
 36        log_image_interval=100,
 37        mixed_precision=True,
 38        early_stopping=None,
 39        logger=TensorboardLogger,
 40        logger_kwargs: Optional[Dict[str, Any]] = None,
 41        id_: Optional[str] = None,
 42        save_root: Optional[str] = None,
 43        compile_model: Optional[Union[bool, str]] = None,
 44    ):
 45        if name is None and not issubclass(logger, WandbLogger):
 46            raise TypeError("Name cannot be None if not using the WandbLogger")
 47
 48        if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
 49            raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.")
 50
 51        self._generate_name = name is None
 52        self.name = name
 53        self.id_ = id_ or name
 54        self.train_loader = train_loader
 55        self.val_loader = val_loader
 56        self.model = model
 57        self.loss = loss
 58        self.optimizer = optimizer
 59        self.metric = metric
 60        self.device = device
 61        self.lr_scheduler = lr_scheduler
 62        self.log_image_interval = log_image_interval
 63        self.save_root = save_root
 64        self.compile_model = compile_model
 65
 66        self._iteration = 0
 67        self._epoch = 0
 68        self._best_epoch = 0
 69
 70        self.mixed_precision = mixed_precision
 71        self.early_stopping = early_stopping
 72        self.train_time = 0.0
 73
 74        self.scaler = amp.GradScaler() if mixed_precision else None
 75
 76        self.logger_class = logger
 77        self.logger_kwargs = logger_kwargs
 78        self.log_image_interval = log_image_interval
 79
 80    @property  # because the logger may generate and set trainer.id on logger.__init__
 81    def checkpoint_folder(self):
 82        assert self.id_ is not None
 83        # save_root enables saving the checkpoints somewhere else than in the local
 84        # folder. This is handy for filesystems with limited space, where saving the checkpoints
 85        # and log files can easily lead to running out of space.
 86        save_root = getattr(self, "save_root", None)
 87        return os.path.join("./checkpoints", self.id_) if save_root is None else\
 88            os.path.join(save_root, "./checkpoints", self.id_)
 89
 90    @property
 91    def iteration(self):
 92        return self._iteration
 93
 94    @property
 95    def epoch(self):
 96        return self._epoch
 97
 98    class Deserializer:
 99        """Determines how to deserialize the trainer kwargs from serialized 'init_data'
100
101        Examples:
102            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
103            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
104
105            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
106            >>> class MyTrainer(DefaultTrainer):
107            >>>     def __init__(self, *args, the_answer: int, **kwargs):
108            >>>         super().__init__(*args, **kwargs)
109            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
110            >>>                                       # see DefaultTrainer.Serializer
111            >>>
112            >>>     class Deserializer(DefaultTrainer.Deserializer):
113            >>>         def load_the_answer(self):
114            >>>             generic_answer = self.init_data["the_answer"]
115            >>>             # (device dependent) special deserialization
116            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
117            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
118            >>>             else:
119            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
120        """
121
122        def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]):
123            self.init_data = init_data
124            self.save_path = save_path
125            # populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'
126            self.trainer_kwargs: Dict[str, Any] = dict(
127                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
128            )
129
130        def load(self, kwarg_name: str, optional):
131            """`optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'"""
132
133            if kwarg_name == "device":
134                pass  # deserialized in __init__
135            elif kwarg_name.endswith("_loader"):
136                self.load_data_loader(kwarg_name, optional)
137            else:
138                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
139                load(kwarg_name, optional=optional)
140
141        def load_data_loader(self, loader_name, optional) -> None:
142            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
143            if ds is None and optional:
144                return
145
146            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
147            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
148            # monkey patch shuffle loader_name to the loader
149            loader.shuffle = loader_kwargs.get("shuffle", False)
150            self.trainer_kwargs[loader_name] = loader
151
152        def load_generic(
153            self,
154            kwarg_name: str,
155            *dynamic_args,
156            optional: bool,
157            only_class: bool = False,
158            dynamic_kwargs: Optional[Dict[str, Any]] = None,
159        ) -> None:
160            if kwarg_name in self.init_data:
161                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
162                return
163
164            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
165            if this_cls is None:
166                if optional:
167                    return
168                else:
169                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
170
171            assert isinstance(this_cls, str), this_cls
172            assert "." in this_cls, this_cls
173            cls_p, cls_m = this_cls.rsplit(".", 1)
174            this_cls = getattr(import_module(cls_p), cls_m)
175            if only_class:
176                self.trainer_kwargs[kwarg_name] = this_cls
177            else:
178                self.trainer_kwargs[kwarg_name] = this_cls(
179                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
180                )
181
182        def load_name(self, kwarg_name: str, optional: bool):
183            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
184
185        def load_optimizer(self, kwarg_name: str, optional: bool):
186            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
187
188        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
189            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
190
191        # todo: remove and rename kwarg 'logger' to 'logger_class'
192        def load_logger(self, kwarg_name: str, optional: bool):
193            assert kwarg_name == "logger"
194            self.load_generic("logger", optional=optional, only_class=True)
195
196    @staticmethod
197    def _get_save_dict(save_path, device):
198        if not os.path.exists(save_path):
199            raise ValueError(f"Cannot find checkpoint {save_path}")
200        return torch.load(save_path, map_location=device)
201
202    @classmethod
203    def from_checkpoint(cls, checkpoint_folder, name="best", device=None):
204        save_path = os.path.join(checkpoint_folder, f"{name}.pt")
205        # make sure the correct device is set if we don't have access to CUDA
206        if not torch.cuda.is_available():
207            device = "cpu"
208        save_dict = cls._get_save_dict(save_path, device)
209        deserializer = cls.Deserializer(save_dict["init"], save_path, device)
210
211        has_kwargs = False
212        deserialized = []
213        for name, parameter in inspect.signature(cls).parameters.items():
214            if name == "kwargs":
215                has_kwargs = True
216                continue
217            deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty)
218            deserialized.append(name)
219
220        # to deserialze kwargs we can't rely on inspecting the signature, so we
221        # go through the remaning kwarg names in init data instead
222        if has_kwargs:
223            kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized))
224            for name in kwarg_names:
225                if name.endswith("_kwargs"):
226                    continue
227                elif name.endswith("_dataset"):
228                    deserializer.load(name.replace("dataset", "loader"), optional=False)
229                elif name.endswith("_class"):
230                    deserializer.load(name.replace("_class", ""), optional=False)
231                else:
232                    deserializer.load(name, optional=False)
233
234        trainer = cls(**deserializer.trainer_kwargs)
235        trainer._initialize(0, save_dict)
236        trainer._is_initialized = True
237        return trainer
238
239    class Serializer:
240        """Implements how to serialize trainer kwargs from a trainer instance
241
242        Examples:
243            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
244            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
245            called by the `dump()` method when appropriate cover most cases already.
246
247            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
248            'the_answer' attribute:
249            >>> class MyTrainer(DefaultTrainer):
250            >>>     def __init__(self, *args, the_answer: int, **kwargs):
251            >>>         super().__init__(*args, **kwargs)
252            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
253            >>>         # but let's make things more interesting...
254            >>>         self.the = the_answer // 10
255            >>>         self.answer = the_answer % 10
256            >>>
257            >>>     class Serializer(DefaultTrainer.Serializer):
258            >>>         trainer: MyTrainer
259            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
260            >>>             assert kwarg_name == "the_answer"
261            >>>             # populate self.init_data with the serialized data required by Deserializer
262            >>>             # to restore the trainer kwargs
263            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
264
265            This example with both Serializer and Deserializer adds `the_answer` kwarg,
266            while saving it in two separate entries 'the' and 'answer'
267            >>> class MyTrainer(DefaultTrainer):
268            >>>     def __init__(self, *args, the_answer: int, **kwargs):
269            >>>         super().__init__(*args, **kwargs)
270            >>>         self.the_answer = the_answer
271            >>>
272            >>>     class Serializer(DefaultTrainer.Serializer):
273            >>>         trainer: MyTrainer
274            >>>         def dump_the_answer(self, kwarg_name: str):
275            >>>             assert kwarg_name == "the_answer"
276            >>>             self.init_data.update({
277            >>>                 "the": self.trainer.the_answer // 10,
278            >>>                 "answer": self.trainer.the_answer % 10
279            >>>             })
280            >>>
281            >>>     class Deserializer(DefaultTrainer.Deserializer):
282            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
283            >>>             assert kwarg_name == "the_answer"
284            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
285            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
286        """
287
288        def __init__(self, trainer: DefaultTrainer):
289            self.trainer = trainer
290            self.init_data = {}  # to be populated during serialization process
291
292        def dump(self, kwarg_name: str) -> None:
293            dumper = getattr(self, f"dump_{kwarg_name}", None)
294            if dumper is not None:
295                dumper(kwarg_name)
296            elif kwarg_name.endswith("_loader"):
297                self.dump_data_loader(kwarg_name)
298            elif kwarg_name.endswith("_class"):
299                self.dump_generic_class(kwarg_name)
300            elif not hasattr(self.trainer, kwarg_name):
301                raise AttributeError(
302                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
303                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
304                )
305            else:
306                assert hasattr(self.trainer, kwarg_name)
307                obj = getattr(self.trainer, kwarg_name)
308                if obj is None or type(obj) in (
309                    bool,
310                    bytearray,
311                    bytes,
312                    dict,
313                    float,
314                    frozenset,
315                    int,
316                    list,
317                    set,
318                    str,
319                    tuple,
320                ):
321                    self.dump_generic_builtin(kwarg_name)
322                else:
323                    self.dump_generic_instance(kwarg_name)
324
325        def dump_generic_builtin(self, kwarg_name: str) -> None:
326            assert hasattr(self.trainer, kwarg_name)
327            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
328
329        def dump_generic_class(self, kwarg_name: str) -> None:
330            assert hasattr(self.trainer, kwarg_name)
331            assert kwarg_name.endswith("_class")
332            obj = getattr(self.trainer, kwarg_name)
333            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
334
335        def dump_generic_instance(self, kwarg_name: str) -> None:
336            assert hasattr(self.trainer, kwarg_name)
337            instance = getattr(self.trainer, kwarg_name)
338            self.init_data.update(
339                {
340                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
341                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
342                }
343            )
344
345        def dump_device(self, kwarg_name: str):
346            assert hasattr(self.trainer, kwarg_name)
347            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
348
349        def dump_data_loader(self, kwarg_name: str) -> None:
350            assert hasattr(self.trainer, kwarg_name)
351            loader = getattr(self.trainer, kwarg_name)
352            if loader is None:
353                return
354            self.init_data.update(
355                {
356                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
357                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
358                }
359            )
360
361        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
362            self.dump_generic_class(f"{kwarg_name}_class")
363
364        def dump_model(self, kwarg_name: str):
365            if is_compiled(self.trainer.model):
366                self.init_data.update(
367                    {
368                        "model_class": self.trainer._model_class,
369                        "model_kwargs": self.trainer._model_kwargs,
370                    }
371                )
372            else:
373                self.dump_generic_instance("model")
374
375    def _build_init(self) -> Dict[str, Any]:
376        serializer = self.Serializer(self)
377        for name in inspect.signature(self.__class__).parameters:
378            # special rules to serialize kwargs
379            # if a trainer class inherits from DefaultTrainer and has **kwargs
380            # they need to be saved in self._kwargs
381            if name == "kwargs":
382                if not hasattr(self, "_kwargs"):
383                    msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\
384                          "Please add self._kwargs to its __init__ function"
385                    raise RuntimeError(msg)
386                kwargs = getattr(self, "_kwargs")
387                for kwarg_name in kwargs:
388                    serializer.dump(kwarg_name)
389                continue
390            serializer.dump(name)
391
392        return serializer.init_data
393
394    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
395        assert self.train_loader is not None
396        assert self.val_loader is not None
397        assert self.model is not None
398        assert self.loss is not None
399        assert self.optimizer is not None
400        assert self.metric is not None
401        assert self.device is not None
402
403        if load_from_checkpoint is not None:
404            self.load_checkpoint(load_from_checkpoint)
405
406        if sum((iterations is not None, epochs is not None)) != 1:
407            raise ValueError(
408                "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer."
409                f"You have passed 'iterations'={iterations} and 'epochs'={epochs}"
410            )
411
412        if epochs is None:
413            epochs = int(np.ceil(float(iterations) / len(self.train_loader)))
414        else:
415            iterations = epochs * len(self.train_loader)
416
417        self.max_iteration = self._iteration + iterations
418        self.max_epoch = self._epoch + epochs
419
420        if not getattr(self, "_is_initialized", False):
421            # check if we compile the model (only supported by pytorch 2)
422            # to enable (de)serialization of compiled models, we keep track of the model class and kwargs
423            if is_compiled(self.model):
424                warnings.warn(
425                    "You have passed a compiled model to the trainer."
426                    "It will not be possible to (de)serialize the trainer with it."
427                    "If you want to be able to do this please pass the normal model."
428                    "It can be automatically compiled by setting 'compile_model' to True"
429                )
430            self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}"
431            self._model_kwargs = get_constructor_arguments(self.model)
432            self.model = auto_compile(self.model, self.compile_model)
433
434            self.model.to(self.device)
435            self.loss.to(self.device)
436
437            # this saves all the information that is necessary
438            # to fully load the trainer from the checkpoint
439            self.init_data = self._build_init()
440
441            if self.logger_class is None:
442                self.logger = None
443            else:
444                # may set self.name if self.name is None
445                save_root = getattr(self, "save_root", None)
446                self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {}))
447
448            try:
449                os.makedirs(self.checkpoint_folder, exist_ok=True)
450            except PermissionError:
451                warnings.warn(
452                    f"The checkpoint folder at {self.checkpoint_folder} could not be created."
453                    "The most likely reason for this is that you copied the checkpoint somewhere else,"
454                    "so we skip this error to enable loading the model from this checkpoint."
455                )
456                pass
457
458        best_metric = np.inf
459        return best_metric
460
461    def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
462        save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
463        extra_init_dict = extra_save_dict.pop("init", {})
464        save_dict = {
465            "iteration": self._iteration,
466            "epoch": self._epoch,
467            "best_epoch": self._best_epoch,
468            "best_metric": best_metric,
469            "current_metric": current_metric,
470            "model_state": self.model.state_dict(),
471            "optimizer_state": self.optimizer.state_dict(),
472            "init": self.init_data | extra_init_dict,
473            "train_time": train_time,
474        }
475        save_dict.update(**extra_save_dict)
476        if self.scaler is not None:
477            save_dict.update({"scaler_state": self.scaler.state_dict()})
478        if self.lr_scheduler is not None:
479            save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()})
480        torch.save(save_dict, save_path)
481
482    def load_checkpoint(self, checkpoint="best"):
483        if isinstance(checkpoint, str):
484            save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
485            if not os.path.exists(save_path):
486                warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.")
487                return
488            save_dict = torch.load(save_path)
489        elif isinstance(checkpoint, dict):
490            save_dict = checkpoint
491        else:
492            raise RuntimeError
493
494        self._iteration = save_dict["iteration"]
495        self._epoch = save_dict["epoch"]
496        self._best_epoch = save_dict["best_epoch"]
497        self.best_metric = save_dict["best_metric"]
498        self.current_metric = save_dict["current_metric"]
499        self.train_time = save_dict.get("train_time", 0.0)
500
501        model_state = save_dict["model_state"]
502        # to enable loading compiled models
503        compiled_prefix = "_orig_mod."
504        model_state = OrderedDict(
505            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()]
506        )
507        self.model.load_state_dict(model_state)
508        # we need to send the network to the device before loading the optimizer state!
509        self.model.to(self.device)
510
511        self.optimizer.load_state_dict(save_dict["optimizer_state"])
512        if self.scaler is not None:
513            self.scaler.load_state_dict(save_dict["scaler_state"])
514        if self.lr_scheduler is not None:
515            self.lr_scheduler.load_state_dict(save_dict["scheduler_state"])
516
517        return save_dict
518
519    def fit(
520        self,
521        iterations=None,
522        load_from_checkpoint=None,
523        epochs=None,
524        save_every_kth_epoch=None,
525        progress=None,
526    ):
527        """Run neural network training.
528
529        Exactly one of 'iterations' or 'epochs' has to be passed.
530
531        Parameters:
532            iterations [int] - how long to train, specified in iterations (default: None)
533            load_from_checkpoint [str] - path to a checkpoint from where training should be continued (default: None)
534            epochs [int] - how long to train, specified in epochs (default: None)
535            save_every_kth_epoch [int] - save checkpoints after every kth epoch separately.
536                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None)
537            progress [progress_bar] - optional progress bar for integration with external tools.
538                Expected to follow the tqdm interface.
539        """
540        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
541        print(
542            "Start fitting for",
543            self.max_iteration - self._iteration,
544            "iterations / ",
545            self.max_epoch - self._epoch,
546            "epochs",
547        )
548        print("with", len(self.train_loader), "iterations per epoch")
549
550        if self.mixed_precision:
551            train_epoch = self._train_epoch_mixed
552            validate = self._validate_mixed
553            print("Training with mixed precision")
554        else:
555            train_epoch = self._train_epoch
556            validate = self._validate
557            print("Training with single precision")
558
559        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
560        if progress is None:
561            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
562        else:
563            progress.total = total_iterations
564            progress.set_description(f"Epoch {self._epoch}")
565
566        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
567        train_epochs = self.max_epoch - self._epoch
568        t_start = time.time()
569        for _ in range(train_epochs):
570
571            # run training and validation for this epoch
572            t_per_iter = train_epoch(progress)
573            current_metric = validate()
574
575            # perform all the post-epoch steps:
576
577            # apply the learning rate scheduler
578            if self.lr_scheduler is not None:
579                self.lr_scheduler.step(current_metric)
580
581            # how long did we train in total?
582            total_train_time = (time.time() - t_start) + self.train_time
583
584            # save this checkpoint as the new best checkpoint if
585            # it has the best overall validation metric
586            if current_metric < best_metric:
587                best_metric = current_metric
588                self._best_epoch = self._epoch
589                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
590
591            # save this checkpoint as the latest checkpoint
592            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
593
594            # if we save after every k-th epoch then check if we need to save now
595            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
596                self.save_checkpoint(
597                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
598                )
599
600            # if early stopping has been specified then check if the stopping condition is met
601            if self.early_stopping is not None:
602                epochs_since_best = self._epoch - self._best_epoch
603                if epochs_since_best > self.early_stopping:
604                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
605                    break
606
607            self._epoch += 1
608            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
609
610        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
611        print(f"The best epoch is number {self._best_epoch}.")
612
613        if self._generate_name:
614            self.name = None
615
616        # Update the train time
617        self.train_time = total_train_time
618
619        # TODO save the model to wandb if we have the wandb logger
620        if isinstance(self.logger, WandbLogger):
621            self.logger.get_wandb().finish()
622
623    def _backprop(self, loss):
624        loss.backward()
625        self.optimizer.step()
626
627    def _backprop_mixed(self, loss):
628        self.scaler.scale(loss).backward()
629        self.scaler.step(self.optimizer)
630        self.scaler.update()
631
632    def _train_epoch(self, progress):
633        return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop)
634
635    def _train_epoch_mixed(self, progress):
636        return self._train_epoch_impl(progress, amp.autocast, self._backprop_mixed)
637
638    def _forward_and_loss(self, x, y):
639        pred = self.model(x)
640        if self._iteration % self.log_image_interval == 0:
641            if pred.requires_grad:
642                pred.retain_grad()
643
644        loss = self.loss(pred, y)
645        return pred, loss
646
647    def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]):
648        self.model.train()
649
650        n_iter = 0
651        t_per_iter = time.time()
652        for x, y in self.train_loader:
653            x, y = x.to(self.device), y.to(self.device)
654
655            self.optimizer.zero_grad()
656
657            with forward_context():
658                pred, loss = self._forward_and_loss(x, y)
659
660            backprop(loss)
661
662            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
663            if self.logger is not None:
664                self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True)
665
666            self._iteration += 1
667            n_iter += 1
668            if self._iteration >= self.max_iteration:
669                break
670            progress.update(1)
671
672        t_per_iter = (time.time() - t_per_iter) / n_iter
673        return t_per_iter
674
675    def _validate(self):
676        return self._validate_impl(contextlib.nullcontext)
677
678    def _validate_mixed(self):
679        return self._validate_impl(amp.autocast)
680
681    def _validate_impl(self, forward_context):
682        self.model.eval()
683
684        metric_val = 0.0
685        loss_val = 0.0
686
687        with torch.no_grad():
688            for x, y in self.val_loader:
689                x, y = x.to(self.device), y.to(self.device)
690                with forward_context():
691                    pred, loss = self._forward_and_loss(x, y)
692                    metric = self.metric(pred, y)
693
694                loss_val += loss.item()
695                metric_val += metric.item()
696
697        metric_val /= len(self.val_loader)
698        loss_val /= len(self.val_loader)
699        if self.logger is not None:
700            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred)
701        return metric_val
class DefaultTrainer:
 23class DefaultTrainer:
 24    """Trainer class for 2d/3d training on a single GPU."""
 25
 26    def __init__(
 27        self,
 28        name: Optional[str],
 29        train_loader: torch.utils.data.DataLoader,
 30        val_loader: torch.utils.data.DataLoader,
 31        model: torch.nn.Module,
 32        loss,
 33        optimizer,
 34        metric,
 35        device: Union[str, torch.device],
 36        lr_scheduler=None,
 37        log_image_interval=100,
 38        mixed_precision=True,
 39        early_stopping=None,
 40        logger=TensorboardLogger,
 41        logger_kwargs: Optional[Dict[str, Any]] = None,
 42        id_: Optional[str] = None,
 43        save_root: Optional[str] = None,
 44        compile_model: Optional[Union[bool, str]] = None,
 45    ):
 46        if name is None and not issubclass(logger, WandbLogger):
 47            raise TypeError("Name cannot be None if not using the WandbLogger")
 48
 49        if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
 50            raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.")
 51
 52        self._generate_name = name is None
 53        self.name = name
 54        self.id_ = id_ or name
 55        self.train_loader = train_loader
 56        self.val_loader = val_loader
 57        self.model = model
 58        self.loss = loss
 59        self.optimizer = optimizer
 60        self.metric = metric
 61        self.device = device
 62        self.lr_scheduler = lr_scheduler
 63        self.log_image_interval = log_image_interval
 64        self.save_root = save_root
 65        self.compile_model = compile_model
 66
 67        self._iteration = 0
 68        self._epoch = 0
 69        self._best_epoch = 0
 70
 71        self.mixed_precision = mixed_precision
 72        self.early_stopping = early_stopping
 73        self.train_time = 0.0
 74
 75        self.scaler = amp.GradScaler() if mixed_precision else None
 76
 77        self.logger_class = logger
 78        self.logger_kwargs = logger_kwargs
 79        self.log_image_interval = log_image_interval
 80
 81    @property  # because the logger may generate and set trainer.id on logger.__init__
 82    def checkpoint_folder(self):
 83        assert self.id_ is not None
 84        # save_root enables saving the checkpoints somewhere else than in the local
 85        # folder. This is handy for filesystems with limited space, where saving the checkpoints
 86        # and log files can easily lead to running out of space.
 87        save_root = getattr(self, "save_root", None)
 88        return os.path.join("./checkpoints", self.id_) if save_root is None else\
 89            os.path.join(save_root, "./checkpoints", self.id_)
 90
 91    @property
 92    def iteration(self):
 93        return self._iteration
 94
 95    @property
 96    def epoch(self):
 97        return self._epoch
 98
 99    class Deserializer:
100        """Determines how to deserialize the trainer kwargs from serialized 'init_data'
101
102        Examples:
103            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
104            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
105
106            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
107            >>> class MyTrainer(DefaultTrainer):
108            >>>     def __init__(self, *args, the_answer: int, **kwargs):
109            >>>         super().__init__(*args, **kwargs)
110            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
111            >>>                                       # see DefaultTrainer.Serializer
112            >>>
113            >>>     class Deserializer(DefaultTrainer.Deserializer):
114            >>>         def load_the_answer(self):
115            >>>             generic_answer = self.init_data["the_answer"]
116            >>>             # (device dependent) special deserialization
117            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
118            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
119            >>>             else:
120            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
121        """
122
123        def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]):
124            self.init_data = init_data
125            self.save_path = save_path
126            # populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'
127            self.trainer_kwargs: Dict[str, Any] = dict(
128                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
129            )
130
131        def load(self, kwarg_name: str, optional):
132            """`optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'"""
133
134            if kwarg_name == "device":
135                pass  # deserialized in __init__
136            elif kwarg_name.endswith("_loader"):
137                self.load_data_loader(kwarg_name, optional)
138            else:
139                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
140                load(kwarg_name, optional=optional)
141
142        def load_data_loader(self, loader_name, optional) -> None:
143            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
144            if ds is None and optional:
145                return
146
147            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
148            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
149            # monkey patch shuffle loader_name to the loader
150            loader.shuffle = loader_kwargs.get("shuffle", False)
151            self.trainer_kwargs[loader_name] = loader
152
153        def load_generic(
154            self,
155            kwarg_name: str,
156            *dynamic_args,
157            optional: bool,
158            only_class: bool = False,
159            dynamic_kwargs: Optional[Dict[str, Any]] = None,
160        ) -> None:
161            if kwarg_name in self.init_data:
162                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
163                return
164
165            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
166            if this_cls is None:
167                if optional:
168                    return
169                else:
170                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
171
172            assert isinstance(this_cls, str), this_cls
173            assert "." in this_cls, this_cls
174            cls_p, cls_m = this_cls.rsplit(".", 1)
175            this_cls = getattr(import_module(cls_p), cls_m)
176            if only_class:
177                self.trainer_kwargs[kwarg_name] = this_cls
178            else:
179                self.trainer_kwargs[kwarg_name] = this_cls(
180                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
181                )
182
183        def load_name(self, kwarg_name: str, optional: bool):
184            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
185
186        def load_optimizer(self, kwarg_name: str, optional: bool):
187            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
188
189        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
190            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
191
192        # todo: remove and rename kwarg 'logger' to 'logger_class'
193        def load_logger(self, kwarg_name: str, optional: bool):
194            assert kwarg_name == "logger"
195            self.load_generic("logger", optional=optional, only_class=True)
196
197    @staticmethod
198    def _get_save_dict(save_path, device):
199        if not os.path.exists(save_path):
200            raise ValueError(f"Cannot find checkpoint {save_path}")
201        return torch.load(save_path, map_location=device)
202
203    @classmethod
204    def from_checkpoint(cls, checkpoint_folder, name="best", device=None):
205        save_path = os.path.join(checkpoint_folder, f"{name}.pt")
206        # make sure the correct device is set if we don't have access to CUDA
207        if not torch.cuda.is_available():
208            device = "cpu"
209        save_dict = cls._get_save_dict(save_path, device)
210        deserializer = cls.Deserializer(save_dict["init"], save_path, device)
211
212        has_kwargs = False
213        deserialized = []
214        for name, parameter in inspect.signature(cls).parameters.items():
215            if name == "kwargs":
216                has_kwargs = True
217                continue
218            deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty)
219            deserialized.append(name)
220
221        # to deserialze kwargs we can't rely on inspecting the signature, so we
222        # go through the remaning kwarg names in init data instead
223        if has_kwargs:
224            kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized))
225            for name in kwarg_names:
226                if name.endswith("_kwargs"):
227                    continue
228                elif name.endswith("_dataset"):
229                    deserializer.load(name.replace("dataset", "loader"), optional=False)
230                elif name.endswith("_class"):
231                    deserializer.load(name.replace("_class", ""), optional=False)
232                else:
233                    deserializer.load(name, optional=False)
234
235        trainer = cls(**deserializer.trainer_kwargs)
236        trainer._initialize(0, save_dict)
237        trainer._is_initialized = True
238        return trainer
239
240    class Serializer:
241        """Implements how to serialize trainer kwargs from a trainer instance
242
243        Examples:
244            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
245            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
246            called by the `dump()` method when appropriate cover most cases already.
247
248            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
249            'the_answer' attribute:
250            >>> class MyTrainer(DefaultTrainer):
251            >>>     def __init__(self, *args, the_answer: int, **kwargs):
252            >>>         super().__init__(*args, **kwargs)
253            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
254            >>>         # but let's make things more interesting...
255            >>>         self.the = the_answer // 10
256            >>>         self.answer = the_answer % 10
257            >>>
258            >>>     class Serializer(DefaultTrainer.Serializer):
259            >>>         trainer: MyTrainer
260            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
261            >>>             assert kwarg_name == "the_answer"
262            >>>             # populate self.init_data with the serialized data required by Deserializer
263            >>>             # to restore the trainer kwargs
264            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
265
266            This example with both Serializer and Deserializer adds `the_answer` kwarg,
267            while saving it in two separate entries 'the' and 'answer'
268            >>> class MyTrainer(DefaultTrainer):
269            >>>     def __init__(self, *args, the_answer: int, **kwargs):
270            >>>         super().__init__(*args, **kwargs)
271            >>>         self.the_answer = the_answer
272            >>>
273            >>>     class Serializer(DefaultTrainer.Serializer):
274            >>>         trainer: MyTrainer
275            >>>         def dump_the_answer(self, kwarg_name: str):
276            >>>             assert kwarg_name == "the_answer"
277            >>>             self.init_data.update({
278            >>>                 "the": self.trainer.the_answer // 10,
279            >>>                 "answer": self.trainer.the_answer % 10
280            >>>             })
281            >>>
282            >>>     class Deserializer(DefaultTrainer.Deserializer):
283            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
284            >>>             assert kwarg_name == "the_answer"
285            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
286            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
287        """
288
289        def __init__(self, trainer: DefaultTrainer):
290            self.trainer = trainer
291            self.init_data = {}  # to be populated during serialization process
292
293        def dump(self, kwarg_name: str) -> None:
294            dumper = getattr(self, f"dump_{kwarg_name}", None)
295            if dumper is not None:
296                dumper(kwarg_name)
297            elif kwarg_name.endswith("_loader"):
298                self.dump_data_loader(kwarg_name)
299            elif kwarg_name.endswith("_class"):
300                self.dump_generic_class(kwarg_name)
301            elif not hasattr(self.trainer, kwarg_name):
302                raise AttributeError(
303                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
304                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
305                )
306            else:
307                assert hasattr(self.trainer, kwarg_name)
308                obj = getattr(self.trainer, kwarg_name)
309                if obj is None or type(obj) in (
310                    bool,
311                    bytearray,
312                    bytes,
313                    dict,
314                    float,
315                    frozenset,
316                    int,
317                    list,
318                    set,
319                    str,
320                    tuple,
321                ):
322                    self.dump_generic_builtin(kwarg_name)
323                else:
324                    self.dump_generic_instance(kwarg_name)
325
326        def dump_generic_builtin(self, kwarg_name: str) -> None:
327            assert hasattr(self.trainer, kwarg_name)
328            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
329
330        def dump_generic_class(self, kwarg_name: str) -> None:
331            assert hasattr(self.trainer, kwarg_name)
332            assert kwarg_name.endswith("_class")
333            obj = getattr(self.trainer, kwarg_name)
334            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
335
336        def dump_generic_instance(self, kwarg_name: str) -> None:
337            assert hasattr(self.trainer, kwarg_name)
338            instance = getattr(self.trainer, kwarg_name)
339            self.init_data.update(
340                {
341                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
342                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
343                }
344            )
345
346        def dump_device(self, kwarg_name: str):
347            assert hasattr(self.trainer, kwarg_name)
348            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
349
350        def dump_data_loader(self, kwarg_name: str) -> None:
351            assert hasattr(self.trainer, kwarg_name)
352            loader = getattr(self.trainer, kwarg_name)
353            if loader is None:
354                return
355            self.init_data.update(
356                {
357                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
358                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
359                }
360            )
361
362        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
363            self.dump_generic_class(f"{kwarg_name}_class")
364
365        def dump_model(self, kwarg_name: str):
366            if is_compiled(self.trainer.model):
367                self.init_data.update(
368                    {
369                        "model_class": self.trainer._model_class,
370                        "model_kwargs": self.trainer._model_kwargs,
371                    }
372                )
373            else:
374                self.dump_generic_instance("model")
375
376    def _build_init(self) -> Dict[str, Any]:
377        serializer = self.Serializer(self)
378        for name in inspect.signature(self.__class__).parameters:
379            # special rules to serialize kwargs
380            # if a trainer class inherits from DefaultTrainer and has **kwargs
381            # they need to be saved in self._kwargs
382            if name == "kwargs":
383                if not hasattr(self, "_kwargs"):
384                    msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\
385                          "Please add self._kwargs to its __init__ function"
386                    raise RuntimeError(msg)
387                kwargs = getattr(self, "_kwargs")
388                for kwarg_name in kwargs:
389                    serializer.dump(kwarg_name)
390                continue
391            serializer.dump(name)
392
393        return serializer.init_data
394
395    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
396        assert self.train_loader is not None
397        assert self.val_loader is not None
398        assert self.model is not None
399        assert self.loss is not None
400        assert self.optimizer is not None
401        assert self.metric is not None
402        assert self.device is not None
403
404        if load_from_checkpoint is not None:
405            self.load_checkpoint(load_from_checkpoint)
406
407        if sum((iterations is not None, epochs is not None)) != 1:
408            raise ValueError(
409                "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer."
410                f"You have passed 'iterations'={iterations} and 'epochs'={epochs}"
411            )
412
413        if epochs is None:
414            epochs = int(np.ceil(float(iterations) / len(self.train_loader)))
415        else:
416            iterations = epochs * len(self.train_loader)
417
418        self.max_iteration = self._iteration + iterations
419        self.max_epoch = self._epoch + epochs
420
421        if not getattr(self, "_is_initialized", False):
422            # check if we compile the model (only supported by pytorch 2)
423            # to enable (de)serialization of compiled models, we keep track of the model class and kwargs
424            if is_compiled(self.model):
425                warnings.warn(
426                    "You have passed a compiled model to the trainer."
427                    "It will not be possible to (de)serialize the trainer with it."
428                    "If you want to be able to do this please pass the normal model."
429                    "It can be automatically compiled by setting 'compile_model' to True"
430                )
431            self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}"
432            self._model_kwargs = get_constructor_arguments(self.model)
433            self.model = auto_compile(self.model, self.compile_model)
434
435            self.model.to(self.device)
436            self.loss.to(self.device)
437
438            # this saves all the information that is necessary
439            # to fully load the trainer from the checkpoint
440            self.init_data = self._build_init()
441
442            if self.logger_class is None:
443                self.logger = None
444            else:
445                # may set self.name if self.name is None
446                save_root = getattr(self, "save_root", None)
447                self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {}))
448
449            try:
450                os.makedirs(self.checkpoint_folder, exist_ok=True)
451            except PermissionError:
452                warnings.warn(
453                    f"The checkpoint folder at {self.checkpoint_folder} could not be created."
454                    "The most likely reason for this is that you copied the checkpoint somewhere else,"
455                    "so we skip this error to enable loading the model from this checkpoint."
456                )
457                pass
458
459        best_metric = np.inf
460        return best_metric
461
462    def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
463        save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
464        extra_init_dict = extra_save_dict.pop("init", {})
465        save_dict = {
466            "iteration": self._iteration,
467            "epoch": self._epoch,
468            "best_epoch": self._best_epoch,
469            "best_metric": best_metric,
470            "current_metric": current_metric,
471            "model_state": self.model.state_dict(),
472            "optimizer_state": self.optimizer.state_dict(),
473            "init": self.init_data | extra_init_dict,
474            "train_time": train_time,
475        }
476        save_dict.update(**extra_save_dict)
477        if self.scaler is not None:
478            save_dict.update({"scaler_state": self.scaler.state_dict()})
479        if self.lr_scheduler is not None:
480            save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()})
481        torch.save(save_dict, save_path)
482
483    def load_checkpoint(self, checkpoint="best"):
484        if isinstance(checkpoint, str):
485            save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
486            if not os.path.exists(save_path):
487                warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.")
488                return
489            save_dict = torch.load(save_path)
490        elif isinstance(checkpoint, dict):
491            save_dict = checkpoint
492        else:
493            raise RuntimeError
494
495        self._iteration = save_dict["iteration"]
496        self._epoch = save_dict["epoch"]
497        self._best_epoch = save_dict["best_epoch"]
498        self.best_metric = save_dict["best_metric"]
499        self.current_metric = save_dict["current_metric"]
500        self.train_time = save_dict.get("train_time", 0.0)
501
502        model_state = save_dict["model_state"]
503        # to enable loading compiled models
504        compiled_prefix = "_orig_mod."
505        model_state = OrderedDict(
506            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()]
507        )
508        self.model.load_state_dict(model_state)
509        # we need to send the network to the device before loading the optimizer state!
510        self.model.to(self.device)
511
512        self.optimizer.load_state_dict(save_dict["optimizer_state"])
513        if self.scaler is not None:
514            self.scaler.load_state_dict(save_dict["scaler_state"])
515        if self.lr_scheduler is not None:
516            self.lr_scheduler.load_state_dict(save_dict["scheduler_state"])
517
518        return save_dict
519
520    def fit(
521        self,
522        iterations=None,
523        load_from_checkpoint=None,
524        epochs=None,
525        save_every_kth_epoch=None,
526        progress=None,
527    ):
528        """Run neural network training.
529
530        Exactly one of 'iterations' or 'epochs' has to be passed.
531
532        Parameters:
533            iterations [int] - how long to train, specified in iterations (default: None)
534            load_from_checkpoint [str] - path to a checkpoint from where training should be continued (default: None)
535            epochs [int] - how long to train, specified in epochs (default: None)
536            save_every_kth_epoch [int] - save checkpoints after every kth epoch separately.
537                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None)
538            progress [progress_bar] - optional progress bar for integration with external tools.
539                Expected to follow the tqdm interface.
540        """
541        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
542        print(
543            "Start fitting for",
544            self.max_iteration - self._iteration,
545            "iterations / ",
546            self.max_epoch - self._epoch,
547            "epochs",
548        )
549        print("with", len(self.train_loader), "iterations per epoch")
550
551        if self.mixed_precision:
552            train_epoch = self._train_epoch_mixed
553            validate = self._validate_mixed
554            print("Training with mixed precision")
555        else:
556            train_epoch = self._train_epoch
557            validate = self._validate
558            print("Training with single precision")
559
560        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
561        if progress is None:
562            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
563        else:
564            progress.total = total_iterations
565            progress.set_description(f"Epoch {self._epoch}")
566
567        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
568        train_epochs = self.max_epoch - self._epoch
569        t_start = time.time()
570        for _ in range(train_epochs):
571
572            # run training and validation for this epoch
573            t_per_iter = train_epoch(progress)
574            current_metric = validate()
575
576            # perform all the post-epoch steps:
577
578            # apply the learning rate scheduler
579            if self.lr_scheduler is not None:
580                self.lr_scheduler.step(current_metric)
581
582            # how long did we train in total?
583            total_train_time = (time.time() - t_start) + self.train_time
584
585            # save this checkpoint as the new best checkpoint if
586            # it has the best overall validation metric
587            if current_metric < best_metric:
588                best_metric = current_metric
589                self._best_epoch = self._epoch
590                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
591
592            # save this checkpoint as the latest checkpoint
593            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
594
595            # if we save after every k-th epoch then check if we need to save now
596            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
597                self.save_checkpoint(
598                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
599                )
600
601            # if early stopping has been specified then check if the stopping condition is met
602            if self.early_stopping is not None:
603                epochs_since_best = self._epoch - self._best_epoch
604                if epochs_since_best > self.early_stopping:
605                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
606                    break
607
608            self._epoch += 1
609            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
610
611        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
612        print(f"The best epoch is number {self._best_epoch}.")
613
614        if self._generate_name:
615            self.name = None
616
617        # Update the train time
618        self.train_time = total_train_time
619
620        # TODO save the model to wandb if we have the wandb logger
621        if isinstance(self.logger, WandbLogger):
622            self.logger.get_wandb().finish()
623
624    def _backprop(self, loss):
625        loss.backward()
626        self.optimizer.step()
627
628    def _backprop_mixed(self, loss):
629        self.scaler.scale(loss).backward()
630        self.scaler.step(self.optimizer)
631        self.scaler.update()
632
633    def _train_epoch(self, progress):
634        return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop)
635
636    def _train_epoch_mixed(self, progress):
637        return self._train_epoch_impl(progress, amp.autocast, self._backprop_mixed)
638
639    def _forward_and_loss(self, x, y):
640        pred = self.model(x)
641        if self._iteration % self.log_image_interval == 0:
642            if pred.requires_grad:
643                pred.retain_grad()
644
645        loss = self.loss(pred, y)
646        return pred, loss
647
648    def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]):
649        self.model.train()
650
651        n_iter = 0
652        t_per_iter = time.time()
653        for x, y in self.train_loader:
654            x, y = x.to(self.device), y.to(self.device)
655
656            self.optimizer.zero_grad()
657
658            with forward_context():
659                pred, loss = self._forward_and_loss(x, y)
660
661            backprop(loss)
662
663            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
664            if self.logger is not None:
665                self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True)
666
667            self._iteration += 1
668            n_iter += 1
669            if self._iteration >= self.max_iteration:
670                break
671            progress.update(1)
672
673        t_per_iter = (time.time() - t_per_iter) / n_iter
674        return t_per_iter
675
676    def _validate(self):
677        return self._validate_impl(contextlib.nullcontext)
678
679    def _validate_mixed(self):
680        return self._validate_impl(amp.autocast)
681
682    def _validate_impl(self, forward_context):
683        self.model.eval()
684
685        metric_val = 0.0
686        loss_val = 0.0
687
688        with torch.no_grad():
689            for x, y in self.val_loader:
690                x, y = x.to(self.device), y.to(self.device)
691                with forward_context():
692                    pred, loss = self._forward_and_loss(x, y)
693                    metric = self.metric(pred, y)
694
695                loss_val += loss.item()
696                metric_val += metric.item()
697
698        metric_val /= len(self.val_loader)
699        loss_val /= len(self.val_loader)
700        if self.logger is not None:
701            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred)
702        return metric_val

Trainer class for 2d/3d training on a single GPU.

DefaultTrainer( name: Optional[str], train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, model: torch.nn.modules.module.Module, loss, optimizer, metric, device: Union[str, torch.device], lr_scheduler=None, log_image_interval=100, mixed_precision=True, early_stopping=None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, logger_kwargs: Optional[Dict[str, Any]] = None, id_: Optional[str] = None, save_root: Optional[str] = None, compile_model: Union[bool, str, NoneType] = None)
26    def __init__(
27        self,
28        name: Optional[str],
29        train_loader: torch.utils.data.DataLoader,
30        val_loader: torch.utils.data.DataLoader,
31        model: torch.nn.Module,
32        loss,
33        optimizer,
34        metric,
35        device: Union[str, torch.device],
36        lr_scheduler=None,
37        log_image_interval=100,
38        mixed_precision=True,
39        early_stopping=None,
40        logger=TensorboardLogger,
41        logger_kwargs: Optional[Dict[str, Any]] = None,
42        id_: Optional[str] = None,
43        save_root: Optional[str] = None,
44        compile_model: Optional[Union[bool, str]] = None,
45    ):
46        if name is None and not issubclass(logger, WandbLogger):
47            raise TypeError("Name cannot be None if not using the WandbLogger")
48
49        if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
50            raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.")
51
52        self._generate_name = name is None
53        self.name = name
54        self.id_ = id_ or name
55        self.train_loader = train_loader
56        self.val_loader = val_loader
57        self.model = model
58        self.loss = loss
59        self.optimizer = optimizer
60        self.metric = metric
61        self.device = device
62        self.lr_scheduler = lr_scheduler
63        self.log_image_interval = log_image_interval
64        self.save_root = save_root
65        self.compile_model = compile_model
66
67        self._iteration = 0
68        self._epoch = 0
69        self._best_epoch = 0
70
71        self.mixed_precision = mixed_precision
72        self.early_stopping = early_stopping
73        self.train_time = 0.0
74
75        self.scaler = amp.GradScaler() if mixed_precision else None
76
77        self.logger_class = logger
78        self.logger_kwargs = logger_kwargs
79        self.log_image_interval = log_image_interval
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
mixed_precision
early_stopping
train_time
scaler
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
@classmethod
def from_checkpoint(cls, checkpoint_folder, name='best', device=None):
203    @classmethod
204    def from_checkpoint(cls, checkpoint_folder, name="best", device=None):
205        save_path = os.path.join(checkpoint_folder, f"{name}.pt")
206        # make sure the correct device is set if we don't have access to CUDA
207        if not torch.cuda.is_available():
208            device = "cpu"
209        save_dict = cls._get_save_dict(save_path, device)
210        deserializer = cls.Deserializer(save_dict["init"], save_path, device)
211
212        has_kwargs = False
213        deserialized = []
214        for name, parameter in inspect.signature(cls).parameters.items():
215            if name == "kwargs":
216                has_kwargs = True
217                continue
218            deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty)
219            deserialized.append(name)
220
221        # to deserialze kwargs we can't rely on inspecting the signature, so we
222        # go through the remaning kwarg names in init data instead
223        if has_kwargs:
224            kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized))
225            for name in kwarg_names:
226                if name.endswith("_kwargs"):
227                    continue
228                elif name.endswith("_dataset"):
229                    deserializer.load(name.replace("dataset", "loader"), optional=False)
230                elif name.endswith("_class"):
231                    deserializer.load(name.replace("_class", ""), optional=False)
232                else:
233                    deserializer.load(name, optional=False)
234
235        trainer = cls(**deserializer.trainer_kwargs)
236        trainer._initialize(0, save_dict)
237        trainer._is_initialized = True
238        return trainer
def save_checkpoint( self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
462    def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
463        save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
464        extra_init_dict = extra_save_dict.pop("init", {})
465        save_dict = {
466            "iteration": self._iteration,
467            "epoch": self._epoch,
468            "best_epoch": self._best_epoch,
469            "best_metric": best_metric,
470            "current_metric": current_metric,
471            "model_state": self.model.state_dict(),
472            "optimizer_state": self.optimizer.state_dict(),
473            "init": self.init_data | extra_init_dict,
474            "train_time": train_time,
475        }
476        save_dict.update(**extra_save_dict)
477        if self.scaler is not None:
478            save_dict.update({"scaler_state": self.scaler.state_dict()})
479        if self.lr_scheduler is not None:
480            save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()})
481        torch.save(save_dict, save_path)
def load_checkpoint(self, checkpoint='best'):
483    def load_checkpoint(self, checkpoint="best"):
484        if isinstance(checkpoint, str):
485            save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
486            if not os.path.exists(save_path):
487                warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.")
488                return
489            save_dict = torch.load(save_path)
490        elif isinstance(checkpoint, dict):
491            save_dict = checkpoint
492        else:
493            raise RuntimeError
494
495        self._iteration = save_dict["iteration"]
496        self._epoch = save_dict["epoch"]
497        self._best_epoch = save_dict["best_epoch"]
498        self.best_metric = save_dict["best_metric"]
499        self.current_metric = save_dict["current_metric"]
500        self.train_time = save_dict.get("train_time", 0.0)
501
502        model_state = save_dict["model_state"]
503        # to enable loading compiled models
504        compiled_prefix = "_orig_mod."
505        model_state = OrderedDict(
506            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()]
507        )
508        self.model.load_state_dict(model_state)
509        # we need to send the network to the device before loading the optimizer state!
510        self.model.to(self.device)
511
512        self.optimizer.load_state_dict(save_dict["optimizer_state"])
513        if self.scaler is not None:
514            self.scaler.load_state_dict(save_dict["scaler_state"])
515        if self.lr_scheduler is not None:
516            self.lr_scheduler.load_state_dict(save_dict["scheduler_state"])
517
518        return save_dict
def fit( self, iterations=None, load_from_checkpoint=None, epochs=None, save_every_kth_epoch=None, progress=None):
520    def fit(
521        self,
522        iterations=None,
523        load_from_checkpoint=None,
524        epochs=None,
525        save_every_kth_epoch=None,
526        progress=None,
527    ):
528        """Run neural network training.
529
530        Exactly one of 'iterations' or 'epochs' has to be passed.
531
532        Parameters:
533            iterations [int] - how long to train, specified in iterations (default: None)
534            load_from_checkpoint [str] - path to a checkpoint from where training should be continued (default: None)
535            epochs [int] - how long to train, specified in epochs (default: None)
536            save_every_kth_epoch [int] - save checkpoints after every kth epoch separately.
537                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None)
538            progress [progress_bar] - optional progress bar for integration with external tools.
539                Expected to follow the tqdm interface.
540        """
541        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
542        print(
543            "Start fitting for",
544            self.max_iteration - self._iteration,
545            "iterations / ",
546            self.max_epoch - self._epoch,
547            "epochs",
548        )
549        print("with", len(self.train_loader), "iterations per epoch")
550
551        if self.mixed_precision:
552            train_epoch = self._train_epoch_mixed
553            validate = self._validate_mixed
554            print("Training with mixed precision")
555        else:
556            train_epoch = self._train_epoch
557            validate = self._validate
558            print("Training with single precision")
559
560        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
561        if progress is None:
562            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
563        else:
564            progress.total = total_iterations
565            progress.set_description(f"Epoch {self._epoch}")
566
567        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
568        train_epochs = self.max_epoch - self._epoch
569        t_start = time.time()
570        for _ in range(train_epochs):
571
572            # run training and validation for this epoch
573            t_per_iter = train_epoch(progress)
574            current_metric = validate()
575
576            # perform all the post-epoch steps:
577
578            # apply the learning rate scheduler
579            if self.lr_scheduler is not None:
580                self.lr_scheduler.step(current_metric)
581
582            # how long did we train in total?
583            total_train_time = (time.time() - t_start) + self.train_time
584
585            # save this checkpoint as the new best checkpoint if
586            # it has the best overall validation metric
587            if current_metric < best_metric:
588                best_metric = current_metric
589                self._best_epoch = self._epoch
590                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
591
592            # save this checkpoint as the latest checkpoint
593            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
594
595            # if we save after every k-th epoch then check if we need to save now
596            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
597                self.save_checkpoint(
598                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
599                )
600
601            # if early stopping has been specified then check if the stopping condition is met
602            if self.early_stopping is not None:
603                epochs_since_best = self._epoch - self._best_epoch
604                if epochs_since_best > self.early_stopping:
605                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
606                    break
607
608            self._epoch += 1
609            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
610
611        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
612        print(f"The best epoch is number {self._best_epoch}.")
613
614        if self._generate_name:
615            self.name = None
616
617        # Update the train time
618        self.train_time = total_train_time
619
620        # TODO save the model to wandb if we have the wandb logger
621        if isinstance(self.logger, WandbLogger):
622            self.logger.get_wandb().finish()

Run neural network training.

Exactly one of 'iterations' or 'epochs' has to be passed.

Arguments:
  • iterations [int] - how long to train, specified in iterations (default: None)
  • load_from_checkpoint [str] - path to a checkpoint from where training should be continued (default: None)
  • epochs [int] - how long to train, specified in epochs (default: None)
  • save_every_kth_epoch [int] - save checkpoints after every kth epoch separately. The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None)
  • progress [progress_bar] - optional progress bar for integration with external tools. Expected to follow the tqdm interface.
class DefaultTrainer.Deserializer:
 99    class Deserializer:
100        """Determines how to deserialize the trainer kwargs from serialized 'init_data'
101
102        Examples:
103            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
104            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
105
106            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
107            >>> class MyTrainer(DefaultTrainer):
108            >>>     def __init__(self, *args, the_answer: int, **kwargs):
109            >>>         super().__init__(*args, **kwargs)
110            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
111            >>>                                       # see DefaultTrainer.Serializer
112            >>>
113            >>>     class Deserializer(DefaultTrainer.Deserializer):
114            >>>         def load_the_answer(self):
115            >>>             generic_answer = self.init_data["the_answer"]
116            >>>             # (device dependent) special deserialization
117            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
118            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
119            >>>             else:
120            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
121        """
122
123        def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]):
124            self.init_data = init_data
125            self.save_path = save_path
126            # populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'
127            self.trainer_kwargs: Dict[str, Any] = dict(
128                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
129            )
130
131        def load(self, kwarg_name: str, optional):
132            """`optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'"""
133
134            if kwarg_name == "device":
135                pass  # deserialized in __init__
136            elif kwarg_name.endswith("_loader"):
137                self.load_data_loader(kwarg_name, optional)
138            else:
139                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
140                load(kwarg_name, optional=optional)
141
142        def load_data_loader(self, loader_name, optional) -> None:
143            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
144            if ds is None and optional:
145                return
146
147            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
148            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
149            # monkey patch shuffle loader_name to the loader
150            loader.shuffle = loader_kwargs.get("shuffle", False)
151            self.trainer_kwargs[loader_name] = loader
152
153        def load_generic(
154            self,
155            kwarg_name: str,
156            *dynamic_args,
157            optional: bool,
158            only_class: bool = False,
159            dynamic_kwargs: Optional[Dict[str, Any]] = None,
160        ) -> None:
161            if kwarg_name in self.init_data:
162                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
163                return
164
165            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
166            if this_cls is None:
167                if optional:
168                    return
169                else:
170                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
171
172            assert isinstance(this_cls, str), this_cls
173            assert "." in this_cls, this_cls
174            cls_p, cls_m = this_cls.rsplit(".", 1)
175            this_cls = getattr(import_module(cls_p), cls_m)
176            if only_class:
177                self.trainer_kwargs[kwarg_name] = this_cls
178            else:
179                self.trainer_kwargs[kwarg_name] = this_cls(
180                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
181                )
182
183        def load_name(self, kwarg_name: str, optional: bool):
184            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
185
186        def load_optimizer(self, kwarg_name: str, optional: bool):
187            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
188
189        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
190            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
191
192        # todo: remove and rename kwarg 'logger' to 'logger_class'
193        def load_logger(self, kwarg_name: str, optional: bool):
194            assert kwarg_name == "logger"
195            self.load_generic("logger", optional=optional, only_class=True)

Determines how to deserialize the trainer kwargs from serialized 'init_data'

Examples:

To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. Note that DefaultTrainer.Deserializer.load_generic() covers most cases already.

This example adds the_answer kwarg, which requires 'calculations' upon initialization:

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
>>>                                       # see DefaultTrainer.Serializer
>>>
>>>     class Deserializer(DefaultTrainer.Deserializer):
>>>         def load_the_answer(self):
>>>             generic_answer = self.init_data["the_answer"]
>>>             # (device dependent) special deserialization
>>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
>>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
>>>             else:
>>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
DefaultTrainer.Deserializer(init_data: dict, save_path: str, device: Union[str, torch.device])
123        def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]):
124            self.init_data = init_data
125            self.save_path = save_path
126            # populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'
127            self.trainer_kwargs: Dict[str, Any] = dict(
128                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
129            )
init_data
save_path
trainer_kwargs: Dict[str, Any]
def load(self, kwarg_name: str, optional):
131        def load(self, kwarg_name: str, optional):
132            """`optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'"""
133
134            if kwarg_name == "device":
135                pass  # deserialized in __init__
136            elif kwarg_name.endswith("_loader"):
137                self.load_data_loader(kwarg_name, optional)
138            else:
139                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
140                load(kwarg_name, optional=optional)

optional is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'

def load_data_loader(self, loader_name, optional) -> None:
142        def load_data_loader(self, loader_name, optional) -> None:
143            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
144            if ds is None and optional:
145                return
146
147            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
148            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
149            # monkey patch shuffle loader_name to the loader
150            loader.shuffle = loader_kwargs.get("shuffle", False)
151            self.trainer_kwargs[loader_name] = loader
def load_generic( self, kwarg_name: str, *dynamic_args, optional: bool, only_class: bool = False, dynamic_kwargs: Optional[Dict[str, Any]] = None) -> None:
153        def load_generic(
154            self,
155            kwarg_name: str,
156            *dynamic_args,
157            optional: bool,
158            only_class: bool = False,
159            dynamic_kwargs: Optional[Dict[str, Any]] = None,
160        ) -> None:
161            if kwarg_name in self.init_data:
162                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
163                return
164
165            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
166            if this_cls is None:
167                if optional:
168                    return
169                else:
170                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
171
172            assert isinstance(this_cls, str), this_cls
173            assert "." in this_cls, this_cls
174            cls_p, cls_m = this_cls.rsplit(".", 1)
175            this_cls = getattr(import_module(cls_p), cls_m)
176            if only_class:
177                self.trainer_kwargs[kwarg_name] = this_cls
178            else:
179                self.trainer_kwargs[kwarg_name] = this_cls(
180                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
181                )
def load_name(self, kwarg_name: str, optional: bool):
183        def load_name(self, kwarg_name: str, optional: bool):
184            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
def load_optimizer(self, kwarg_name: str, optional: bool):
186        def load_optimizer(self, kwarg_name: str, optional: bool):
187            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
def load_lr_scheduler(self, kwarg_name: str, optional: bool):
189        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
190            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
def load_logger(self, kwarg_name: str, optional: bool):
193        def load_logger(self, kwarg_name: str, optional: bool):
194            assert kwarg_name == "logger"
195            self.load_generic("logger", optional=optional, only_class=True)
class DefaultTrainer.Serializer:
240    class Serializer:
241        """Implements how to serialize trainer kwargs from a trainer instance
242
243        Examples:
244            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
245            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
246            called by the `dump()` method when appropriate cover most cases already.
247
248            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
249            'the_answer' attribute:
250            >>> class MyTrainer(DefaultTrainer):
251            >>>     def __init__(self, *args, the_answer: int, **kwargs):
252            >>>         super().__init__(*args, **kwargs)
253            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
254            >>>         # but let's make things more interesting...
255            >>>         self.the = the_answer // 10
256            >>>         self.answer = the_answer % 10
257            >>>
258            >>>     class Serializer(DefaultTrainer.Serializer):
259            >>>         trainer: MyTrainer
260            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
261            >>>             assert kwarg_name == "the_answer"
262            >>>             # populate self.init_data with the serialized data required by Deserializer
263            >>>             # to restore the trainer kwargs
264            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
265
266            This example with both Serializer and Deserializer adds `the_answer` kwarg,
267            while saving it in two separate entries 'the' and 'answer'
268            >>> class MyTrainer(DefaultTrainer):
269            >>>     def __init__(self, *args, the_answer: int, **kwargs):
270            >>>         super().__init__(*args, **kwargs)
271            >>>         self.the_answer = the_answer
272            >>>
273            >>>     class Serializer(DefaultTrainer.Serializer):
274            >>>         trainer: MyTrainer
275            >>>         def dump_the_answer(self, kwarg_name: str):
276            >>>             assert kwarg_name == "the_answer"
277            >>>             self.init_data.update({
278            >>>                 "the": self.trainer.the_answer // 10,
279            >>>                 "answer": self.trainer.the_answer % 10
280            >>>             })
281            >>>
282            >>>     class Deserializer(DefaultTrainer.Deserializer):
283            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
284            >>>             assert kwarg_name == "the_answer"
285            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
286            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
287        """
288
289        def __init__(self, trainer: DefaultTrainer):
290            self.trainer = trainer
291            self.init_data = {}  # to be populated during serialization process
292
293        def dump(self, kwarg_name: str) -> None:
294            dumper = getattr(self, f"dump_{kwarg_name}", None)
295            if dumper is not None:
296                dumper(kwarg_name)
297            elif kwarg_name.endswith("_loader"):
298                self.dump_data_loader(kwarg_name)
299            elif kwarg_name.endswith("_class"):
300                self.dump_generic_class(kwarg_name)
301            elif not hasattr(self.trainer, kwarg_name):
302                raise AttributeError(
303                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
304                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
305                )
306            else:
307                assert hasattr(self.trainer, kwarg_name)
308                obj = getattr(self.trainer, kwarg_name)
309                if obj is None or type(obj) in (
310                    bool,
311                    bytearray,
312                    bytes,
313                    dict,
314                    float,
315                    frozenset,
316                    int,
317                    list,
318                    set,
319                    str,
320                    tuple,
321                ):
322                    self.dump_generic_builtin(kwarg_name)
323                else:
324                    self.dump_generic_instance(kwarg_name)
325
326        def dump_generic_builtin(self, kwarg_name: str) -> None:
327            assert hasattr(self.trainer, kwarg_name)
328            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
329
330        def dump_generic_class(self, kwarg_name: str) -> None:
331            assert hasattr(self.trainer, kwarg_name)
332            assert kwarg_name.endswith("_class")
333            obj = getattr(self.trainer, kwarg_name)
334            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
335
336        def dump_generic_instance(self, kwarg_name: str) -> None:
337            assert hasattr(self.trainer, kwarg_name)
338            instance = getattr(self.trainer, kwarg_name)
339            self.init_data.update(
340                {
341                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
342                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
343                }
344            )
345
346        def dump_device(self, kwarg_name: str):
347            assert hasattr(self.trainer, kwarg_name)
348            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
349
350        def dump_data_loader(self, kwarg_name: str) -> None:
351            assert hasattr(self.trainer, kwarg_name)
352            loader = getattr(self.trainer, kwarg_name)
353            if loader is None:
354                return
355            self.init_data.update(
356                {
357                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
358                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
359                }
360            )
361
362        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
363            self.dump_generic_class(f"{kwarg_name}_class")
364
365        def dump_model(self, kwarg_name: str):
366            if is_compiled(self.trainer.model):
367                self.init_data.update(
368                    {
369                        "model_class": self.trainer._model_class,
370                        "model_kwargs": self.trainer._model_kwargs,
371                    }
372                )
373            else:
374                self.dump_generic_instance("model")

Implements how to serialize trainer kwargs from a trainer instance

Examples:

To extend the serialization process you can inherite from this Serializer in a derived Trainer class. Note that the methods dump_generic_builtin(), dump_generic_class() and dump_generic_instance() called by the dump() method when appropriate cover most cases already.

This example adds the_answer kwarg, which requires extra steps on dumping only because we don't keep a 'the_answer' attribute:

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
>>>         # but let's make things more interesting...
>>>         self.the = the_answer // 10
>>>         self.answer = the_answer % 10
>>>
>>>     class Serializer(DefaultTrainer.Serializer):
>>>         trainer: MyTrainer
>>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
>>>             assert kwarg_name == "the_answer"
>>>             # populate self.init_data with the serialized data required by Deserializer
>>>             # to restore the trainer kwargs
>>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer

This example with both Serializer and Deserializer adds the_answer kwarg, while saving it in two separate entries 'the' and 'answer'

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         self.the_answer = the_answer
>>>
>>>     class Serializer(DefaultTrainer.Serializer):
>>>         trainer: MyTrainer
>>>         def dump_the_answer(self, kwarg_name: str):
>>>             assert kwarg_name == "the_answer"
>>>             self.init_data.update({
>>>                 "the": self.trainer.the_answer // 10,
>>>                 "answer": self.trainer.the_answer % 10
>>>             })
>>>
>>>     class Deserializer(DefaultTrainer.Deserializer):
>>>         def load_the_answer(self, kwarg_name: str, optional: bool):
>>>             assert kwarg_name == "the_answer"
>>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
>>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
DefaultTrainer.Serializer(trainer: DefaultTrainer)
289        def __init__(self, trainer: DefaultTrainer):
290            self.trainer = trainer
291            self.init_data = {}  # to be populated during serialization process
trainer
init_data
def dump(self, kwarg_name: str) -> None:
293        def dump(self, kwarg_name: str) -> None:
294            dumper = getattr(self, f"dump_{kwarg_name}", None)
295            if dumper is not None:
296                dumper(kwarg_name)
297            elif kwarg_name.endswith("_loader"):
298                self.dump_data_loader(kwarg_name)
299            elif kwarg_name.endswith("_class"):
300                self.dump_generic_class(kwarg_name)
301            elif not hasattr(self.trainer, kwarg_name):
302                raise AttributeError(
303                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
304                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
305                )
306            else:
307                assert hasattr(self.trainer, kwarg_name)
308                obj = getattr(self.trainer, kwarg_name)
309                if obj is None or type(obj) in (
310                    bool,
311                    bytearray,
312                    bytes,
313                    dict,
314                    float,
315                    frozenset,
316                    int,
317                    list,
318                    set,
319                    str,
320                    tuple,
321                ):
322                    self.dump_generic_builtin(kwarg_name)
323                else:
324                    self.dump_generic_instance(kwarg_name)
def dump_generic_builtin(self, kwarg_name: str) -> None:
326        def dump_generic_builtin(self, kwarg_name: str) -> None:
327            assert hasattr(self.trainer, kwarg_name)
328            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
def dump_generic_class(self, kwarg_name: str) -> None:
330        def dump_generic_class(self, kwarg_name: str) -> None:
331            assert hasattr(self.trainer, kwarg_name)
332            assert kwarg_name.endswith("_class")
333            obj = getattr(self.trainer, kwarg_name)
334            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
def dump_generic_instance(self, kwarg_name: str) -> None:
336        def dump_generic_instance(self, kwarg_name: str) -> None:
337            assert hasattr(self.trainer, kwarg_name)
338            instance = getattr(self.trainer, kwarg_name)
339            self.init_data.update(
340                {
341                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
342                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
343                }
344            )
def dump_device(self, kwarg_name: str):
346        def dump_device(self, kwarg_name: str):
347            assert hasattr(self.trainer, kwarg_name)
348            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
def dump_data_loader(self, kwarg_name: str) -> None:
350        def dump_data_loader(self, kwarg_name: str) -> None:
351            assert hasattr(self.trainer, kwarg_name)
352            loader = getattr(self.trainer, kwarg_name)
353            if loader is None:
354                return
355            self.init_data.update(
356                {
357                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
358                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
359                }
360            )
def dump_logger(self, kwarg_name: str):
362        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
363            self.dump_generic_class(f"{kwarg_name}_class")
def dump_model(self, kwarg_name: str):
365        def dump_model(self, kwarg_name: str):
366            if is_compiled(self.trainer.model):
367                self.init_data.update(
368                    {
369                        "model_class": self.trainer._model_class,
370                        "model_kwargs": self.trainer._model_kwargs,
371                    }
372                )
373            else:
374                self.dump_generic_instance("model")