torch_em.trainer.default_trainer

  1from __future__ import annotations
  2
  3import os
  4import time
  5import inspect
  6import warnings
  7import contextlib
  8from tqdm import tqdm
  9from functools import partial
 10from datetime import datetime
 11from collections import OrderedDict
 12from importlib import import_module
 13from typing import Any, Callable, Dict, Optional, Union, Literal
 14
 15import numpy as np
 16
 17import torch
 18
 19from .wandb_logger import WandbLogger
 20from .tensorboard_logger import TensorboardLogger
 21from ..util import auto_compile, get_constructor_arguments, is_compiled
 22
 23
 24class DefaultTrainer:
 25    """Trainer class for training a segmentation network.
 26
 27    The trainer class implements the core logic for training a network with pytorch.
 28    It implements a training loop to run training and validation, which is started with `fit`.
 29    The checkpoints and logs from the training run will be saved in the current working directory,
 30    or in the directory specifified by `save_root`. Training can be continued from a checkpoint
 31    by passing it's location to the `load_from_checkpoint` argument of `fit`.
 32
 33    A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`.
 34    Alternatively, the trainer class can also be instantiated as in this example:
 35    ```python
 36    import torch
 37    from torch_em.loss import DiceLoss
 38    from torch_em.model import UNet2d
 39    from torch_em.data.datasets.light_microscopy import get_dsb_loader
 40    from torch_em.trainer import DefaultTrainer
 41
 42    # The training data will be downloaded to this location.
 43    data_root = "/path/to/save/the/training/data"
 44    patch_shape = (256, 256)
 45
 46    # Create the model and optimizer.
 47    model = UNet2d(in_channels=1, out_channels=1)
 48    optimizer = torch.optim.AdamW(model.parameters())
 49
 50    trainer = DefaultTrainer(
 51        name="unet-training",
 52        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
 53        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
 54        model=model,
 55        loss=DiceLoss(),  # The loss function.
 56        optimizer=optimizer,
 57        metric=DiceLoss(),  # The metric. The trainer expects smaller values to represent better results.
 58        device="cuda",  # The device to use for training.
 59    )
 60    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
 61    ```
 62
 63    Args:
 64        name: The name of the checkpoint that will be created by the trainer.
 65        train_loader: The data loader containing the training data.
 66        val_loader: The data loader containing the validation data.
 67        model: The model to train.
 68        loss: The loss function for training.
 69        optimizer: The optimizer.
 70        metric: The metric for validation.
 71        device: The torch device to use for training. If None, will use a GPU if available.
 72        lr_scheduler: The learning rate scheduler.
 73        log_image_interval: The interval for saving images during logging, in training iterations.
 74        mixed_precision: Whether to train with mixed precision.
 75        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
 76        logger: The logger class. Will be instantiated for logging.
 77            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
 78        logger_kwargs: The keyword arguments for the logger class.
 79        id_: Unique identifier for the trainer. If None then `name` will be used.
 80        save_root: The root folder for saving the checkpoint and logs.
 81        compile_model: Whether to compile the model before training.
 82        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
 83    """
 84    def __init__(
 85        self,
 86        name: Optional[str],
 87        train_loader: torch.utils.data.DataLoader,
 88        val_loader: torch.utils.data.DataLoader,
 89        model: torch.nn.Module,
 90        loss: torch.nn.Module,
 91        optimizer: torch.optim.Optimizer,
 92        metric: Callable,
 93        device: Union[str, torch.device],
 94        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
 95        log_image_interval: int = 100,
 96        mixed_precision: bool = True,
 97        early_stopping: Optional[int] = None,
 98        logger=TensorboardLogger,
 99        logger_kwargs: Optional[Dict[str, Any]] = None,
100        id_: Optional[str] = None,
101        save_root: Optional[str] = None,
102        compile_model: Optional[Union[bool, str]] = None,
103        rank: Optional[int] = None,
104    ):
105        if name is None and not issubclass(logger, WandbLogger):
106            raise TypeError("Name cannot be None if not using the WandbLogger")
107
108        if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
109            raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.")
110
111        self._generate_name = name is None
112        self.name = name
113        self.id_ = id_ or name
114        self.train_loader = train_loader
115        self.val_loader = val_loader
116        self.model = model
117        self.loss = loss
118        self.optimizer = optimizer
119        self.metric = metric
120        self.device = torch.device(device)
121        self.lr_scheduler = lr_scheduler
122        self.log_image_interval = log_image_interval
123        self.save_root = save_root
124        self.compile_model = compile_model
125        self.rank = rank
126
127        self._iteration = 0
128        self._epoch = 0
129        self._best_epoch = 0
130
131        self.mixed_precision = mixed_precision
132        self.early_stopping = early_stopping
133        self.train_time = 0.0
134
135        if mixed_precision:
136            self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda")
137        else:
138            self.scaler = None
139
140        self.logger_class = logger
141        self.logger_kwargs = logger_kwargs
142        self.log_image_interval = log_image_interval
143
144    @property
145    def checkpoint_folder(self):
146        assert self.id_ is not None  # Because the logger may generate and set trainer.id on logger.__init__.
147        # Save_root enables saving the checkpoints somewhere else than in the local older.
148        # This is handy for filesystems with limited space, where saving the checkpoints
149        # and log files can ead to running out of space.
150        save_root = getattr(self, "save_root", None)
151        return os.path.join("./checkpoints", self.id_) if save_root is None else\
152            os.path.join(save_root, "./checkpoints", self.id_)
153
154    @property
155    def iteration(self):
156        return self._iteration
157
158    @property
159    def epoch(self):
160        return self._epoch
161
162    class Deserializer:
163        """Determines how to deserialize the trainer kwargs from serialized 'init_data'.
164
165        Examples:
166            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
167            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
168
169            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
170            >>> class MyTrainer(DefaultTrainer):
171            >>>     def __init__(self, *args, the_answer: int, **kwargs):
172            >>>         super().__init__(*args, **kwargs)
173            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
174            >>>                                       # see DefaultTrainer.Serializer
175            >>>
176            >>>     class Deserializer(DefaultTrainer.Deserializer):
177            >>>         def load_the_answer(self):
178            >>>             generic_answer = self.init_data["the_answer"]
179            >>>             # (device dependent) special deserialization
180            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
181            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
182            >>>             else:
183            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
184
185        Args:
186            init_data: The initialization data of the trainer.
187            save_path: The path where the checkpoint was saved.
188            device: The device.
189        """
190
191        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
192            self.init_data = init_data
193            self.save_path = save_path
194            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
195            self.trainer_kwargs: Dict[str, Any] = dict(
196                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
197            )
198
199        def load(self, kwarg_name: str, optional):
200            """@private
201            """
202            # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
203            if kwarg_name == "device":
204                pass  # deserialized in __init__
205            elif kwarg_name.endswith("_loader"):
206                self.load_data_loader(kwarg_name, optional)
207            else:
208                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
209                load(kwarg_name, optional=optional)
210
211        def load_data_loader(self, loader_name, optional) -> None:
212            """@private
213            """
214            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
215            if ds is None and optional:
216                return
217
218            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
219            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
220            # monkey patch shuffle loader_name to the loader
221            loader.shuffle = loader_kwargs.get("shuffle", False)
222            self.trainer_kwargs[loader_name] = loader
223
224        def load_generic(
225            self,
226            kwarg_name: str,
227            *dynamic_args: Dict,
228            optional: bool,
229            only_class: bool = False,
230            dynamic_kwargs: Optional[Dict[str, Any]] = None,
231        ) -> None:
232            """@private
233            """
234            if kwarg_name in self.init_data:
235                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
236                return
237
238            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
239            if this_cls is None:
240                if optional:
241                    return
242                else:
243                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
244
245            assert isinstance(this_cls, str), this_cls
246            assert "." in this_cls, this_cls
247            cls_p, cls_m = this_cls.rsplit(".", 1)
248            this_cls = getattr(import_module(cls_p), cls_m)
249            if only_class:
250                self.trainer_kwargs[kwarg_name] = this_cls
251            else:
252                self.trainer_kwargs[kwarg_name] = this_cls(
253                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
254                )
255
256        def load_name(self, kwarg_name: str, optional: bool):
257            """@private
258            """
259            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
260
261        def load_optimizer(self, kwarg_name: str, optional: bool):
262            """@private
263            """
264            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
265
266        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
267            """@private
268            """
269            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
270
271        # todo: remove and rename kwarg 'logger' to 'logger_class'
272        def load_logger(self, kwarg_name: str, optional: bool):
273            """@private
274            """
275            assert kwarg_name == "logger"
276            self.load_generic("logger", optional=optional, only_class=True)
277
278    @staticmethod
279    def _get_save_dict(save_path, device):
280        if not os.path.exists(save_path):
281            raise ValueError(f"Cannot find checkpoint {save_path}")
282        return torch.load(save_path, map_location=device, weights_only=False)
283
284    @classmethod
285    def from_checkpoint(
286        cls,
287        checkpoint_folder: Union[os.PathLike, str],
288        name: Literal["best", "latest"] = "best",
289        device: Optional[Union[str, torch.device]] = None,
290    ):
291        """@private
292        """
293        save_path = os.path.join(checkpoint_folder, f"{name}.pt")
294        # make sure the correct device is set if we don't have access to CUDA
295        if not torch.cuda.is_available():
296            device = "cpu"
297        save_dict = cls._get_save_dict(save_path, device)
298        deserializer = cls.Deserializer(save_dict["init"], save_path, device)
299
300        has_kwargs = False
301        deserialized = []
302        for name, parameter in inspect.signature(cls).parameters.items():
303            if name == "kwargs":
304                has_kwargs = True
305                continue
306            deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty)
307            deserialized.append(name)
308
309        # to deserialze kwargs we can't rely on inspecting the signature, so we
310        # go through the remaning kwarg names in init data instead
311        if has_kwargs:
312            kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized))
313            for name in kwarg_names:
314                if name.endswith("_kwargs"):
315                    continue
316                elif name.endswith("_dataset"):
317                    deserializer.load(name.replace("dataset", "loader"), optional=False)
318                elif name.endswith("_class"):
319                    deserializer.load(name.replace("_class", ""), optional=False)
320                else:
321                    deserializer.load(name, optional=False)
322
323        trainer = cls(**deserializer.trainer_kwargs)
324        trainer._initialize(0, save_dict)
325        trainer._is_initialized = True
326        return trainer
327
328    class Serializer:
329        """Implements how to serialize trainer kwargs from a trainer instance.
330
331        Examples:
332            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
333            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
334            called by the `dump()` method when appropriate cover most cases already.
335
336            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
337            'the_answer' attribute:
338            >>> class MyTrainer(DefaultTrainer):
339            >>>     def __init__(self, *args, the_answer: int, **kwargs):
340            >>>         super().__init__(*args, **kwargs)
341            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
342            >>>         # but let's make things more interesting...
343            >>>         self.the = the_answer // 10
344            >>>         self.answer = the_answer % 10
345            >>>
346            >>>     class Serializer(DefaultTrainer.Serializer):
347            >>>         trainer: MyTrainer
348            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
349            >>>             assert kwarg_name == "the_answer"
350            >>>             # populate self.init_data with the serialized data required by Deserializer
351            >>>             # to restore the trainer kwargs
352            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
353
354            This example with both Serializer and Deserializer adds `the_answer` kwarg,
355            while saving it in two separate entries 'the' and 'answer'
356            >>> class MyTrainer(DefaultTrainer):
357            >>>     def __init__(self, *args, the_answer: int, **kwargs):
358            >>>         super().__init__(*args, **kwargs)
359            >>>         self.the_answer = the_answer
360            >>>
361            >>>     class Serializer(DefaultTrainer.Serializer):
362            >>>         trainer: MyTrainer
363            >>>         def dump_the_answer(self, kwarg_name: str):
364            >>>             assert kwarg_name == "the_answer"
365            >>>             self.init_data.update({
366            >>>                 "the": self.trainer.the_answer // 10,
367            >>>                 "answer": self.trainer.the_answer % 10
368            >>>             })
369            >>>
370            >>>     class Deserializer(DefaultTrainer.Deserializer):
371            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
372            >>>             assert kwarg_name == "the_answer"
373            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
374            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
375
376        Args:
377            trainer: The trainer instance.
378        """
379
380        def __init__(self, trainer: DefaultTrainer):
381            self.trainer = trainer
382            self.init_data = {}  # to be populated during serialization process
383
384        def dump(self, kwarg_name: str) -> None:
385            """@private
386            """
387            dumper = getattr(self, f"dump_{kwarg_name}", None)
388            if dumper is not None:
389                dumper(kwarg_name)
390            elif kwarg_name.endswith("_loader"):
391                self.dump_data_loader(kwarg_name)
392            elif kwarg_name.endswith("_class"):
393                self.dump_generic_class(kwarg_name)
394            elif not hasattr(self.trainer, kwarg_name):
395                raise AttributeError(
396                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
397                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
398                )
399            else:
400                assert hasattr(self.trainer, kwarg_name)
401                obj = getattr(self.trainer, kwarg_name)
402                if obj is None or type(obj) in (
403                    bool,
404                    bytearray,
405                    bytes,
406                    dict,
407                    float,
408                    frozenset,
409                    int,
410                    list,
411                    set,
412                    str,
413                    tuple,
414                ):
415                    self.dump_generic_builtin(kwarg_name)
416                else:
417                    self.dump_generic_instance(kwarg_name)
418
419        def dump_generic_builtin(self, kwarg_name: str) -> None:
420            """@private
421            """
422            assert hasattr(self.trainer, kwarg_name)
423            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
424
425        def dump_generic_class(self, kwarg_name: str) -> None:
426            """@private
427            """
428            assert hasattr(self.trainer, kwarg_name)
429            assert kwarg_name.endswith("_class")
430            obj = getattr(self.trainer, kwarg_name)
431            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
432
433        def dump_generic_instance(self, kwarg_name: str) -> None:
434            """@private
435            """
436            assert hasattr(self.trainer, kwarg_name)
437            instance = getattr(self.trainer, kwarg_name)
438            self.init_data.update(
439                {
440                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
441                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
442                }
443            )
444
445        def dump_device(self, kwarg_name: str):
446            """@private
447            """
448            assert hasattr(self.trainer, kwarg_name)
449            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
450
451        def dump_data_loader(self, kwarg_name: str) -> None:
452            """@private
453            """
454            assert hasattr(self.trainer, kwarg_name)
455            loader = getattr(self.trainer, kwarg_name)
456            if loader is None:
457                return
458            self.init_data.update(
459                {
460                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
461                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
462                }
463            )
464
465        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
466            """@private
467            """
468            self.dump_generic_class(f"{kwarg_name}_class")
469
470        def dump_model(self, kwarg_name: str):
471            """@private
472            """
473            if is_compiled(self.trainer.model):
474                self.init_data.update(
475                    {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs}
476                )
477            else:
478                self.dump_generic_instance("model")
479
480    def _build_init(self) -> Dict[str, Any]:
481        serializer = self.Serializer(self)
482        for name in inspect.signature(self.__class__).parameters:
483            # special rules to serialize kwargs
484            # if a trainer class inherits from DefaultTrainer and has **kwargs
485            # they need to be saved in self._kwargs
486            if name == "kwargs":
487                if not hasattr(self, "_kwargs"):
488                    msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\
489                          "Please add self._kwargs to its __init__ function"
490                    raise RuntimeError(msg)
491                kwargs = getattr(self, "_kwargs")
492                for kwarg_name in kwargs:
493                    serializer.dump(kwarg_name)
494                continue
495            serializer.dump(name)
496
497        return serializer.init_data
498
499    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
500        assert self.train_loader is not None
501        assert self.val_loader is not None
502        assert self.model is not None
503        assert self.loss is not None
504        assert self.optimizer is not None
505        assert self.metric is not None
506        assert self.device is not None
507
508        if load_from_checkpoint is not None:
509            self.load_checkpoint(load_from_checkpoint)
510
511        if sum((iterations is not None, epochs is not None)) != 1:
512            raise ValueError(
513                "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer."
514                f"You have passed 'iterations'={iterations} and 'epochs'={epochs}"
515            )
516
517        if epochs is None:
518            epochs = int(np.ceil(float(iterations) / len(self.train_loader)))
519        else:
520            iterations = epochs * len(self.train_loader)
521
522        self.max_iteration = self._iteration + iterations
523        self.max_epoch = self._epoch + epochs
524
525        if not getattr(self, "_is_initialized", False):
526            # check if we compile the model (only supported by pytorch 2)
527            # to enable (de)serialization of compiled models, we keep track of the model class and kwargs
528            if is_compiled(self.model):
529                warnings.warn(
530                    "You have passed a compiled model to the trainer."
531                    "It will not be possible to (de)serialize the trainer with it."
532                    "If you want to be able to do this please pass the normal model."
533                    "It can be automatically compiled by setting 'compile_model' to True"
534                )
535            self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}"
536            self._model_kwargs = get_constructor_arguments(self.model)
537            self.model = auto_compile(self.model, self.compile_model)
538
539            self.model.to(self.device)
540            self.loss.to(self.device)
541
542            # this saves all the information that is necessary
543            # to fully load the trainer from the checkpoint
544            self.init_data = self._build_init()
545
546            if self.logger_class is None:
547                self.logger = None
548            else:
549                # may set self.name if self.name is None
550                save_root = getattr(self, "save_root", None)
551                try:
552                    self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {}))
553                except PermissionError:
554                    warnings.warn(
555                        f"The checkpoint folder at {self.checkpoint_folder} could not be created."
556                        "The most likely reason for this is that you copied the checkpoint somewhere else,"
557                        "so we skip this error to enable loading the model from this checkpoint."
558                    )
559
560            try:
561                os.makedirs(self.checkpoint_folder, exist_ok=True)
562            except PermissionError:
563                warnings.warn(
564                    f"The checkpoint folder at {self.checkpoint_folder} could not be created."
565                    "The most likely reason for this is that you copied the checkpoint somewhere else,"
566                    "so we skip this error to enable loading the model from this checkpoint."
567                )
568                pass
569
570        best_metric = np.inf
571        return best_metric
572
573    def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
574        """@private
575        """
576        save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
577        extra_init_dict = extra_save_dict.pop("init", {})
578        save_dict = {
579            "iteration": self._iteration,
580            "epoch": self._epoch,
581            "best_epoch": self._best_epoch,
582            "best_metric": best_metric,
583            "current_metric": current_metric,
584            "model_state": self.model.state_dict(),
585            "optimizer_state": self.optimizer.state_dict(),
586            "init": self.init_data | extra_init_dict,
587            "train_time": train_time,
588            "timestamp": datetime.now().strftime("%d-%m-%Y (%H:%M:%S)"),
589        }
590        save_dict.update(**extra_save_dict)
591        if self.scaler is not None:
592            save_dict.update({"scaler_state": self.scaler.state_dict()})
593        if self.lr_scheduler is not None:
594            save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()})
595
596        rank = getattr(self, "rank", None)
597        if rank is None or rank == 0:
598            torch.save(save_dict, save_path)
599
600    def load_checkpoint(self, checkpoint="best"):
601        """@private
602        """
603        if isinstance(checkpoint, str):
604            save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
605            if not os.path.exists(save_path):
606                warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.")
607                return
608            save_dict = torch.load(save_path, weights_only=False)
609        elif isinstance(checkpoint, dict):
610            save_dict = checkpoint
611        else:
612            raise RuntimeError
613
614        self._iteration = save_dict["iteration"]
615        self._epoch = save_dict["epoch"]
616        self._best_epoch = save_dict["best_epoch"]
617        self.best_metric = save_dict["best_metric"]
618        self.current_metric = save_dict["current_metric"]
619        self.train_time = save_dict.get("train_time", 0.0)
620
621        model_state = save_dict["model_state"]
622        # to enable loading compiled models
623        compiled_prefix = "_orig_mod."
624        model_state = OrderedDict(
625            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()]
626        )
627        self.model.load_state_dict(model_state)
628        # we need to send the network to the device before loading the optimizer state!
629        self.model.to(self.device)
630
631        self.optimizer.load_state_dict(save_dict["optimizer_state"])
632        if self.scaler is not None:
633            self.scaler.load_state_dict(save_dict["scaler_state"])
634        if self.lr_scheduler is not None:
635            self.lr_scheduler.load_state_dict(save_dict["scheduler_state"])
636
637        return save_dict
638
639    def _verify_if_training_completed(self, checkpoint="latest"):
640        save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
641        save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None
642        if save_dict and self.max_iteration == save_dict.get("iteration"):
643            return True
644        return False
645
646    def fit(
647        self,
648        iterations: Optional[int] = None,
649        load_from_checkpoint: Optional[Union[os.PathLike, str]] = None,
650        epochs: Optional[int] = None,
651        save_every_kth_epoch: Optional[int] = None,
652        progress=None,
653        overwrite_training: bool = True,
654    ):
655        """Run neural network training.
656
657        Exactly one of 'iterations' or 'epochs' has to be passed.
658
659        Args:
660            iterations: How long to train, specified in iterations.
661            load_from_checkpoint: Path to a checkpoint from where training should be continued .
662            epochs: How long to train, specified in epochs.
663            save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file.
664                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
665            progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
666            overwrite_training: Whether to overwrite existing checkpoints in the save directory.
667        """
668        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
669
670        if not overwrite_training:
671            if load_from_checkpoint is not None:
672                raise ValueError(
673                    "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time."
674                )
675
676            if self._verify_if_training_completed():
677                print(
678                    f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs "
679                    "and 'overwrite_training' is set to 'False'."
680                )
681                print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.")
682                return
683
684        print(
685            "Start fitting for",
686            self.max_iteration - self._iteration,
687            "iterations / ",
688            self.max_epoch - self._epoch,
689            "epochs",
690        )
691        print("with", len(self.train_loader), "iterations per epoch")
692
693        if self.mixed_precision:
694            train_epoch = self._train_epoch_mixed
695            validate = self._validate_mixed
696            print("Training with mixed precision")
697        else:
698            train_epoch = self._train_epoch
699            validate = self._validate
700            print("Training with single precision")
701
702        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
703        if progress is None:
704            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
705        else:
706            progress.total = total_iterations
707            progress.set_description(f"Epoch {self._epoch}")
708
709        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
710        train_epochs = self.max_epoch - self._epoch
711        t_start = time.time()
712        for epoch in range(train_epochs):
713
714            # Ensure data is shuffled differently at each epoch.
715            try:
716                self.train_loader.sampler.set_epoch(epoch)
717            except AttributeError:
718                pass
719
720            # Run training and validation for this epoch
721            t_per_iter = train_epoch(progress)
722            current_metric = validate()
723
724            # perform all the post-epoch steps:
725
726            # apply the learning rate scheduler
727            if self.lr_scheduler is not None:
728                self.lr_scheduler.step(current_metric)
729
730            # how long did we train in total?
731            total_train_time = (time.time() - t_start) + self.train_time
732
733            # save this checkpoint as the new best checkpoint if
734            # it has the best overall validation metric
735            if current_metric < best_metric:
736                best_metric = current_metric
737                self._best_epoch = self._epoch
738                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
739
740            # save this checkpoint as the latest checkpoint
741            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
742
743            # if we save after every k-th epoch then check if we need to save now
744            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
745                self.save_checkpoint(
746                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
747                )
748
749            # if early stopping has been specified then check if the stopping condition is met
750            if self.early_stopping is not None:
751                epochs_since_best = self._epoch - self._best_epoch
752                if epochs_since_best > self.early_stopping:
753                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
754                    break
755
756            self._epoch += 1
757            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
758
759        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
760        print(f"The best epoch is number {self._best_epoch}.")
761
762        if self._generate_name:
763            self.name = None
764
765        # Update the train time
766        self.train_time = total_train_time
767
768        # TODO save the model to wandb if we have the wandb logger
769        if isinstance(self.logger, WandbLogger):
770            self.logger.get_wandb().finish()
771
772    def _backprop(self, loss):
773        loss.backward()
774        self.optimizer.step()
775
776    def _backprop_mixed(self, loss):
777        self.scaler.scale(loss).backward()
778        self.scaler.step(self.optimizer)
779        self.scaler.update()
780
781    def _train_epoch(self, progress):
782        return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop)
783
784    def _train_epoch_mixed(self, progress):
785        return self._train_epoch_impl(
786            progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"),
787            self._backprop_mixed
788        )
789
790    def _forward_and_loss(self, x, y):
791        pred = self.model(x)
792        if self._iteration % self.log_image_interval == 0:
793            if pred.requires_grad:
794                pred.retain_grad()
795
796        loss = self.loss(pred, y)
797        return pred, loss
798
799    def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]):
800        self.model.train()
801
802        n_iter = 0
803        t_per_iter = time.time()
804        for x, y in self.train_loader:
805            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
806
807            self.optimizer.zero_grad()
808
809            with forward_context():
810                pred, loss = self._forward_and_loss(x, y)
811
812            backprop(loss)
813
814            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
815            if self.logger is not None:
816                self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True)
817
818            self._iteration += 1
819            n_iter += 1
820            if self._iteration >= self.max_iteration:
821                break
822            progress.update(1)
823
824        t_per_iter = (time.time() - t_per_iter) / n_iter
825        return t_per_iter
826
827    def _validate(self):
828        return self._validate_impl(contextlib.nullcontext)
829
830    def _validate_mixed(self):
831        return self._validate_impl(
832            partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda")
833        )
834
835    def _validate_impl(self, forward_context):
836        self.model.eval()
837
838        metric_val = 0.0
839        loss_val = 0.0
840
841        with torch.no_grad():
842            for x, y in self.val_loader:
843                x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
844                with forward_context():
845                    pred, loss = self._forward_and_loss(x, y)
846                    metric = self.metric(pred, y)
847
848                loss_val += loss.item()
849                metric_val += metric.item()
850
851        metric_val /= len(self.val_loader)
852        loss_val /= len(self.val_loader)
853        if self.logger is not None:
854            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred)
855        return metric_val
class DefaultTrainer:
 25class DefaultTrainer:
 26    """Trainer class for training a segmentation network.
 27
 28    The trainer class implements the core logic for training a network with pytorch.
 29    It implements a training loop to run training and validation, which is started with `fit`.
 30    The checkpoints and logs from the training run will be saved in the current working directory,
 31    or in the directory specifified by `save_root`. Training can be continued from a checkpoint
 32    by passing it's location to the `load_from_checkpoint` argument of `fit`.
 33
 34    A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`.
 35    Alternatively, the trainer class can also be instantiated as in this example:
 36    ```python
 37    import torch
 38    from torch_em.loss import DiceLoss
 39    from torch_em.model import UNet2d
 40    from torch_em.data.datasets.light_microscopy import get_dsb_loader
 41    from torch_em.trainer import DefaultTrainer
 42
 43    # The training data will be downloaded to this location.
 44    data_root = "/path/to/save/the/training/data"
 45    patch_shape = (256, 256)
 46
 47    # Create the model and optimizer.
 48    model = UNet2d(in_channels=1, out_channels=1)
 49    optimizer = torch.optim.AdamW(model.parameters())
 50
 51    trainer = DefaultTrainer(
 52        name="unet-training",
 53        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
 54        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
 55        model=model,
 56        loss=DiceLoss(),  # The loss function.
 57        optimizer=optimizer,
 58        metric=DiceLoss(),  # The metric. The trainer expects smaller values to represent better results.
 59        device="cuda",  # The device to use for training.
 60    )
 61    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
 62    ```
 63
 64    Args:
 65        name: The name of the checkpoint that will be created by the trainer.
 66        train_loader: The data loader containing the training data.
 67        val_loader: The data loader containing the validation data.
 68        model: The model to train.
 69        loss: The loss function for training.
 70        optimizer: The optimizer.
 71        metric: The metric for validation.
 72        device: The torch device to use for training. If None, will use a GPU if available.
 73        lr_scheduler: The learning rate scheduler.
 74        log_image_interval: The interval for saving images during logging, in training iterations.
 75        mixed_precision: Whether to train with mixed precision.
 76        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
 77        logger: The logger class. Will be instantiated for logging.
 78            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
 79        logger_kwargs: The keyword arguments for the logger class.
 80        id_: Unique identifier for the trainer. If None then `name` will be used.
 81        save_root: The root folder for saving the checkpoint and logs.
 82        compile_model: Whether to compile the model before training.
 83        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
 84    """
 85    def __init__(
 86        self,
 87        name: Optional[str],
 88        train_loader: torch.utils.data.DataLoader,
 89        val_loader: torch.utils.data.DataLoader,
 90        model: torch.nn.Module,
 91        loss: torch.nn.Module,
 92        optimizer: torch.optim.Optimizer,
 93        metric: Callable,
 94        device: Union[str, torch.device],
 95        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
 96        log_image_interval: int = 100,
 97        mixed_precision: bool = True,
 98        early_stopping: Optional[int] = None,
 99        logger=TensorboardLogger,
100        logger_kwargs: Optional[Dict[str, Any]] = None,
101        id_: Optional[str] = None,
102        save_root: Optional[str] = None,
103        compile_model: Optional[Union[bool, str]] = None,
104        rank: Optional[int] = None,
105    ):
106        if name is None and not issubclass(logger, WandbLogger):
107            raise TypeError("Name cannot be None if not using the WandbLogger")
108
109        if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
110            raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.")
111
112        self._generate_name = name is None
113        self.name = name
114        self.id_ = id_ or name
115        self.train_loader = train_loader
116        self.val_loader = val_loader
117        self.model = model
118        self.loss = loss
119        self.optimizer = optimizer
120        self.metric = metric
121        self.device = torch.device(device)
122        self.lr_scheduler = lr_scheduler
123        self.log_image_interval = log_image_interval
124        self.save_root = save_root
125        self.compile_model = compile_model
126        self.rank = rank
127
128        self._iteration = 0
129        self._epoch = 0
130        self._best_epoch = 0
131
132        self.mixed_precision = mixed_precision
133        self.early_stopping = early_stopping
134        self.train_time = 0.0
135
136        if mixed_precision:
137            self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda")
138        else:
139            self.scaler = None
140
141        self.logger_class = logger
142        self.logger_kwargs = logger_kwargs
143        self.log_image_interval = log_image_interval
144
145    @property
146    def checkpoint_folder(self):
147        assert self.id_ is not None  # Because the logger may generate and set trainer.id on logger.__init__.
148        # Save_root enables saving the checkpoints somewhere else than in the local older.
149        # This is handy for filesystems with limited space, where saving the checkpoints
150        # and log files can ead to running out of space.
151        save_root = getattr(self, "save_root", None)
152        return os.path.join("./checkpoints", self.id_) if save_root is None else\
153            os.path.join(save_root, "./checkpoints", self.id_)
154
155    @property
156    def iteration(self):
157        return self._iteration
158
159    @property
160    def epoch(self):
161        return self._epoch
162
163    class Deserializer:
164        """Determines how to deserialize the trainer kwargs from serialized 'init_data'.
165
166        Examples:
167            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
168            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
169
170            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
171            >>> class MyTrainer(DefaultTrainer):
172            >>>     def __init__(self, *args, the_answer: int, **kwargs):
173            >>>         super().__init__(*args, **kwargs)
174            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
175            >>>                                       # see DefaultTrainer.Serializer
176            >>>
177            >>>     class Deserializer(DefaultTrainer.Deserializer):
178            >>>         def load_the_answer(self):
179            >>>             generic_answer = self.init_data["the_answer"]
180            >>>             # (device dependent) special deserialization
181            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
182            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
183            >>>             else:
184            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
185
186        Args:
187            init_data: The initialization data of the trainer.
188            save_path: The path where the checkpoint was saved.
189            device: The device.
190        """
191
192        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
193            self.init_data = init_data
194            self.save_path = save_path
195            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
196            self.trainer_kwargs: Dict[str, Any] = dict(
197                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
198            )
199
200        def load(self, kwarg_name: str, optional):
201            """@private
202            """
203            # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
204            if kwarg_name == "device":
205                pass  # deserialized in __init__
206            elif kwarg_name.endswith("_loader"):
207                self.load_data_loader(kwarg_name, optional)
208            else:
209                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
210                load(kwarg_name, optional=optional)
211
212        def load_data_loader(self, loader_name, optional) -> None:
213            """@private
214            """
215            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
216            if ds is None and optional:
217                return
218
219            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
220            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
221            # monkey patch shuffle loader_name to the loader
222            loader.shuffle = loader_kwargs.get("shuffle", False)
223            self.trainer_kwargs[loader_name] = loader
224
225        def load_generic(
226            self,
227            kwarg_name: str,
228            *dynamic_args: Dict,
229            optional: bool,
230            only_class: bool = False,
231            dynamic_kwargs: Optional[Dict[str, Any]] = None,
232        ) -> None:
233            """@private
234            """
235            if kwarg_name in self.init_data:
236                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
237                return
238
239            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
240            if this_cls is None:
241                if optional:
242                    return
243                else:
244                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
245
246            assert isinstance(this_cls, str), this_cls
247            assert "." in this_cls, this_cls
248            cls_p, cls_m = this_cls.rsplit(".", 1)
249            this_cls = getattr(import_module(cls_p), cls_m)
250            if only_class:
251                self.trainer_kwargs[kwarg_name] = this_cls
252            else:
253                self.trainer_kwargs[kwarg_name] = this_cls(
254                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
255                )
256
257        def load_name(self, kwarg_name: str, optional: bool):
258            """@private
259            """
260            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
261
262        def load_optimizer(self, kwarg_name: str, optional: bool):
263            """@private
264            """
265            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
266
267        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
268            """@private
269            """
270            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
271
272        # todo: remove and rename kwarg 'logger' to 'logger_class'
273        def load_logger(self, kwarg_name: str, optional: bool):
274            """@private
275            """
276            assert kwarg_name == "logger"
277            self.load_generic("logger", optional=optional, only_class=True)
278
279    @staticmethod
280    def _get_save_dict(save_path, device):
281        if not os.path.exists(save_path):
282            raise ValueError(f"Cannot find checkpoint {save_path}")
283        return torch.load(save_path, map_location=device, weights_only=False)
284
285    @classmethod
286    def from_checkpoint(
287        cls,
288        checkpoint_folder: Union[os.PathLike, str],
289        name: Literal["best", "latest"] = "best",
290        device: Optional[Union[str, torch.device]] = None,
291    ):
292        """@private
293        """
294        save_path = os.path.join(checkpoint_folder, f"{name}.pt")
295        # make sure the correct device is set if we don't have access to CUDA
296        if not torch.cuda.is_available():
297            device = "cpu"
298        save_dict = cls._get_save_dict(save_path, device)
299        deserializer = cls.Deserializer(save_dict["init"], save_path, device)
300
301        has_kwargs = False
302        deserialized = []
303        for name, parameter in inspect.signature(cls).parameters.items():
304            if name == "kwargs":
305                has_kwargs = True
306                continue
307            deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty)
308            deserialized.append(name)
309
310        # to deserialze kwargs we can't rely on inspecting the signature, so we
311        # go through the remaning kwarg names in init data instead
312        if has_kwargs:
313            kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized))
314            for name in kwarg_names:
315                if name.endswith("_kwargs"):
316                    continue
317                elif name.endswith("_dataset"):
318                    deserializer.load(name.replace("dataset", "loader"), optional=False)
319                elif name.endswith("_class"):
320                    deserializer.load(name.replace("_class", ""), optional=False)
321                else:
322                    deserializer.load(name, optional=False)
323
324        trainer = cls(**deserializer.trainer_kwargs)
325        trainer._initialize(0, save_dict)
326        trainer._is_initialized = True
327        return trainer
328
329    class Serializer:
330        """Implements how to serialize trainer kwargs from a trainer instance.
331
332        Examples:
333            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
334            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
335            called by the `dump()` method when appropriate cover most cases already.
336
337            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
338            'the_answer' attribute:
339            >>> class MyTrainer(DefaultTrainer):
340            >>>     def __init__(self, *args, the_answer: int, **kwargs):
341            >>>         super().__init__(*args, **kwargs)
342            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
343            >>>         # but let's make things more interesting...
344            >>>         self.the = the_answer // 10
345            >>>         self.answer = the_answer % 10
346            >>>
347            >>>     class Serializer(DefaultTrainer.Serializer):
348            >>>         trainer: MyTrainer
349            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
350            >>>             assert kwarg_name == "the_answer"
351            >>>             # populate self.init_data with the serialized data required by Deserializer
352            >>>             # to restore the trainer kwargs
353            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
354
355            This example with both Serializer and Deserializer adds `the_answer` kwarg,
356            while saving it in two separate entries 'the' and 'answer'
357            >>> class MyTrainer(DefaultTrainer):
358            >>>     def __init__(self, *args, the_answer: int, **kwargs):
359            >>>         super().__init__(*args, **kwargs)
360            >>>         self.the_answer = the_answer
361            >>>
362            >>>     class Serializer(DefaultTrainer.Serializer):
363            >>>         trainer: MyTrainer
364            >>>         def dump_the_answer(self, kwarg_name: str):
365            >>>             assert kwarg_name == "the_answer"
366            >>>             self.init_data.update({
367            >>>                 "the": self.trainer.the_answer // 10,
368            >>>                 "answer": self.trainer.the_answer % 10
369            >>>             })
370            >>>
371            >>>     class Deserializer(DefaultTrainer.Deserializer):
372            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
373            >>>             assert kwarg_name == "the_answer"
374            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
375            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
376
377        Args:
378            trainer: The trainer instance.
379        """
380
381        def __init__(self, trainer: DefaultTrainer):
382            self.trainer = trainer
383            self.init_data = {}  # to be populated during serialization process
384
385        def dump(self, kwarg_name: str) -> None:
386            """@private
387            """
388            dumper = getattr(self, f"dump_{kwarg_name}", None)
389            if dumper is not None:
390                dumper(kwarg_name)
391            elif kwarg_name.endswith("_loader"):
392                self.dump_data_loader(kwarg_name)
393            elif kwarg_name.endswith("_class"):
394                self.dump_generic_class(kwarg_name)
395            elif not hasattr(self.trainer, kwarg_name):
396                raise AttributeError(
397                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
398                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
399                )
400            else:
401                assert hasattr(self.trainer, kwarg_name)
402                obj = getattr(self.trainer, kwarg_name)
403                if obj is None or type(obj) in (
404                    bool,
405                    bytearray,
406                    bytes,
407                    dict,
408                    float,
409                    frozenset,
410                    int,
411                    list,
412                    set,
413                    str,
414                    tuple,
415                ):
416                    self.dump_generic_builtin(kwarg_name)
417                else:
418                    self.dump_generic_instance(kwarg_name)
419
420        def dump_generic_builtin(self, kwarg_name: str) -> None:
421            """@private
422            """
423            assert hasattr(self.trainer, kwarg_name)
424            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
425
426        def dump_generic_class(self, kwarg_name: str) -> None:
427            """@private
428            """
429            assert hasattr(self.trainer, kwarg_name)
430            assert kwarg_name.endswith("_class")
431            obj = getattr(self.trainer, kwarg_name)
432            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
433
434        def dump_generic_instance(self, kwarg_name: str) -> None:
435            """@private
436            """
437            assert hasattr(self.trainer, kwarg_name)
438            instance = getattr(self.trainer, kwarg_name)
439            self.init_data.update(
440                {
441                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
442                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
443                }
444            )
445
446        def dump_device(self, kwarg_name: str):
447            """@private
448            """
449            assert hasattr(self.trainer, kwarg_name)
450            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
451
452        def dump_data_loader(self, kwarg_name: str) -> None:
453            """@private
454            """
455            assert hasattr(self.trainer, kwarg_name)
456            loader = getattr(self.trainer, kwarg_name)
457            if loader is None:
458                return
459            self.init_data.update(
460                {
461                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
462                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
463                }
464            )
465
466        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
467            """@private
468            """
469            self.dump_generic_class(f"{kwarg_name}_class")
470
471        def dump_model(self, kwarg_name: str):
472            """@private
473            """
474            if is_compiled(self.trainer.model):
475                self.init_data.update(
476                    {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs}
477                )
478            else:
479                self.dump_generic_instance("model")
480
481    def _build_init(self) -> Dict[str, Any]:
482        serializer = self.Serializer(self)
483        for name in inspect.signature(self.__class__).parameters:
484            # special rules to serialize kwargs
485            # if a trainer class inherits from DefaultTrainer and has **kwargs
486            # they need to be saved in self._kwargs
487            if name == "kwargs":
488                if not hasattr(self, "_kwargs"):
489                    msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\
490                          "Please add self._kwargs to its __init__ function"
491                    raise RuntimeError(msg)
492                kwargs = getattr(self, "_kwargs")
493                for kwarg_name in kwargs:
494                    serializer.dump(kwarg_name)
495                continue
496            serializer.dump(name)
497
498        return serializer.init_data
499
500    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
501        assert self.train_loader is not None
502        assert self.val_loader is not None
503        assert self.model is not None
504        assert self.loss is not None
505        assert self.optimizer is not None
506        assert self.metric is not None
507        assert self.device is not None
508
509        if load_from_checkpoint is not None:
510            self.load_checkpoint(load_from_checkpoint)
511
512        if sum((iterations is not None, epochs is not None)) != 1:
513            raise ValueError(
514                "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer."
515                f"You have passed 'iterations'={iterations} and 'epochs'={epochs}"
516            )
517
518        if epochs is None:
519            epochs = int(np.ceil(float(iterations) / len(self.train_loader)))
520        else:
521            iterations = epochs * len(self.train_loader)
522
523        self.max_iteration = self._iteration + iterations
524        self.max_epoch = self._epoch + epochs
525
526        if not getattr(self, "_is_initialized", False):
527            # check if we compile the model (only supported by pytorch 2)
528            # to enable (de)serialization of compiled models, we keep track of the model class and kwargs
529            if is_compiled(self.model):
530                warnings.warn(
531                    "You have passed a compiled model to the trainer."
532                    "It will not be possible to (de)serialize the trainer with it."
533                    "If you want to be able to do this please pass the normal model."
534                    "It can be automatically compiled by setting 'compile_model' to True"
535                )
536            self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}"
537            self._model_kwargs = get_constructor_arguments(self.model)
538            self.model = auto_compile(self.model, self.compile_model)
539
540            self.model.to(self.device)
541            self.loss.to(self.device)
542
543            # this saves all the information that is necessary
544            # to fully load the trainer from the checkpoint
545            self.init_data = self._build_init()
546
547            if self.logger_class is None:
548                self.logger = None
549            else:
550                # may set self.name if self.name is None
551                save_root = getattr(self, "save_root", None)
552                try:
553                    self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {}))
554                except PermissionError:
555                    warnings.warn(
556                        f"The checkpoint folder at {self.checkpoint_folder} could not be created."
557                        "The most likely reason for this is that you copied the checkpoint somewhere else,"
558                        "so we skip this error to enable loading the model from this checkpoint."
559                    )
560
561            try:
562                os.makedirs(self.checkpoint_folder, exist_ok=True)
563            except PermissionError:
564                warnings.warn(
565                    f"The checkpoint folder at {self.checkpoint_folder} could not be created."
566                    "The most likely reason for this is that you copied the checkpoint somewhere else,"
567                    "so we skip this error to enable loading the model from this checkpoint."
568                )
569                pass
570
571        best_metric = np.inf
572        return best_metric
573
574    def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
575        """@private
576        """
577        save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
578        extra_init_dict = extra_save_dict.pop("init", {})
579        save_dict = {
580            "iteration": self._iteration,
581            "epoch": self._epoch,
582            "best_epoch": self._best_epoch,
583            "best_metric": best_metric,
584            "current_metric": current_metric,
585            "model_state": self.model.state_dict(),
586            "optimizer_state": self.optimizer.state_dict(),
587            "init": self.init_data | extra_init_dict,
588            "train_time": train_time,
589            "timestamp": datetime.now().strftime("%d-%m-%Y (%H:%M:%S)"),
590        }
591        save_dict.update(**extra_save_dict)
592        if self.scaler is not None:
593            save_dict.update({"scaler_state": self.scaler.state_dict()})
594        if self.lr_scheduler is not None:
595            save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()})
596
597        rank = getattr(self, "rank", None)
598        if rank is None or rank == 0:
599            torch.save(save_dict, save_path)
600
601    def load_checkpoint(self, checkpoint="best"):
602        """@private
603        """
604        if isinstance(checkpoint, str):
605            save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
606            if not os.path.exists(save_path):
607                warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.")
608                return
609            save_dict = torch.load(save_path, weights_only=False)
610        elif isinstance(checkpoint, dict):
611            save_dict = checkpoint
612        else:
613            raise RuntimeError
614
615        self._iteration = save_dict["iteration"]
616        self._epoch = save_dict["epoch"]
617        self._best_epoch = save_dict["best_epoch"]
618        self.best_metric = save_dict["best_metric"]
619        self.current_metric = save_dict["current_metric"]
620        self.train_time = save_dict.get("train_time", 0.0)
621
622        model_state = save_dict["model_state"]
623        # to enable loading compiled models
624        compiled_prefix = "_orig_mod."
625        model_state = OrderedDict(
626            [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()]
627        )
628        self.model.load_state_dict(model_state)
629        # we need to send the network to the device before loading the optimizer state!
630        self.model.to(self.device)
631
632        self.optimizer.load_state_dict(save_dict["optimizer_state"])
633        if self.scaler is not None:
634            self.scaler.load_state_dict(save_dict["scaler_state"])
635        if self.lr_scheduler is not None:
636            self.lr_scheduler.load_state_dict(save_dict["scheduler_state"])
637
638        return save_dict
639
640    def _verify_if_training_completed(self, checkpoint="latest"):
641        save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
642        save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None
643        if save_dict and self.max_iteration == save_dict.get("iteration"):
644            return True
645        return False
646
647    def fit(
648        self,
649        iterations: Optional[int] = None,
650        load_from_checkpoint: Optional[Union[os.PathLike, str]] = None,
651        epochs: Optional[int] = None,
652        save_every_kth_epoch: Optional[int] = None,
653        progress=None,
654        overwrite_training: bool = True,
655    ):
656        """Run neural network training.
657
658        Exactly one of 'iterations' or 'epochs' has to be passed.
659
660        Args:
661            iterations: How long to train, specified in iterations.
662            load_from_checkpoint: Path to a checkpoint from where training should be continued .
663            epochs: How long to train, specified in epochs.
664            save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file.
665                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
666            progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
667            overwrite_training: Whether to overwrite existing checkpoints in the save directory.
668        """
669        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
670
671        if not overwrite_training:
672            if load_from_checkpoint is not None:
673                raise ValueError(
674                    "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time."
675                )
676
677            if self._verify_if_training_completed():
678                print(
679                    f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs "
680                    "and 'overwrite_training' is set to 'False'."
681                )
682                print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.")
683                return
684
685        print(
686            "Start fitting for",
687            self.max_iteration - self._iteration,
688            "iterations / ",
689            self.max_epoch - self._epoch,
690            "epochs",
691        )
692        print("with", len(self.train_loader), "iterations per epoch")
693
694        if self.mixed_precision:
695            train_epoch = self._train_epoch_mixed
696            validate = self._validate_mixed
697            print("Training with mixed precision")
698        else:
699            train_epoch = self._train_epoch
700            validate = self._validate
701            print("Training with single precision")
702
703        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
704        if progress is None:
705            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
706        else:
707            progress.total = total_iterations
708            progress.set_description(f"Epoch {self._epoch}")
709
710        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
711        train_epochs = self.max_epoch - self._epoch
712        t_start = time.time()
713        for epoch in range(train_epochs):
714
715            # Ensure data is shuffled differently at each epoch.
716            try:
717                self.train_loader.sampler.set_epoch(epoch)
718            except AttributeError:
719                pass
720
721            # Run training and validation for this epoch
722            t_per_iter = train_epoch(progress)
723            current_metric = validate()
724
725            # perform all the post-epoch steps:
726
727            # apply the learning rate scheduler
728            if self.lr_scheduler is not None:
729                self.lr_scheduler.step(current_metric)
730
731            # how long did we train in total?
732            total_train_time = (time.time() - t_start) + self.train_time
733
734            # save this checkpoint as the new best checkpoint if
735            # it has the best overall validation metric
736            if current_metric < best_metric:
737                best_metric = current_metric
738                self._best_epoch = self._epoch
739                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
740
741            # save this checkpoint as the latest checkpoint
742            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
743
744            # if we save after every k-th epoch then check if we need to save now
745            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
746                self.save_checkpoint(
747                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
748                )
749
750            # if early stopping has been specified then check if the stopping condition is met
751            if self.early_stopping is not None:
752                epochs_since_best = self._epoch - self._best_epoch
753                if epochs_since_best > self.early_stopping:
754                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
755                    break
756
757            self._epoch += 1
758            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
759
760        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
761        print(f"The best epoch is number {self._best_epoch}.")
762
763        if self._generate_name:
764            self.name = None
765
766        # Update the train time
767        self.train_time = total_train_time
768
769        # TODO save the model to wandb if we have the wandb logger
770        if isinstance(self.logger, WandbLogger):
771            self.logger.get_wandb().finish()
772
773    def _backprop(self, loss):
774        loss.backward()
775        self.optimizer.step()
776
777    def _backprop_mixed(self, loss):
778        self.scaler.scale(loss).backward()
779        self.scaler.step(self.optimizer)
780        self.scaler.update()
781
782    def _train_epoch(self, progress):
783        return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop)
784
785    def _train_epoch_mixed(self, progress):
786        return self._train_epoch_impl(
787            progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"),
788            self._backprop_mixed
789        )
790
791    def _forward_and_loss(self, x, y):
792        pred = self.model(x)
793        if self._iteration % self.log_image_interval == 0:
794            if pred.requires_grad:
795                pred.retain_grad()
796
797        loss = self.loss(pred, y)
798        return pred, loss
799
800    def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]):
801        self.model.train()
802
803        n_iter = 0
804        t_per_iter = time.time()
805        for x, y in self.train_loader:
806            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
807
808            self.optimizer.zero_grad()
809
810            with forward_context():
811                pred, loss = self._forward_and_loss(x, y)
812
813            backprop(loss)
814
815            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
816            if self.logger is not None:
817                self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True)
818
819            self._iteration += 1
820            n_iter += 1
821            if self._iteration >= self.max_iteration:
822                break
823            progress.update(1)
824
825        t_per_iter = (time.time() - t_per_iter) / n_iter
826        return t_per_iter
827
828    def _validate(self):
829        return self._validate_impl(contextlib.nullcontext)
830
831    def _validate_mixed(self):
832        return self._validate_impl(
833            partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda")
834        )
835
836    def _validate_impl(self, forward_context):
837        self.model.eval()
838
839        metric_val = 0.0
840        loss_val = 0.0
841
842        with torch.no_grad():
843            for x, y in self.val_loader:
844                x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
845                with forward_context():
846                    pred, loss = self._forward_and_loss(x, y)
847                    metric = self.metric(pred, y)
848
849                loss_val += loss.item()
850                metric_val += metric.item()
851
852        metric_val /= len(self.val_loader)
853        loss_val /= len(self.val_loader)
854        if self.logger is not None:
855            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred)
856        return metric_val

Trainer class for training a segmentation network.

The trainer class implements the core logic for training a network with pytorch. It implements a training loop to run training and validation, which is started with fit. The checkpoints and logs from the training run will be saved in the current working directory, or in the directory specifified by save_root. Training can be continued from a checkpoint by passing it's location to the load_from_checkpoint argument of fit.

A pre-configured instance of the trainer can be obtained from torch_em.default_segmentation_trainer. Alternatively, the trainer class can also be instantiated as in this example:

import torch
from torch_em.loss import DiceLoss
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
from torch_em.trainer import DefaultTrainer

# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)

# Create the model and optimizer.
model = UNet2d(in_channels=1, out_channels=1)
optimizer = torch.optim.AdamW(model.parameters())

trainer = DefaultTrainer(
    name="unet-training",
    train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
    val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
    model=model,
    loss=DiceLoss(),  # The loss function.
    optimizer=optimizer,
    metric=DiceLoss(),  # The metric. The trainer expects smaller values to represent better results.
    device="cuda",  # The device to use for training.
)
trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
Arguments:
  • name: The name of the checkpoint that will be created by the trainer.
  • train_loader: The data loader containing the training data.
  • val_loader: The data loader containing the validation data.
  • model: The model to train.
  • loss: The loss function for training.
  • optimizer: The optimizer.
  • metric: The metric for validation.
  • device: The torch device to use for training. If None, will use a GPU if available.
  • lr_scheduler: The learning rate scheduler.
  • log_image_interval: The interval for saving images during logging, in training iterations.
  • mixed_precision: Whether to train with mixed precision.
  • early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
  • logger: The logger class. Will be instantiated for logging. By default uses torch_em.training.tensorboard_logger.TensorboardLogger.
  • logger_kwargs: The keyword arguments for the logger class.
  • id_: Unique identifier for the trainer. If None then name will be used.
  • save_root: The root folder for saving the checkpoint and logs.
  • compile_model: Whether to compile the model before training.
  • rank: Rank argument for distributed training. See torch_em.multi_gpu_training for details.
DefaultTrainer( name: Optional[str], train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, model: torch.nn.modules.module.Module, loss: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, metric: Callable, device: Union[str, torch.device], lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, log_image_interval: int = 100, mixed_precision: bool = True, early_stopping: Optional[int] = None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, logger_kwargs: Optional[Dict[str, Any]] = None, id_: Optional[str] = None, save_root: Optional[str] = None, compile_model: Union[bool, str, NoneType] = None, rank: Optional[int] = None)
 85    def __init__(
 86        self,
 87        name: Optional[str],
 88        train_loader: torch.utils.data.DataLoader,
 89        val_loader: torch.utils.data.DataLoader,
 90        model: torch.nn.Module,
 91        loss: torch.nn.Module,
 92        optimizer: torch.optim.Optimizer,
 93        metric: Callable,
 94        device: Union[str, torch.device],
 95        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
 96        log_image_interval: int = 100,
 97        mixed_precision: bool = True,
 98        early_stopping: Optional[int] = None,
 99        logger=TensorboardLogger,
100        logger_kwargs: Optional[Dict[str, Any]] = None,
101        id_: Optional[str] = None,
102        save_root: Optional[str] = None,
103        compile_model: Optional[Union[bool, str]] = None,
104        rank: Optional[int] = None,
105    ):
106        if name is None and not issubclass(logger, WandbLogger):
107            raise TypeError("Name cannot be None if not using the WandbLogger")
108
109        if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
110            raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.")
111
112        self._generate_name = name is None
113        self.name = name
114        self.id_ = id_ or name
115        self.train_loader = train_loader
116        self.val_loader = val_loader
117        self.model = model
118        self.loss = loss
119        self.optimizer = optimizer
120        self.metric = metric
121        self.device = torch.device(device)
122        self.lr_scheduler = lr_scheduler
123        self.log_image_interval = log_image_interval
124        self.save_root = save_root
125        self.compile_model = compile_model
126        self.rank = rank
127
128        self._iteration = 0
129        self._epoch = 0
130        self._best_epoch = 0
131
132        self.mixed_precision = mixed_precision
133        self.early_stopping = early_stopping
134        self.train_time = 0.0
135
136        if mixed_precision:
137            self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda")
138        else:
139            self.scaler = None
140
141        self.logger_class = logger
142        self.logger_kwargs = logger_kwargs
143        self.log_image_interval = log_image_interval
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
145    @property
146    def checkpoint_folder(self):
147        assert self.id_ is not None  # Because the logger may generate and set trainer.id on logger.__init__.
148        # Save_root enables saving the checkpoints somewhere else than in the local older.
149        # This is handy for filesystems with limited space, where saving the checkpoints
150        # and log files can ead to running out of space.
151        save_root = getattr(self, "save_root", None)
152        return os.path.join("./checkpoints", self.id_) if save_root is None else\
153            os.path.join(save_root, "./checkpoints", self.id_)
iteration
155    @property
156    def iteration(self):
157        return self._iteration
epoch
159    @property
160    def epoch(self):
161        return self._epoch
def fit( self, iterations: Optional[int] = None, load_from_checkpoint: Union[str, os.PathLike, NoneType] = None, epochs: Optional[int] = None, save_every_kth_epoch: Optional[int] = None, progress=None, overwrite_training: bool = True):
647    def fit(
648        self,
649        iterations: Optional[int] = None,
650        load_from_checkpoint: Optional[Union[os.PathLike, str]] = None,
651        epochs: Optional[int] = None,
652        save_every_kth_epoch: Optional[int] = None,
653        progress=None,
654        overwrite_training: bool = True,
655    ):
656        """Run neural network training.
657
658        Exactly one of 'iterations' or 'epochs' has to be passed.
659
660        Args:
661            iterations: How long to train, specified in iterations.
662            load_from_checkpoint: Path to a checkpoint from where training should be continued .
663            epochs: How long to train, specified in epochs.
664            save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file.
665                The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
666            progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
667            overwrite_training: Whether to overwrite existing checkpoints in the save directory.
668        """
669        best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
670
671        if not overwrite_training:
672            if load_from_checkpoint is not None:
673                raise ValueError(
674                    "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time."
675                )
676
677            if self._verify_if_training_completed():
678                print(
679                    f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs "
680                    "and 'overwrite_training' is set to 'False'."
681                )
682                print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.")
683                return
684
685        print(
686            "Start fitting for",
687            self.max_iteration - self._iteration,
688            "iterations / ",
689            self.max_epoch - self._epoch,
690            "epochs",
691        )
692        print("with", len(self.train_loader), "iterations per epoch")
693
694        if self.mixed_precision:
695            train_epoch = self._train_epoch_mixed
696            validate = self._validate_mixed
697            print("Training with mixed precision")
698        else:
699            train_epoch = self._train_epoch
700            validate = self._validate
701            print("Training with single precision")
702
703        total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
704        if progress is None:
705            progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
706        else:
707            progress.total = total_iterations
708            progress.set_description(f"Epoch {self._epoch}")
709
710        msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
711        train_epochs = self.max_epoch - self._epoch
712        t_start = time.time()
713        for epoch in range(train_epochs):
714
715            # Ensure data is shuffled differently at each epoch.
716            try:
717                self.train_loader.sampler.set_epoch(epoch)
718            except AttributeError:
719                pass
720
721            # Run training and validation for this epoch
722            t_per_iter = train_epoch(progress)
723            current_metric = validate()
724
725            # perform all the post-epoch steps:
726
727            # apply the learning rate scheduler
728            if self.lr_scheduler is not None:
729                self.lr_scheduler.step(current_metric)
730
731            # how long did we train in total?
732            total_train_time = (time.time() - t_start) + self.train_time
733
734            # save this checkpoint as the new best checkpoint if
735            # it has the best overall validation metric
736            if current_metric < best_metric:
737                best_metric = current_metric
738                self._best_epoch = self._epoch
739                self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
740
741            # save this checkpoint as the latest checkpoint
742            self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
743
744            # if we save after every k-th epoch then check if we need to save now
745            if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
746                self.save_checkpoint(
747                    f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
748                )
749
750            # if early stopping has been specified then check if the stopping condition is met
751            if self.early_stopping is not None:
752                epochs_since_best = self._epoch - self._best_epoch
753                if epochs_since_best > self.early_stopping:
754                    print("Stopping training because there has been no improvement for", self.early_stopping, "epochs")
755                    break
756
757            self._epoch += 1
758            progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True)
759
760        print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.")
761        print(f"The best epoch is number {self._best_epoch}.")
762
763        if self._generate_name:
764            self.name = None
765
766        # Update the train time
767        self.train_time = total_train_time
768
769        # TODO save the model to wandb if we have the wandb logger
770        if isinstance(self.logger, WandbLogger):
771            self.logger.get_wandb().finish()

Run neural network training.

Exactly one of 'iterations' or 'epochs' has to be passed.

Arguments:
  • iterations: How long to train, specified in iterations.
  • load_from_checkpoint: Path to a checkpoint from where training should be continued .
  • epochs: How long to train, specified in epochs.
  • save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
  • progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
  • overwrite_training: Whether to overwrite existing checkpoints in the save directory.
class DefaultTrainer.Deserializer:
163    class Deserializer:
164        """Determines how to deserialize the trainer kwargs from serialized 'init_data'.
165
166        Examples:
167            To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
168            Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
169
170            This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
171            >>> class MyTrainer(DefaultTrainer):
172            >>>     def __init__(self, *args, the_answer: int, **kwargs):
173            >>>         super().__init__(*args, **kwargs)
174            >>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
175            >>>                                       # see DefaultTrainer.Serializer
176            >>>
177            >>>     class Deserializer(DefaultTrainer.Deserializer):
178            >>>         def load_the_answer(self):
179            >>>             generic_answer = self.init_data["the_answer"]
180            >>>             # (device dependent) special deserialization
181            >>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
182            >>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
183            >>>             else:
184            >>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
185
186        Args:
187            init_data: The initialization data of the trainer.
188            save_path: The path where the checkpoint was saved.
189            device: The device.
190        """
191
192        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
193            self.init_data = init_data
194            self.save_path = save_path
195            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
196            self.trainer_kwargs: Dict[str, Any] = dict(
197                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
198            )
199
200        def load(self, kwarg_name: str, optional):
201            """@private
202            """
203            # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
204            if kwarg_name == "device":
205                pass  # deserialized in __init__
206            elif kwarg_name.endswith("_loader"):
207                self.load_data_loader(kwarg_name, optional)
208            else:
209                load = getattr(self, f"load_{kwarg_name}", self.load_generic)
210                load(kwarg_name, optional=optional)
211
212        def load_data_loader(self, loader_name, optional) -> None:
213            """@private
214            """
215            ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
216            if ds is None and optional:
217                return
218
219            loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
220            loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
221            # monkey patch shuffle loader_name to the loader
222            loader.shuffle = loader_kwargs.get("shuffle", False)
223            self.trainer_kwargs[loader_name] = loader
224
225        def load_generic(
226            self,
227            kwarg_name: str,
228            *dynamic_args: Dict,
229            optional: bool,
230            only_class: bool = False,
231            dynamic_kwargs: Optional[Dict[str, Any]] = None,
232        ) -> None:
233            """@private
234            """
235            if kwarg_name in self.init_data:
236                self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
237                return
238
239            this_cls = self.init_data.get(f"{kwarg_name}_class", None)
240            if this_cls is None:
241                if optional:
242                    return
243                else:
244                    raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")
245
246            assert isinstance(this_cls, str), this_cls
247            assert "." in this_cls, this_cls
248            cls_p, cls_m = this_cls.rsplit(".", 1)
249            this_cls = getattr(import_module(cls_p), cls_m)
250            if only_class:
251                self.trainer_kwargs[kwarg_name] = this_cls
252            else:
253                self.trainer_kwargs[kwarg_name] = this_cls(
254                    *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
255                )
256
257        def load_name(self, kwarg_name: str, optional: bool):
258            """@private
259            """
260            self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]
261
262        def load_optimizer(self, kwarg_name: str, optional: bool):
263            """@private
264            """
265            self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)
266
267        def load_lr_scheduler(self, kwarg_name: str, optional: bool):
268            """@private
269            """
270            self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)
271
272        # todo: remove and rename kwarg 'logger' to 'logger_class'
273        def load_logger(self, kwarg_name: str, optional: bool):
274            """@private
275            """
276            assert kwarg_name == "logger"
277            self.load_generic("logger", optional=optional, only_class=True)

Determines how to deserialize the trainer kwargs from serialized 'init_data'.

Examples:

To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. Note that DefaultTrainer.Deserializer.load_generic() covers most cases already.

This example adds the_answer kwarg, which requires 'calculations' upon initialization:

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         self.the_answer = the_answer  # this allows the default Serializer to save the new kwarg,
>>>                                       # see DefaultTrainer.Serializer
>>>
>>>     class Deserializer(DefaultTrainer.Deserializer):
>>>         def load_the_answer(self):
>>>             generic_answer = self.init_data["the_answer"]
>>>             # (device dependent) special deserialization
>>>             if self.trainer_kwargs["device"].type == "cpu":  # accessing previously deserialized kwarg
>>>                 self.trainer_kwargs["the_answer"] = generic_answer + 1
>>>             else:
>>>                 self.trainer_kwargs["the_answer"] = generic_answer * 2
Arguments:
  • init_data: The initialization data of the trainer.
  • save_path: The path where the checkpoint was saved.
  • device: The device.
DefaultTrainer.Deserializer(init_data: Dict, save_path: str, device: Union[str, torch.device])
192        def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
193            self.init_data = init_data
194            self.save_path = save_path
195            # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
196            self.trainer_kwargs: Dict[str, Any] = dict(
197                device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
198            )
init_data
save_path
trainer_kwargs: Dict[str, Any]
class DefaultTrainer.Serializer:
329    class Serializer:
330        """Implements how to serialize trainer kwargs from a trainer instance.
331
332        Examples:
333            To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
334            Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()`
335            called by the `dump()` method when appropriate cover most cases already.
336
337            This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a
338            'the_answer' attribute:
339            >>> class MyTrainer(DefaultTrainer):
340            >>>     def __init__(self, *args, the_answer: int, **kwargs):
341            >>>         super().__init__(*args, **kwargs)
342            >>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
343            >>>         # but let's make things more interesting...
344            >>>         self.the = the_answer // 10
345            >>>         self.answer = the_answer % 10
346            >>>
347            >>>     class Serializer(DefaultTrainer.Serializer):
348            >>>         trainer: MyTrainer
349            >>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
350            >>>             assert kwarg_name == "the_answer"
351            >>>             # populate self.init_data with the serialized data required by Deserializer
352            >>>             # to restore the trainer kwargs
353            >>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
354
355            This example with both Serializer and Deserializer adds `the_answer` kwarg,
356            while saving it in two separate entries 'the' and 'answer'
357            >>> class MyTrainer(DefaultTrainer):
358            >>>     def __init__(self, *args, the_answer: int, **kwargs):
359            >>>         super().__init__(*args, **kwargs)
360            >>>         self.the_answer = the_answer
361            >>>
362            >>>     class Serializer(DefaultTrainer.Serializer):
363            >>>         trainer: MyTrainer
364            >>>         def dump_the_answer(self, kwarg_name: str):
365            >>>             assert kwarg_name == "the_answer"
366            >>>             self.init_data.update({
367            >>>                 "the": self.trainer.the_answer // 10,
368            >>>                 "answer": self.trainer.the_answer % 10
369            >>>             })
370            >>>
371            >>>     class Deserializer(DefaultTrainer.Deserializer):
372            >>>         def load_the_answer(self, kwarg_name: str, optional: bool):
373            >>>             assert kwarg_name == "the_answer"
374            >>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
375            >>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
376
377        Args:
378            trainer: The trainer instance.
379        """
380
381        def __init__(self, trainer: DefaultTrainer):
382            self.trainer = trainer
383            self.init_data = {}  # to be populated during serialization process
384
385        def dump(self, kwarg_name: str) -> None:
386            """@private
387            """
388            dumper = getattr(self, f"dump_{kwarg_name}", None)
389            if dumper is not None:
390                dumper(kwarg_name)
391            elif kwarg_name.endswith("_loader"):
392                self.dump_data_loader(kwarg_name)
393            elif kwarg_name.endswith("_class"):
394                self.dump_generic_class(kwarg_name)
395            elif not hasattr(self.trainer, kwarg_name):
396                raise AttributeError(
397                    f"{self.trainer.__class__} missing attribute '{kwarg_name}' "
398                    f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()"
399                )
400            else:
401                assert hasattr(self.trainer, kwarg_name)
402                obj = getattr(self.trainer, kwarg_name)
403                if obj is None or type(obj) in (
404                    bool,
405                    bytearray,
406                    bytes,
407                    dict,
408                    float,
409                    frozenset,
410                    int,
411                    list,
412                    set,
413                    str,
414                    tuple,
415                ):
416                    self.dump_generic_builtin(kwarg_name)
417                else:
418                    self.dump_generic_instance(kwarg_name)
419
420        def dump_generic_builtin(self, kwarg_name: str) -> None:
421            """@private
422            """
423            assert hasattr(self.trainer, kwarg_name)
424            self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)
425
426        def dump_generic_class(self, kwarg_name: str) -> None:
427            """@private
428            """
429            assert hasattr(self.trainer, kwarg_name)
430            assert kwarg_name.endswith("_class")
431            obj = getattr(self.trainer, kwarg_name)
432            self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"
433
434        def dump_generic_instance(self, kwarg_name: str) -> None:
435            """@private
436            """
437            assert hasattr(self.trainer, kwarg_name)
438            instance = getattr(self.trainer, kwarg_name)
439            self.init_data.update(
440                {
441                    f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
442                    f"{kwarg_name}_kwargs": get_constructor_arguments(instance),
443                }
444            )
445
446        def dump_device(self, kwarg_name: str):
447            """@private
448            """
449            assert hasattr(self.trainer, kwarg_name)
450            self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))
451
452        def dump_data_loader(self, kwarg_name: str) -> None:
453            """@private
454            """
455            assert hasattr(self.trainer, kwarg_name)
456            loader = getattr(self.trainer, kwarg_name)
457            if loader is None:
458                return
459            self.init_data.update(
460                {
461                    f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset,
462                    f"{kwarg_name}_kwargs": get_constructor_arguments(loader),
463                }
464            )
465
466        def dump_logger(self, kwarg_name: str):  # todo: remove and rename kwarg 'logger' to 'logger_class'
467            """@private
468            """
469            self.dump_generic_class(f"{kwarg_name}_class")
470
471        def dump_model(self, kwarg_name: str):
472            """@private
473            """
474            if is_compiled(self.trainer.model):
475                self.init_data.update(
476                    {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs}
477                )
478            else:
479                self.dump_generic_instance("model")

Implements how to serialize trainer kwargs from a trainer instance.

Examples:

To extend the serialization process you can inherite from this Serializer in a derived Trainer class. Note that the methods dump_generic_builtin(), dump_generic_class() and dump_generic_instance() called by the dump() method when appropriate cover most cases already.

This example adds the_answer kwarg, which requires extra steps on dumping only because we don't keep a 'the_answer' attribute:

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         # self.the_answer = the_answer  # this would allow the default Serializer to save the new kwarg,
>>>         # but let's make things more interesting...
>>>         self.the = the_answer // 10
>>>         self.answer = the_answer % 10
>>>
>>>     class Serializer(DefaultTrainer.Serializer):
>>>         trainer: MyTrainer
>>>         def dump_the_answer(self, kwarg_name: str) -> None:  # custom dump method for 'the_answer' kwarg
>>>             assert kwarg_name == "the_answer"
>>>             # populate self.init_data with the serialized data required by Deserializer
>>>             # to restore the trainer kwargs
>>>             self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer

This example with both Serializer and Deserializer adds the_answer kwarg, while saving it in two separate entries 'the' and 'answer'

>>> class MyTrainer(DefaultTrainer):
>>>     def __init__(self, *args, the_answer: int, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>         self.the_answer = the_answer
>>>
>>>     class Serializer(DefaultTrainer.Serializer):
>>>         trainer: MyTrainer
>>>         def dump_the_answer(self, kwarg_name: str):
>>>             assert kwarg_name == "the_answer"
>>>             self.init_data.update({
>>>                 "the": self.trainer.the_answer // 10,
>>>                 "answer": self.trainer.the_answer % 10
>>>             })
>>>
>>>     class Deserializer(DefaultTrainer.Deserializer):
>>>         def load_the_answer(self, kwarg_name: str, optional: bool):
>>>             assert kwarg_name == "the_answer"
>>>             # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
>>>             self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
Arguments:
  • trainer: The trainer instance.
DefaultTrainer.Serializer(trainer: DefaultTrainer)
381        def __init__(self, trainer: DefaultTrainer):
382            self.trainer = trainer
383            self.init_data = {}  # to be populated during serialization process
trainer
init_data