torch_em.trainer.default_trainer
1from __future__ import annotations 2 3import os 4import time 5import inspect 6import warnings 7import contextlib 8from tqdm import tqdm 9from collections import OrderedDict 10from functools import partial 11from importlib import import_module 12from typing import Any, Callable, Dict, Optional, Union, Literal 13 14import numpy as np 15 16import torch 17 18from .wandb_logger import WandbLogger 19from .tensorboard_logger import TensorboardLogger 20from ..util import auto_compile, get_constructor_arguments, is_compiled 21 22 23class DefaultTrainer: 24 """Trainer class for training a segmentation network. 25 26 The trainer class implements the core logic for training a network with pytorch. 27 It implements a training loop to run training and validation, which is started with `fit`. 28 The checkpoints and logs from the training run will be saved in the current working directory, 29 or in the directory specifified by `save_root`. Training can be continued from a checkpoint 30 by passing it's location to the `load_from_checkpoint` argument of `fit`. 31 32 A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`. 33 Alternatively, the trainer class can also be instantiated as in this example: 34 ```python 35 import torch 36 from torch_em.loss import DiceLoss 37 from torch_em.model import UNet2d 38 from torch_em.data.datasets.light_microscopy import get_dsb_loader 39 from torch_em.trainer import DefaultTrainer 40 41 # The training data will be downloaded to this location. 42 data_root = "/path/to/save/the/training/data" 43 patch_shape = (256, 256) 44 45 # Create the model and optimizer. 46 model = UNet2d(in_channels=1, out_channels=1) 47 optimizer = torch.optim.AdamW(model.parameters()) 48 49 trainer = DefaultTrainer( 50 name="unet-training", 51 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 52 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 53 model=model, 54 loss=DiceLoss(), # The loss function. 55 optimizer=optimizer, 56 metric=DiceLoss(), # The metric. The trainer expects smaller values to represent better results. 57 device="cuda", # The device to use for training. 58 ) 59 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 60 ``` 61 62 Args: 63 name: The name of the checkpoint that will be created by the trainer. 64 train_loader: The data loader containing the training data. 65 val_loader: The data loader containing the validation data. 66 model: The model to train. 67 loss: The loss function for training. 68 optimizer: The optimizer. 69 metric: The metric for validation. 70 device: The torch device to use for training. If None, will use a GPU if available. 71 lr_scheduler: The learning rate scheduler. 72 log_image_interval: The interval for saving images during logging, in training iterations. 73 mixed_precision: Whether to train with mixed precision. 74 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 75 logger: The logger class. Will be instantiated for logging. 76 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 77 logger_kwargs: The keyword arguments for the logger class. 78 id_: Unique identifier for the trainer. If None then `name` will be used. 79 save_root: The root folder for saving the checkpoint and logs. 80 compile_model: Whether to compile the model before training. 81 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 82 """ 83 def __init__( 84 self, 85 name: Optional[str], 86 train_loader: torch.utils.data.DataLoader, 87 val_loader: torch.utils.data.DataLoader, 88 model: torch.nn.Module, 89 loss: torch.nn.Module, 90 optimizer: torch.optim.Optimizer, 91 metric: Callable, 92 device: Union[str, torch.device], 93 lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 94 log_image_interval: int = 100, 95 mixed_precision: bool = True, 96 early_stopping: Optional[int] = None, 97 logger=TensorboardLogger, 98 logger_kwargs: Optional[Dict[str, Any]] = None, 99 id_: Optional[str] = None, 100 save_root: Optional[str] = None, 101 compile_model: Optional[Union[bool, str]] = None, 102 rank: Optional[int] = None, 103 ): 104 if name is None and not issubclass(logger, WandbLogger): 105 raise TypeError("Name cannot be None if not using the WandbLogger") 106 107 if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): 108 raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.") 109 110 self._generate_name = name is None 111 self.name = name 112 self.id_ = id_ or name 113 self.train_loader = train_loader 114 self.val_loader = val_loader 115 self.model = model 116 self.loss = loss 117 self.optimizer = optimizer 118 self.metric = metric 119 self.device = torch.device(device) 120 self.lr_scheduler = lr_scheduler 121 self.log_image_interval = log_image_interval 122 self.save_root = save_root 123 self.compile_model = compile_model 124 self.rank = rank 125 126 self._iteration = 0 127 self._epoch = 0 128 self._best_epoch = 0 129 130 self.mixed_precision = mixed_precision 131 self.early_stopping = early_stopping 132 self.train_time = 0.0 133 134 if mixed_precision: 135 self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda") 136 else: 137 self.scaler = None 138 139 self.logger_class = logger 140 self.logger_kwargs = logger_kwargs 141 self.log_image_interval = log_image_interval 142 143 @property 144 def checkpoint_folder(self): 145 assert self.id_ is not None # Because the logger may generate and set trainer.id on logger.__init__. 146 # Save_root enables saving the checkpoints somewhere else than in the local older. 147 # This is handy for filesystems with limited space, where saving the checkpoints 148 # and log files can ead to running out of space. 149 save_root = getattr(self, "save_root", None) 150 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 151 os.path.join(save_root, "./checkpoints", self.id_) 152 153 @property 154 def iteration(self): 155 return self._iteration 156 157 @property 158 def epoch(self): 159 return self._epoch 160 161 class Deserializer: 162 """Determines how to deserialize the trainer kwargs from serialized 'init_data'. 163 164 Examples: 165 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 166 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 167 168 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 169 >>> class MyTrainer(DefaultTrainer): 170 >>> def __init__(self, *args, the_answer: int, **kwargs): 171 >>> super().__init__(*args, **kwargs) 172 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 173 >>> # see DefaultTrainer.Serializer 174 >>> 175 >>> class Deserializer(DefaultTrainer.Deserializer): 176 >>> def load_the_answer(self): 177 >>> generic_answer = self.init_data["the_answer"] 178 >>> # (device dependent) special deserialization 179 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 180 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 181 >>> else: 182 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 183 184 Args: 185 init_data: The initialization data of the trainer. 186 save_path: The path where the checkpoint was saved. 187 device: The device. 188 """ 189 190 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 191 self.init_data = init_data 192 self.save_path = save_path 193 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 194 self.trainer_kwargs: Dict[str, Any] = dict( 195 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 196 ) 197 198 def load(self, kwarg_name: str, optional): 199 """@private 200 """ 201 # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name' 202 if kwarg_name == "device": 203 pass # deserialized in __init__ 204 elif kwarg_name.endswith("_loader"): 205 self.load_data_loader(kwarg_name, optional) 206 else: 207 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 208 load(kwarg_name, optional=optional) 209 210 def load_data_loader(self, loader_name, optional) -> None: 211 """@private 212 """ 213 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 214 if ds is None and optional: 215 return 216 217 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 218 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 219 # monkey patch shuffle loader_name to the loader 220 loader.shuffle = loader_kwargs.get("shuffle", False) 221 self.trainer_kwargs[loader_name] = loader 222 223 def load_generic( 224 self, 225 kwarg_name: str, 226 *dynamic_args: Dict, 227 optional: bool, 228 only_class: bool = False, 229 dynamic_kwargs: Optional[Dict[str, Any]] = None, 230 ) -> None: 231 """@private 232 """ 233 if kwarg_name in self.init_data: 234 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 235 return 236 237 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 238 if this_cls is None: 239 if optional: 240 return 241 else: 242 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 243 244 assert isinstance(this_cls, str), this_cls 245 assert "." in this_cls, this_cls 246 cls_p, cls_m = this_cls.rsplit(".", 1) 247 this_cls = getattr(import_module(cls_p), cls_m) 248 if only_class: 249 self.trainer_kwargs[kwarg_name] = this_cls 250 else: 251 self.trainer_kwargs[kwarg_name] = this_cls( 252 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 253 ) 254 255 def load_name(self, kwarg_name: str, optional: bool): 256 """@private 257 """ 258 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 259 260 def load_optimizer(self, kwarg_name: str, optional: bool): 261 """@private 262 """ 263 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 264 265 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 266 """@private 267 """ 268 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 269 270 # todo: remove and rename kwarg 'logger' to 'logger_class' 271 def load_logger(self, kwarg_name: str, optional: bool): 272 """@private 273 """ 274 assert kwarg_name == "logger" 275 self.load_generic("logger", optional=optional, only_class=True) 276 277 @staticmethod 278 def _get_save_dict(save_path, device): 279 if not os.path.exists(save_path): 280 raise ValueError(f"Cannot find checkpoint {save_path}") 281 return torch.load(save_path, map_location=device, weights_only=False) 282 283 @classmethod 284 def from_checkpoint( 285 cls, 286 checkpoint_folder: Union[os.PathLike, str], 287 name: Literal["best", "latest"] = "best", 288 device: Optional[Union[str, torch.device]] = None, 289 ): 290 """@private 291 """ 292 save_path = os.path.join(checkpoint_folder, f"{name}.pt") 293 # make sure the correct device is set if we don't have access to CUDA 294 if not torch.cuda.is_available(): 295 device = "cpu" 296 save_dict = cls._get_save_dict(save_path, device) 297 deserializer = cls.Deserializer(save_dict["init"], save_path, device) 298 299 has_kwargs = False 300 deserialized = [] 301 for name, parameter in inspect.signature(cls).parameters.items(): 302 if name == "kwargs": 303 has_kwargs = True 304 continue 305 deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty) 306 deserialized.append(name) 307 308 # to deserialze kwargs we can't rely on inspecting the signature, so we 309 # go through the remaning kwarg names in init data instead 310 if has_kwargs: 311 kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized)) 312 for name in kwarg_names: 313 if name.endswith("_kwargs"): 314 continue 315 elif name.endswith("_dataset"): 316 deserializer.load(name.replace("dataset", "loader"), optional=False) 317 elif name.endswith("_class"): 318 deserializer.load(name.replace("_class", ""), optional=False) 319 else: 320 deserializer.load(name, optional=False) 321 322 trainer = cls(**deserializer.trainer_kwargs) 323 trainer._initialize(0, save_dict) 324 trainer._is_initialized = True 325 return trainer 326 327 class Serializer: 328 """Implements how to serialize trainer kwargs from a trainer instance. 329 330 Examples: 331 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 332 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 333 called by the `dump()` method when appropriate cover most cases already. 334 335 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 336 'the_answer' attribute: 337 >>> class MyTrainer(DefaultTrainer): 338 >>> def __init__(self, *args, the_answer: int, **kwargs): 339 >>> super().__init__(*args, **kwargs) 340 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 341 >>> # but let's make things more interesting... 342 >>> self.the = the_answer // 10 343 >>> self.answer = the_answer % 10 344 >>> 345 >>> class Serializer(DefaultTrainer.Serializer): 346 >>> trainer: MyTrainer 347 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 348 >>> assert kwarg_name == "the_answer" 349 >>> # populate self.init_data with the serialized data required by Deserializer 350 >>> # to restore the trainer kwargs 351 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 352 353 This example with both Serializer and Deserializer adds `the_answer` kwarg, 354 while saving it in two separate entries 'the' and 'answer' 355 >>> class MyTrainer(DefaultTrainer): 356 >>> def __init__(self, *args, the_answer: int, **kwargs): 357 >>> super().__init__(*args, **kwargs) 358 >>> self.the_answer = the_answer 359 >>> 360 >>> class Serializer(DefaultTrainer.Serializer): 361 >>> trainer: MyTrainer 362 >>> def dump_the_answer(self, kwarg_name: str): 363 >>> assert kwarg_name == "the_answer" 364 >>> self.init_data.update({ 365 >>> "the": self.trainer.the_answer // 10, 366 >>> "answer": self.trainer.the_answer % 10 367 >>> }) 368 >>> 369 >>> class Deserializer(DefaultTrainer.Deserializer): 370 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 371 >>> assert kwarg_name == "the_answer" 372 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 373 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 374 375 Args: 376 trainer: The trainer instance. 377 """ 378 379 def __init__(self, trainer: DefaultTrainer): 380 self.trainer = trainer 381 self.init_data = {} # to be populated during serialization process 382 383 def dump(self, kwarg_name: str) -> None: 384 """@private 385 """ 386 dumper = getattr(self, f"dump_{kwarg_name}", None) 387 if dumper is not None: 388 dumper(kwarg_name) 389 elif kwarg_name.endswith("_loader"): 390 self.dump_data_loader(kwarg_name) 391 elif kwarg_name.endswith("_class"): 392 self.dump_generic_class(kwarg_name) 393 elif not hasattr(self.trainer, kwarg_name): 394 raise AttributeError( 395 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 396 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 397 ) 398 else: 399 assert hasattr(self.trainer, kwarg_name) 400 obj = getattr(self.trainer, kwarg_name) 401 if obj is None or type(obj) in ( 402 bool, 403 bytearray, 404 bytes, 405 dict, 406 float, 407 frozenset, 408 int, 409 list, 410 set, 411 str, 412 tuple, 413 ): 414 self.dump_generic_builtin(kwarg_name) 415 else: 416 self.dump_generic_instance(kwarg_name) 417 418 def dump_generic_builtin(self, kwarg_name: str) -> None: 419 """@private 420 """ 421 assert hasattr(self.trainer, kwarg_name) 422 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 423 424 def dump_generic_class(self, kwarg_name: str) -> None: 425 """@private 426 """ 427 assert hasattr(self.trainer, kwarg_name) 428 assert kwarg_name.endswith("_class") 429 obj = getattr(self.trainer, kwarg_name) 430 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 431 432 def dump_generic_instance(self, kwarg_name: str) -> None: 433 """@private 434 """ 435 assert hasattr(self.trainer, kwarg_name) 436 instance = getattr(self.trainer, kwarg_name) 437 self.init_data.update( 438 { 439 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 440 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 441 } 442 ) 443 444 def dump_device(self, kwarg_name: str): 445 """@private 446 """ 447 assert hasattr(self.trainer, kwarg_name) 448 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 449 450 def dump_data_loader(self, kwarg_name: str) -> None: 451 """@private 452 """ 453 assert hasattr(self.trainer, kwarg_name) 454 loader = getattr(self.trainer, kwarg_name) 455 if loader is None: 456 return 457 self.init_data.update( 458 { 459 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 460 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 461 } 462 ) 463 464 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 465 """@private 466 """ 467 self.dump_generic_class(f"{kwarg_name}_class") 468 469 def dump_model(self, kwarg_name: str): 470 """@private 471 """ 472 if is_compiled(self.trainer.model): 473 self.init_data.update( 474 {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs} 475 ) 476 else: 477 self.dump_generic_instance("model") 478 479 def _build_init(self) -> Dict[str, Any]: 480 serializer = self.Serializer(self) 481 for name in inspect.signature(self.__class__).parameters: 482 # special rules to serialize kwargs 483 # if a trainer class inherits from DefaultTrainer and has **kwargs 484 # they need to be saved in self._kwargs 485 if name == "kwargs": 486 if not hasattr(self, "_kwargs"): 487 msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\ 488 "Please add self._kwargs to its __init__ function" 489 raise RuntimeError(msg) 490 kwargs = getattr(self, "_kwargs") 491 for kwarg_name in kwargs: 492 serializer.dump(kwarg_name) 493 continue 494 serializer.dump(name) 495 496 return serializer.init_data 497 498 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 499 assert self.train_loader is not None 500 assert self.val_loader is not None 501 assert self.model is not None 502 assert self.loss is not None 503 assert self.optimizer is not None 504 assert self.metric is not None 505 assert self.device is not None 506 507 if load_from_checkpoint is not None: 508 self.load_checkpoint(load_from_checkpoint) 509 510 if sum((iterations is not None, epochs is not None)) != 1: 511 raise ValueError( 512 "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer." 513 f"You have passed 'iterations'={iterations} and 'epochs'={epochs}" 514 ) 515 516 if epochs is None: 517 epochs = int(np.ceil(float(iterations) / len(self.train_loader))) 518 else: 519 iterations = epochs * len(self.train_loader) 520 521 self.max_iteration = self._iteration + iterations 522 self.max_epoch = self._epoch + epochs 523 524 if not getattr(self, "_is_initialized", False): 525 # check if we compile the model (only supported by pytorch 2) 526 # to enable (de)serialization of compiled models, we keep track of the model class and kwargs 527 if is_compiled(self.model): 528 warnings.warn( 529 "You have passed a compiled model to the trainer." 530 "It will not be possible to (de)serialize the trainer with it." 531 "If you want to be able to do this please pass the normal model." 532 "It can be automatically compiled by setting 'compile_model' to True" 533 ) 534 self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}" 535 self._model_kwargs = get_constructor_arguments(self.model) 536 self.model = auto_compile(self.model, self.compile_model) 537 538 self.model.to(self.device) 539 self.loss.to(self.device) 540 541 # this saves all the information that is necessary 542 # to fully load the trainer from the checkpoint 543 self.init_data = self._build_init() 544 545 if self.logger_class is None: 546 self.logger = None 547 else: 548 # may set self.name if self.name is None 549 save_root = getattr(self, "save_root", None) 550 try: 551 self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {})) 552 except PermissionError: 553 warnings.warn( 554 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 555 "The most likely reason for this is that you copied the checkpoint somewhere else," 556 "so we skip this error to enable loading the model from this checkpoint." 557 ) 558 559 try: 560 os.makedirs(self.checkpoint_folder, exist_ok=True) 561 except PermissionError: 562 warnings.warn( 563 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 564 "The most likely reason for this is that you copied the checkpoint somewhere else," 565 "so we skip this error to enable loading the model from this checkpoint." 566 ) 567 pass 568 569 best_metric = np.inf 570 return best_metric 571 572 def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict): 573 """@private 574 """ 575 save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") 576 extra_init_dict = extra_save_dict.pop("init", {}) 577 save_dict = { 578 "iteration": self._iteration, 579 "epoch": self._epoch, 580 "best_epoch": self._best_epoch, 581 "best_metric": best_metric, 582 "current_metric": current_metric, 583 "model_state": self.model.state_dict(), 584 "optimizer_state": self.optimizer.state_dict(), 585 "init": self.init_data | extra_init_dict, 586 "train_time": train_time, 587 } 588 save_dict.update(**extra_save_dict) 589 if self.scaler is not None: 590 save_dict.update({"scaler_state": self.scaler.state_dict()}) 591 if self.lr_scheduler is not None: 592 save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()}) 593 594 rank = getattr(self, "rank", None) 595 if rank is None or rank == 0: 596 torch.save(save_dict, save_path) 597 598 def load_checkpoint(self, checkpoint="best"): 599 """@private 600 """ 601 if isinstance(checkpoint, str): 602 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 603 if not os.path.exists(save_path): 604 warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.") 605 return 606 save_dict = torch.load(save_path, weights_only=False) 607 elif isinstance(checkpoint, dict): 608 save_dict = checkpoint 609 else: 610 raise RuntimeError 611 612 self._iteration = save_dict["iteration"] 613 self._epoch = save_dict["epoch"] 614 self._best_epoch = save_dict["best_epoch"] 615 self.best_metric = save_dict["best_metric"] 616 self.current_metric = save_dict["current_metric"] 617 self.train_time = save_dict.get("train_time", 0.0) 618 619 model_state = save_dict["model_state"] 620 # to enable loading compiled models 621 compiled_prefix = "_orig_mod." 622 model_state = OrderedDict( 623 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()] 624 ) 625 self.model.load_state_dict(model_state) 626 # we need to send the network to the device before loading the optimizer state! 627 self.model.to(self.device) 628 629 self.optimizer.load_state_dict(save_dict["optimizer_state"]) 630 if self.scaler is not None: 631 self.scaler.load_state_dict(save_dict["scaler_state"]) 632 if self.lr_scheduler is not None: 633 self.lr_scheduler.load_state_dict(save_dict["scheduler_state"]) 634 635 return save_dict 636 637 def _verify_if_training_completed(self, checkpoint="latest"): 638 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 639 save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None 640 if save_dict and self.max_iteration == save_dict.get("iteration"): 641 return True 642 return False 643 644 def fit( 645 self, 646 iterations: Optional[int] = None, 647 load_from_checkpoint: Optional[Union[os.PathLike, str]] = None, 648 epochs: Optional[int] = None, 649 save_every_kth_epoch: Optional[int] = None, 650 progress=None, 651 overwrite_training: bool = True, 652 ): 653 """Run neural network training. 654 655 Exactly one of 'iterations' or 'epochs' has to be passed. 656 657 Args: 658 iterations: How long to train, specified in iterations. 659 load_from_checkpoint: Path to a checkpoint from where training should be continued . 660 epochs: How long to train, specified in epochs. 661 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 662 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 663 progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface. 664 overwrite_training: Whether to overwrite existing checkpoints in the save directory. 665 """ 666 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 667 668 if not overwrite_training: 669 if load_from_checkpoint is not None: 670 raise ValueError( 671 "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time." 672 ) 673 674 if self._verify_if_training_completed(): 675 print( 676 f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs " 677 "and 'overwrite_training' is set to 'False'." 678 ) 679 print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.") 680 return 681 682 print( 683 "Start fitting for", 684 self.max_iteration - self._iteration, 685 "iterations / ", 686 self.max_epoch - self._epoch, 687 "epochs", 688 ) 689 print("with", len(self.train_loader), "iterations per epoch") 690 691 if self.mixed_precision: 692 train_epoch = self._train_epoch_mixed 693 validate = self._validate_mixed 694 print("Training with mixed precision") 695 else: 696 train_epoch = self._train_epoch 697 validate = self._validate 698 print("Training with single precision") 699 700 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 701 if progress is None: 702 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 703 else: 704 progress.total = total_iterations 705 progress.set_description(f"Epoch {self._epoch}") 706 707 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 708 train_epochs = self.max_epoch - self._epoch 709 t_start = time.time() 710 for epoch in range(train_epochs): 711 712 # Ensure data is shuffled differently at each epoch. 713 try: 714 self.train_loader.sampler.set_epoch(epoch) 715 except AttributeError: 716 pass 717 718 # Run training and validation for this epoch 719 t_per_iter = train_epoch(progress) 720 current_metric = validate() 721 722 # perform all the post-epoch steps: 723 724 # apply the learning rate scheduler 725 if self.lr_scheduler is not None: 726 self.lr_scheduler.step(current_metric) 727 728 # how long did we train in total? 729 total_train_time = (time.time() - t_start) + self.train_time 730 731 # save this checkpoint as the new best checkpoint if 732 # it has the best overall validation metric 733 if current_metric < best_metric: 734 best_metric = current_metric 735 self._best_epoch = self._epoch 736 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 737 738 # save this checkpoint as the latest checkpoint 739 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 740 741 # if we save after every k-th epoch then check if we need to save now 742 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 743 self.save_checkpoint( 744 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 745 ) 746 747 # if early stopping has been specified then check if the stopping condition is met 748 if self.early_stopping is not None: 749 epochs_since_best = self._epoch - self._best_epoch 750 if epochs_since_best > self.early_stopping: 751 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 752 break 753 754 self._epoch += 1 755 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 756 757 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 758 print(f"The best epoch is number {self._best_epoch}.") 759 760 if self._generate_name: 761 self.name = None 762 763 # Update the train time 764 self.train_time = total_train_time 765 766 # TODO save the model to wandb if we have the wandb logger 767 if isinstance(self.logger, WandbLogger): 768 self.logger.get_wandb().finish() 769 770 def _backprop(self, loss): 771 loss.backward() 772 self.optimizer.step() 773 774 def _backprop_mixed(self, loss): 775 self.scaler.scale(loss).backward() 776 self.scaler.step(self.optimizer) 777 self.scaler.update() 778 779 def _train_epoch(self, progress): 780 return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop) 781 782 def _train_epoch_mixed(self, progress): 783 return self._train_epoch_impl( 784 progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"), 785 self._backprop_mixed 786 ) 787 788 def _forward_and_loss(self, x, y): 789 pred = self.model(x) 790 if self._iteration % self.log_image_interval == 0: 791 if pred.requires_grad: 792 pred.retain_grad() 793 794 loss = self.loss(pred, y) 795 return pred, loss 796 797 def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]): 798 self.model.train() 799 800 n_iter = 0 801 t_per_iter = time.time() 802 for x, y in self.train_loader: 803 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 804 805 self.optimizer.zero_grad() 806 807 with forward_context(): 808 pred, loss = self._forward_and_loss(x, y) 809 810 backprop(loss) 811 812 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 813 if self.logger is not None: 814 self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True) 815 816 self._iteration += 1 817 n_iter += 1 818 if self._iteration >= self.max_iteration: 819 break 820 progress.update(1) 821 822 t_per_iter = (time.time() - t_per_iter) / n_iter 823 return t_per_iter 824 825 def _validate(self): 826 return self._validate_impl(contextlib.nullcontext) 827 828 def _validate_mixed(self): 829 return self._validate_impl( 830 partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda") 831 ) 832 833 def _validate_impl(self, forward_context): 834 self.model.eval() 835 836 metric_val = 0.0 837 loss_val = 0.0 838 839 with torch.no_grad(): 840 for x, y in self.val_loader: 841 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 842 with forward_context(): 843 pred, loss = self._forward_and_loss(x, y) 844 metric = self.metric(pred, y) 845 846 loss_val += loss.item() 847 metric_val += metric.item() 848 849 metric_val /= len(self.val_loader) 850 loss_val /= len(self.val_loader) 851 if self.logger is not None: 852 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred) 853 return metric_val
24class DefaultTrainer: 25 """Trainer class for training a segmentation network. 26 27 The trainer class implements the core logic for training a network with pytorch. 28 It implements a training loop to run training and validation, which is started with `fit`. 29 The checkpoints and logs from the training run will be saved in the current working directory, 30 or in the directory specifified by `save_root`. Training can be continued from a checkpoint 31 by passing it's location to the `load_from_checkpoint` argument of `fit`. 32 33 A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`. 34 Alternatively, the trainer class can also be instantiated as in this example: 35 ```python 36 import torch 37 from torch_em.loss import DiceLoss 38 from torch_em.model import UNet2d 39 from torch_em.data.datasets.light_microscopy import get_dsb_loader 40 from torch_em.trainer import DefaultTrainer 41 42 # The training data will be downloaded to this location. 43 data_root = "/path/to/save/the/training/data" 44 patch_shape = (256, 256) 45 46 # Create the model and optimizer. 47 model = UNet2d(in_channels=1, out_channels=1) 48 optimizer = torch.optim.AdamW(model.parameters()) 49 50 trainer = DefaultTrainer( 51 name="unet-training", 52 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 53 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 54 model=model, 55 loss=DiceLoss(), # The loss function. 56 optimizer=optimizer, 57 metric=DiceLoss(), # The metric. The trainer expects smaller values to represent better results. 58 device="cuda", # The device to use for training. 59 ) 60 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 61 ``` 62 63 Args: 64 name: The name of the checkpoint that will be created by the trainer. 65 train_loader: The data loader containing the training data. 66 val_loader: The data loader containing the validation data. 67 model: The model to train. 68 loss: The loss function for training. 69 optimizer: The optimizer. 70 metric: The metric for validation. 71 device: The torch device to use for training. If None, will use a GPU if available. 72 lr_scheduler: The learning rate scheduler. 73 log_image_interval: The interval for saving images during logging, in training iterations. 74 mixed_precision: Whether to train with mixed precision. 75 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 76 logger: The logger class. Will be instantiated for logging. 77 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 78 logger_kwargs: The keyword arguments for the logger class. 79 id_: Unique identifier for the trainer. If None then `name` will be used. 80 save_root: The root folder for saving the checkpoint and logs. 81 compile_model: Whether to compile the model before training. 82 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 83 """ 84 def __init__( 85 self, 86 name: Optional[str], 87 train_loader: torch.utils.data.DataLoader, 88 val_loader: torch.utils.data.DataLoader, 89 model: torch.nn.Module, 90 loss: torch.nn.Module, 91 optimizer: torch.optim.Optimizer, 92 metric: Callable, 93 device: Union[str, torch.device], 94 lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 95 log_image_interval: int = 100, 96 mixed_precision: bool = True, 97 early_stopping: Optional[int] = None, 98 logger=TensorboardLogger, 99 logger_kwargs: Optional[Dict[str, Any]] = None, 100 id_: Optional[str] = None, 101 save_root: Optional[str] = None, 102 compile_model: Optional[Union[bool, str]] = None, 103 rank: Optional[int] = None, 104 ): 105 if name is None and not issubclass(logger, WandbLogger): 106 raise TypeError("Name cannot be None if not using the WandbLogger") 107 108 if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): 109 raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.") 110 111 self._generate_name = name is None 112 self.name = name 113 self.id_ = id_ or name 114 self.train_loader = train_loader 115 self.val_loader = val_loader 116 self.model = model 117 self.loss = loss 118 self.optimizer = optimizer 119 self.metric = metric 120 self.device = torch.device(device) 121 self.lr_scheduler = lr_scheduler 122 self.log_image_interval = log_image_interval 123 self.save_root = save_root 124 self.compile_model = compile_model 125 self.rank = rank 126 127 self._iteration = 0 128 self._epoch = 0 129 self._best_epoch = 0 130 131 self.mixed_precision = mixed_precision 132 self.early_stopping = early_stopping 133 self.train_time = 0.0 134 135 if mixed_precision: 136 self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda") 137 else: 138 self.scaler = None 139 140 self.logger_class = logger 141 self.logger_kwargs = logger_kwargs 142 self.log_image_interval = log_image_interval 143 144 @property 145 def checkpoint_folder(self): 146 assert self.id_ is not None # Because the logger may generate and set trainer.id on logger.__init__. 147 # Save_root enables saving the checkpoints somewhere else than in the local older. 148 # This is handy for filesystems with limited space, where saving the checkpoints 149 # and log files can ead to running out of space. 150 save_root = getattr(self, "save_root", None) 151 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 152 os.path.join(save_root, "./checkpoints", self.id_) 153 154 @property 155 def iteration(self): 156 return self._iteration 157 158 @property 159 def epoch(self): 160 return self._epoch 161 162 class Deserializer: 163 """Determines how to deserialize the trainer kwargs from serialized 'init_data'. 164 165 Examples: 166 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 167 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 168 169 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 170 >>> class MyTrainer(DefaultTrainer): 171 >>> def __init__(self, *args, the_answer: int, **kwargs): 172 >>> super().__init__(*args, **kwargs) 173 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 174 >>> # see DefaultTrainer.Serializer 175 >>> 176 >>> class Deserializer(DefaultTrainer.Deserializer): 177 >>> def load_the_answer(self): 178 >>> generic_answer = self.init_data["the_answer"] 179 >>> # (device dependent) special deserialization 180 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 181 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 182 >>> else: 183 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 184 185 Args: 186 init_data: The initialization data of the trainer. 187 save_path: The path where the checkpoint was saved. 188 device: The device. 189 """ 190 191 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 192 self.init_data = init_data 193 self.save_path = save_path 194 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 195 self.trainer_kwargs: Dict[str, Any] = dict( 196 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 197 ) 198 199 def load(self, kwarg_name: str, optional): 200 """@private 201 """ 202 # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name' 203 if kwarg_name == "device": 204 pass # deserialized in __init__ 205 elif kwarg_name.endswith("_loader"): 206 self.load_data_loader(kwarg_name, optional) 207 else: 208 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 209 load(kwarg_name, optional=optional) 210 211 def load_data_loader(self, loader_name, optional) -> None: 212 """@private 213 """ 214 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 215 if ds is None and optional: 216 return 217 218 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 219 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 220 # monkey patch shuffle loader_name to the loader 221 loader.shuffle = loader_kwargs.get("shuffle", False) 222 self.trainer_kwargs[loader_name] = loader 223 224 def load_generic( 225 self, 226 kwarg_name: str, 227 *dynamic_args: Dict, 228 optional: bool, 229 only_class: bool = False, 230 dynamic_kwargs: Optional[Dict[str, Any]] = None, 231 ) -> None: 232 """@private 233 """ 234 if kwarg_name in self.init_data: 235 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 236 return 237 238 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 239 if this_cls is None: 240 if optional: 241 return 242 else: 243 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 244 245 assert isinstance(this_cls, str), this_cls 246 assert "." in this_cls, this_cls 247 cls_p, cls_m = this_cls.rsplit(".", 1) 248 this_cls = getattr(import_module(cls_p), cls_m) 249 if only_class: 250 self.trainer_kwargs[kwarg_name] = this_cls 251 else: 252 self.trainer_kwargs[kwarg_name] = this_cls( 253 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 254 ) 255 256 def load_name(self, kwarg_name: str, optional: bool): 257 """@private 258 """ 259 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 260 261 def load_optimizer(self, kwarg_name: str, optional: bool): 262 """@private 263 """ 264 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 265 266 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 267 """@private 268 """ 269 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 270 271 # todo: remove and rename kwarg 'logger' to 'logger_class' 272 def load_logger(self, kwarg_name: str, optional: bool): 273 """@private 274 """ 275 assert kwarg_name == "logger" 276 self.load_generic("logger", optional=optional, only_class=True) 277 278 @staticmethod 279 def _get_save_dict(save_path, device): 280 if not os.path.exists(save_path): 281 raise ValueError(f"Cannot find checkpoint {save_path}") 282 return torch.load(save_path, map_location=device, weights_only=False) 283 284 @classmethod 285 def from_checkpoint( 286 cls, 287 checkpoint_folder: Union[os.PathLike, str], 288 name: Literal["best", "latest"] = "best", 289 device: Optional[Union[str, torch.device]] = None, 290 ): 291 """@private 292 """ 293 save_path = os.path.join(checkpoint_folder, f"{name}.pt") 294 # make sure the correct device is set if we don't have access to CUDA 295 if not torch.cuda.is_available(): 296 device = "cpu" 297 save_dict = cls._get_save_dict(save_path, device) 298 deserializer = cls.Deserializer(save_dict["init"], save_path, device) 299 300 has_kwargs = False 301 deserialized = [] 302 for name, parameter in inspect.signature(cls).parameters.items(): 303 if name == "kwargs": 304 has_kwargs = True 305 continue 306 deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty) 307 deserialized.append(name) 308 309 # to deserialze kwargs we can't rely on inspecting the signature, so we 310 # go through the remaning kwarg names in init data instead 311 if has_kwargs: 312 kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized)) 313 for name in kwarg_names: 314 if name.endswith("_kwargs"): 315 continue 316 elif name.endswith("_dataset"): 317 deserializer.load(name.replace("dataset", "loader"), optional=False) 318 elif name.endswith("_class"): 319 deserializer.load(name.replace("_class", ""), optional=False) 320 else: 321 deserializer.load(name, optional=False) 322 323 trainer = cls(**deserializer.trainer_kwargs) 324 trainer._initialize(0, save_dict) 325 trainer._is_initialized = True 326 return trainer 327 328 class Serializer: 329 """Implements how to serialize trainer kwargs from a trainer instance. 330 331 Examples: 332 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 333 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 334 called by the `dump()` method when appropriate cover most cases already. 335 336 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 337 'the_answer' attribute: 338 >>> class MyTrainer(DefaultTrainer): 339 >>> def __init__(self, *args, the_answer: int, **kwargs): 340 >>> super().__init__(*args, **kwargs) 341 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 342 >>> # but let's make things more interesting... 343 >>> self.the = the_answer // 10 344 >>> self.answer = the_answer % 10 345 >>> 346 >>> class Serializer(DefaultTrainer.Serializer): 347 >>> trainer: MyTrainer 348 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 349 >>> assert kwarg_name == "the_answer" 350 >>> # populate self.init_data with the serialized data required by Deserializer 351 >>> # to restore the trainer kwargs 352 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 353 354 This example with both Serializer and Deserializer adds `the_answer` kwarg, 355 while saving it in two separate entries 'the' and 'answer' 356 >>> class MyTrainer(DefaultTrainer): 357 >>> def __init__(self, *args, the_answer: int, **kwargs): 358 >>> super().__init__(*args, **kwargs) 359 >>> self.the_answer = the_answer 360 >>> 361 >>> class Serializer(DefaultTrainer.Serializer): 362 >>> trainer: MyTrainer 363 >>> def dump_the_answer(self, kwarg_name: str): 364 >>> assert kwarg_name == "the_answer" 365 >>> self.init_data.update({ 366 >>> "the": self.trainer.the_answer // 10, 367 >>> "answer": self.trainer.the_answer % 10 368 >>> }) 369 >>> 370 >>> class Deserializer(DefaultTrainer.Deserializer): 371 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 372 >>> assert kwarg_name == "the_answer" 373 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 374 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 375 376 Args: 377 trainer: The trainer instance. 378 """ 379 380 def __init__(self, trainer: DefaultTrainer): 381 self.trainer = trainer 382 self.init_data = {} # to be populated during serialization process 383 384 def dump(self, kwarg_name: str) -> None: 385 """@private 386 """ 387 dumper = getattr(self, f"dump_{kwarg_name}", None) 388 if dumper is not None: 389 dumper(kwarg_name) 390 elif kwarg_name.endswith("_loader"): 391 self.dump_data_loader(kwarg_name) 392 elif kwarg_name.endswith("_class"): 393 self.dump_generic_class(kwarg_name) 394 elif not hasattr(self.trainer, kwarg_name): 395 raise AttributeError( 396 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 397 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 398 ) 399 else: 400 assert hasattr(self.trainer, kwarg_name) 401 obj = getattr(self.trainer, kwarg_name) 402 if obj is None or type(obj) in ( 403 bool, 404 bytearray, 405 bytes, 406 dict, 407 float, 408 frozenset, 409 int, 410 list, 411 set, 412 str, 413 tuple, 414 ): 415 self.dump_generic_builtin(kwarg_name) 416 else: 417 self.dump_generic_instance(kwarg_name) 418 419 def dump_generic_builtin(self, kwarg_name: str) -> None: 420 """@private 421 """ 422 assert hasattr(self.trainer, kwarg_name) 423 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 424 425 def dump_generic_class(self, kwarg_name: str) -> None: 426 """@private 427 """ 428 assert hasattr(self.trainer, kwarg_name) 429 assert kwarg_name.endswith("_class") 430 obj = getattr(self.trainer, kwarg_name) 431 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 432 433 def dump_generic_instance(self, kwarg_name: str) -> None: 434 """@private 435 """ 436 assert hasattr(self.trainer, kwarg_name) 437 instance = getattr(self.trainer, kwarg_name) 438 self.init_data.update( 439 { 440 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 441 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 442 } 443 ) 444 445 def dump_device(self, kwarg_name: str): 446 """@private 447 """ 448 assert hasattr(self.trainer, kwarg_name) 449 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 450 451 def dump_data_loader(self, kwarg_name: str) -> None: 452 """@private 453 """ 454 assert hasattr(self.trainer, kwarg_name) 455 loader = getattr(self.trainer, kwarg_name) 456 if loader is None: 457 return 458 self.init_data.update( 459 { 460 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 461 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 462 } 463 ) 464 465 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 466 """@private 467 """ 468 self.dump_generic_class(f"{kwarg_name}_class") 469 470 def dump_model(self, kwarg_name: str): 471 """@private 472 """ 473 if is_compiled(self.trainer.model): 474 self.init_data.update( 475 {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs} 476 ) 477 else: 478 self.dump_generic_instance("model") 479 480 def _build_init(self) -> Dict[str, Any]: 481 serializer = self.Serializer(self) 482 for name in inspect.signature(self.__class__).parameters: 483 # special rules to serialize kwargs 484 # if a trainer class inherits from DefaultTrainer and has **kwargs 485 # they need to be saved in self._kwargs 486 if name == "kwargs": 487 if not hasattr(self, "_kwargs"): 488 msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\ 489 "Please add self._kwargs to its __init__ function" 490 raise RuntimeError(msg) 491 kwargs = getattr(self, "_kwargs") 492 for kwarg_name in kwargs: 493 serializer.dump(kwarg_name) 494 continue 495 serializer.dump(name) 496 497 return serializer.init_data 498 499 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 500 assert self.train_loader is not None 501 assert self.val_loader is not None 502 assert self.model is not None 503 assert self.loss is not None 504 assert self.optimizer is not None 505 assert self.metric is not None 506 assert self.device is not None 507 508 if load_from_checkpoint is not None: 509 self.load_checkpoint(load_from_checkpoint) 510 511 if sum((iterations is not None, epochs is not None)) != 1: 512 raise ValueError( 513 "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer." 514 f"You have passed 'iterations'={iterations} and 'epochs'={epochs}" 515 ) 516 517 if epochs is None: 518 epochs = int(np.ceil(float(iterations) / len(self.train_loader))) 519 else: 520 iterations = epochs * len(self.train_loader) 521 522 self.max_iteration = self._iteration + iterations 523 self.max_epoch = self._epoch + epochs 524 525 if not getattr(self, "_is_initialized", False): 526 # check if we compile the model (only supported by pytorch 2) 527 # to enable (de)serialization of compiled models, we keep track of the model class and kwargs 528 if is_compiled(self.model): 529 warnings.warn( 530 "You have passed a compiled model to the trainer." 531 "It will not be possible to (de)serialize the trainer with it." 532 "If you want to be able to do this please pass the normal model." 533 "It can be automatically compiled by setting 'compile_model' to True" 534 ) 535 self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}" 536 self._model_kwargs = get_constructor_arguments(self.model) 537 self.model = auto_compile(self.model, self.compile_model) 538 539 self.model.to(self.device) 540 self.loss.to(self.device) 541 542 # this saves all the information that is necessary 543 # to fully load the trainer from the checkpoint 544 self.init_data = self._build_init() 545 546 if self.logger_class is None: 547 self.logger = None 548 else: 549 # may set self.name if self.name is None 550 save_root = getattr(self, "save_root", None) 551 try: 552 self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {})) 553 except PermissionError: 554 warnings.warn( 555 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 556 "The most likely reason for this is that you copied the checkpoint somewhere else," 557 "so we skip this error to enable loading the model from this checkpoint." 558 ) 559 560 try: 561 os.makedirs(self.checkpoint_folder, exist_ok=True) 562 except PermissionError: 563 warnings.warn( 564 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 565 "The most likely reason for this is that you copied the checkpoint somewhere else," 566 "so we skip this error to enable loading the model from this checkpoint." 567 ) 568 pass 569 570 best_metric = np.inf 571 return best_metric 572 573 def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict): 574 """@private 575 """ 576 save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") 577 extra_init_dict = extra_save_dict.pop("init", {}) 578 save_dict = { 579 "iteration": self._iteration, 580 "epoch": self._epoch, 581 "best_epoch": self._best_epoch, 582 "best_metric": best_metric, 583 "current_metric": current_metric, 584 "model_state": self.model.state_dict(), 585 "optimizer_state": self.optimizer.state_dict(), 586 "init": self.init_data | extra_init_dict, 587 "train_time": train_time, 588 } 589 save_dict.update(**extra_save_dict) 590 if self.scaler is not None: 591 save_dict.update({"scaler_state": self.scaler.state_dict()}) 592 if self.lr_scheduler is not None: 593 save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()}) 594 595 rank = getattr(self, "rank", None) 596 if rank is None or rank == 0: 597 torch.save(save_dict, save_path) 598 599 def load_checkpoint(self, checkpoint="best"): 600 """@private 601 """ 602 if isinstance(checkpoint, str): 603 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 604 if not os.path.exists(save_path): 605 warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.") 606 return 607 save_dict = torch.load(save_path, weights_only=False) 608 elif isinstance(checkpoint, dict): 609 save_dict = checkpoint 610 else: 611 raise RuntimeError 612 613 self._iteration = save_dict["iteration"] 614 self._epoch = save_dict["epoch"] 615 self._best_epoch = save_dict["best_epoch"] 616 self.best_metric = save_dict["best_metric"] 617 self.current_metric = save_dict["current_metric"] 618 self.train_time = save_dict.get("train_time", 0.0) 619 620 model_state = save_dict["model_state"] 621 # to enable loading compiled models 622 compiled_prefix = "_orig_mod." 623 model_state = OrderedDict( 624 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()] 625 ) 626 self.model.load_state_dict(model_state) 627 # we need to send the network to the device before loading the optimizer state! 628 self.model.to(self.device) 629 630 self.optimizer.load_state_dict(save_dict["optimizer_state"]) 631 if self.scaler is not None: 632 self.scaler.load_state_dict(save_dict["scaler_state"]) 633 if self.lr_scheduler is not None: 634 self.lr_scheduler.load_state_dict(save_dict["scheduler_state"]) 635 636 return save_dict 637 638 def _verify_if_training_completed(self, checkpoint="latest"): 639 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 640 save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None 641 if save_dict and self.max_iteration == save_dict.get("iteration"): 642 return True 643 return False 644 645 def fit( 646 self, 647 iterations: Optional[int] = None, 648 load_from_checkpoint: Optional[Union[os.PathLike, str]] = None, 649 epochs: Optional[int] = None, 650 save_every_kth_epoch: Optional[int] = None, 651 progress=None, 652 overwrite_training: bool = True, 653 ): 654 """Run neural network training. 655 656 Exactly one of 'iterations' or 'epochs' has to be passed. 657 658 Args: 659 iterations: How long to train, specified in iterations. 660 load_from_checkpoint: Path to a checkpoint from where training should be continued . 661 epochs: How long to train, specified in epochs. 662 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 663 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 664 progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface. 665 overwrite_training: Whether to overwrite existing checkpoints in the save directory. 666 """ 667 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 668 669 if not overwrite_training: 670 if load_from_checkpoint is not None: 671 raise ValueError( 672 "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time." 673 ) 674 675 if self._verify_if_training_completed(): 676 print( 677 f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs " 678 "and 'overwrite_training' is set to 'False'." 679 ) 680 print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.") 681 return 682 683 print( 684 "Start fitting for", 685 self.max_iteration - self._iteration, 686 "iterations / ", 687 self.max_epoch - self._epoch, 688 "epochs", 689 ) 690 print("with", len(self.train_loader), "iterations per epoch") 691 692 if self.mixed_precision: 693 train_epoch = self._train_epoch_mixed 694 validate = self._validate_mixed 695 print("Training with mixed precision") 696 else: 697 train_epoch = self._train_epoch 698 validate = self._validate 699 print("Training with single precision") 700 701 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 702 if progress is None: 703 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 704 else: 705 progress.total = total_iterations 706 progress.set_description(f"Epoch {self._epoch}") 707 708 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 709 train_epochs = self.max_epoch - self._epoch 710 t_start = time.time() 711 for epoch in range(train_epochs): 712 713 # Ensure data is shuffled differently at each epoch. 714 try: 715 self.train_loader.sampler.set_epoch(epoch) 716 except AttributeError: 717 pass 718 719 # Run training and validation for this epoch 720 t_per_iter = train_epoch(progress) 721 current_metric = validate() 722 723 # perform all the post-epoch steps: 724 725 # apply the learning rate scheduler 726 if self.lr_scheduler is not None: 727 self.lr_scheduler.step(current_metric) 728 729 # how long did we train in total? 730 total_train_time = (time.time() - t_start) + self.train_time 731 732 # save this checkpoint as the new best checkpoint if 733 # it has the best overall validation metric 734 if current_metric < best_metric: 735 best_metric = current_metric 736 self._best_epoch = self._epoch 737 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 738 739 # save this checkpoint as the latest checkpoint 740 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 741 742 # if we save after every k-th epoch then check if we need to save now 743 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 744 self.save_checkpoint( 745 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 746 ) 747 748 # if early stopping has been specified then check if the stopping condition is met 749 if self.early_stopping is not None: 750 epochs_since_best = self._epoch - self._best_epoch 751 if epochs_since_best > self.early_stopping: 752 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 753 break 754 755 self._epoch += 1 756 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 757 758 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 759 print(f"The best epoch is number {self._best_epoch}.") 760 761 if self._generate_name: 762 self.name = None 763 764 # Update the train time 765 self.train_time = total_train_time 766 767 # TODO save the model to wandb if we have the wandb logger 768 if isinstance(self.logger, WandbLogger): 769 self.logger.get_wandb().finish() 770 771 def _backprop(self, loss): 772 loss.backward() 773 self.optimizer.step() 774 775 def _backprop_mixed(self, loss): 776 self.scaler.scale(loss).backward() 777 self.scaler.step(self.optimizer) 778 self.scaler.update() 779 780 def _train_epoch(self, progress): 781 return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop) 782 783 def _train_epoch_mixed(self, progress): 784 return self._train_epoch_impl( 785 progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"), 786 self._backprop_mixed 787 ) 788 789 def _forward_and_loss(self, x, y): 790 pred = self.model(x) 791 if self._iteration % self.log_image_interval == 0: 792 if pred.requires_grad: 793 pred.retain_grad() 794 795 loss = self.loss(pred, y) 796 return pred, loss 797 798 def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]): 799 self.model.train() 800 801 n_iter = 0 802 t_per_iter = time.time() 803 for x, y in self.train_loader: 804 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 805 806 self.optimizer.zero_grad() 807 808 with forward_context(): 809 pred, loss = self._forward_and_loss(x, y) 810 811 backprop(loss) 812 813 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 814 if self.logger is not None: 815 self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True) 816 817 self._iteration += 1 818 n_iter += 1 819 if self._iteration >= self.max_iteration: 820 break 821 progress.update(1) 822 823 t_per_iter = (time.time() - t_per_iter) / n_iter 824 return t_per_iter 825 826 def _validate(self): 827 return self._validate_impl(contextlib.nullcontext) 828 829 def _validate_mixed(self): 830 return self._validate_impl( 831 partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda") 832 ) 833 834 def _validate_impl(self, forward_context): 835 self.model.eval() 836 837 metric_val = 0.0 838 loss_val = 0.0 839 840 with torch.no_grad(): 841 for x, y in self.val_loader: 842 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 843 with forward_context(): 844 pred, loss = self._forward_and_loss(x, y) 845 metric = self.metric(pred, y) 846 847 loss_val += loss.item() 848 metric_val += metric.item() 849 850 metric_val /= len(self.val_loader) 851 loss_val /= len(self.val_loader) 852 if self.logger is not None: 853 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred) 854 return metric_val
Trainer class for training a segmentation network.
The trainer class implements the core logic for training a network with pytorch.
It implements a training loop to run training and validation, which is started with fit
.
The checkpoints and logs from the training run will be saved in the current working directory,
or in the directory specifified by save_root
. Training can be continued from a checkpoint
by passing it's location to the load_from_checkpoint
argument of fit
.
A pre-configured instance of the trainer can be obtained from torch_em.default_segmentation_trainer
.
Alternatively, the trainer class can also be instantiated as in this example:
import torch
from torch_em.loss import DiceLoss
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
from torch_em.trainer import DefaultTrainer
# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
# Create the model and optimizer.
model = UNet2d(in_channels=1, out_channels=1)
optimizer = torch.optim.AdamW(model.parameters())
trainer = DefaultTrainer(
name="unet-training",
train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
model=model,
loss=DiceLoss(), # The loss function.
optimizer=optimizer,
metric=DiceLoss(), # The metric. The trainer expects smaller values to represent better results.
device="cuda", # The device to use for training.
)
trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations.
Arguments:
- name: The name of the checkpoint that will be created by the trainer.
- train_loader: The data loader containing the training data.
- val_loader: The data loader containing the validation data.
- model: The model to train.
- loss: The loss function for training.
- optimizer: The optimizer.
- metric: The metric for validation.
- device: The torch device to use for training. If None, will use a GPU if available.
- lr_scheduler: The learning rate scheduler.
- log_image_interval: The interval for saving images during logging, in training iterations.
- mixed_precision: Whether to train with mixed precision.
- early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
- logger: The logger class. Will be instantiated for logging.
By default uses
torch_em.training.tensorboard_logger.TensorboardLogger
. - logger_kwargs: The keyword arguments for the logger class.
- id_: Unique identifier for the trainer. If None then
name
will be used. - save_root: The root folder for saving the checkpoint and logs.
- compile_model: Whether to compile the model before training.
- rank: Rank argument for distributed training. See
torch_em.multi_gpu_training
for details.
84 def __init__( 85 self, 86 name: Optional[str], 87 train_loader: torch.utils.data.DataLoader, 88 val_loader: torch.utils.data.DataLoader, 89 model: torch.nn.Module, 90 loss: torch.nn.Module, 91 optimizer: torch.optim.Optimizer, 92 metric: Callable, 93 device: Union[str, torch.device], 94 lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 95 log_image_interval: int = 100, 96 mixed_precision: bool = True, 97 early_stopping: Optional[int] = None, 98 logger=TensorboardLogger, 99 logger_kwargs: Optional[Dict[str, Any]] = None, 100 id_: Optional[str] = None, 101 save_root: Optional[str] = None, 102 compile_model: Optional[Union[bool, str]] = None, 103 rank: Optional[int] = None, 104 ): 105 if name is None and not issubclass(logger, WandbLogger): 106 raise TypeError("Name cannot be None if not using the WandbLogger") 107 108 if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): 109 raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.") 110 111 self._generate_name = name is None 112 self.name = name 113 self.id_ = id_ or name 114 self.train_loader = train_loader 115 self.val_loader = val_loader 116 self.model = model 117 self.loss = loss 118 self.optimizer = optimizer 119 self.metric = metric 120 self.device = torch.device(device) 121 self.lr_scheduler = lr_scheduler 122 self.log_image_interval = log_image_interval 123 self.save_root = save_root 124 self.compile_model = compile_model 125 self.rank = rank 126 127 self._iteration = 0 128 self._epoch = 0 129 self._best_epoch = 0 130 131 self.mixed_precision = mixed_precision 132 self.early_stopping = early_stopping 133 self.train_time = 0.0 134 135 if mixed_precision: 136 self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda") 137 else: 138 self.scaler = None 139 140 self.logger_class = logger 141 self.logger_kwargs = logger_kwargs 142 self.log_image_interval = log_image_interval
144 @property 145 def checkpoint_folder(self): 146 assert self.id_ is not None # Because the logger may generate and set trainer.id on logger.__init__. 147 # Save_root enables saving the checkpoints somewhere else than in the local older. 148 # This is handy for filesystems with limited space, where saving the checkpoints 149 # and log files can ead to running out of space. 150 save_root = getattr(self, "save_root", None) 151 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 152 os.path.join(save_root, "./checkpoints", self.id_)
645 def fit( 646 self, 647 iterations: Optional[int] = None, 648 load_from_checkpoint: Optional[Union[os.PathLike, str]] = None, 649 epochs: Optional[int] = None, 650 save_every_kth_epoch: Optional[int] = None, 651 progress=None, 652 overwrite_training: bool = True, 653 ): 654 """Run neural network training. 655 656 Exactly one of 'iterations' or 'epochs' has to be passed. 657 658 Args: 659 iterations: How long to train, specified in iterations. 660 load_from_checkpoint: Path to a checkpoint from where training should be continued . 661 epochs: How long to train, specified in epochs. 662 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 663 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 664 progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface. 665 overwrite_training: Whether to overwrite existing checkpoints in the save directory. 666 """ 667 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 668 669 if not overwrite_training: 670 if load_from_checkpoint is not None: 671 raise ValueError( 672 "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time." 673 ) 674 675 if self._verify_if_training_completed(): 676 print( 677 f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs " 678 "and 'overwrite_training' is set to 'False'." 679 ) 680 print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.") 681 return 682 683 print( 684 "Start fitting for", 685 self.max_iteration - self._iteration, 686 "iterations / ", 687 self.max_epoch - self._epoch, 688 "epochs", 689 ) 690 print("with", len(self.train_loader), "iterations per epoch") 691 692 if self.mixed_precision: 693 train_epoch = self._train_epoch_mixed 694 validate = self._validate_mixed 695 print("Training with mixed precision") 696 else: 697 train_epoch = self._train_epoch 698 validate = self._validate 699 print("Training with single precision") 700 701 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 702 if progress is None: 703 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 704 else: 705 progress.total = total_iterations 706 progress.set_description(f"Epoch {self._epoch}") 707 708 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 709 train_epochs = self.max_epoch - self._epoch 710 t_start = time.time() 711 for epoch in range(train_epochs): 712 713 # Ensure data is shuffled differently at each epoch. 714 try: 715 self.train_loader.sampler.set_epoch(epoch) 716 except AttributeError: 717 pass 718 719 # Run training and validation for this epoch 720 t_per_iter = train_epoch(progress) 721 current_metric = validate() 722 723 # perform all the post-epoch steps: 724 725 # apply the learning rate scheduler 726 if self.lr_scheduler is not None: 727 self.lr_scheduler.step(current_metric) 728 729 # how long did we train in total? 730 total_train_time = (time.time() - t_start) + self.train_time 731 732 # save this checkpoint as the new best checkpoint if 733 # it has the best overall validation metric 734 if current_metric < best_metric: 735 best_metric = current_metric 736 self._best_epoch = self._epoch 737 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 738 739 # save this checkpoint as the latest checkpoint 740 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 741 742 # if we save after every k-th epoch then check if we need to save now 743 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 744 self.save_checkpoint( 745 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 746 ) 747 748 # if early stopping has been specified then check if the stopping condition is met 749 if self.early_stopping is not None: 750 epochs_since_best = self._epoch - self._best_epoch 751 if epochs_since_best > self.early_stopping: 752 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 753 break 754 755 self._epoch += 1 756 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 757 758 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 759 print(f"The best epoch is number {self._best_epoch}.") 760 761 if self._generate_name: 762 self.name = None 763 764 # Update the train time 765 self.train_time = total_train_time 766 767 # TODO save the model to wandb if we have the wandb logger 768 if isinstance(self.logger, WandbLogger): 769 self.logger.get_wandb().finish()
Run neural network training.
Exactly one of 'iterations' or 'epochs' has to be passed.
Arguments:
- iterations: How long to train, specified in iterations.
- load_from_checkpoint: Path to a checkpoint from where training should be continued .
- epochs: How long to train, specified in epochs.
- save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
- progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
- overwrite_training: Whether to overwrite existing checkpoints in the save directory.
162 class Deserializer: 163 """Determines how to deserialize the trainer kwargs from serialized 'init_data'. 164 165 Examples: 166 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 167 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 168 169 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 170 >>> class MyTrainer(DefaultTrainer): 171 >>> def __init__(self, *args, the_answer: int, **kwargs): 172 >>> super().__init__(*args, **kwargs) 173 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 174 >>> # see DefaultTrainer.Serializer 175 >>> 176 >>> class Deserializer(DefaultTrainer.Deserializer): 177 >>> def load_the_answer(self): 178 >>> generic_answer = self.init_data["the_answer"] 179 >>> # (device dependent) special deserialization 180 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 181 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 182 >>> else: 183 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 184 185 Args: 186 init_data: The initialization data of the trainer. 187 save_path: The path where the checkpoint was saved. 188 device: The device. 189 """ 190 191 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 192 self.init_data = init_data 193 self.save_path = save_path 194 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 195 self.trainer_kwargs: Dict[str, Any] = dict( 196 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 197 ) 198 199 def load(self, kwarg_name: str, optional): 200 """@private 201 """ 202 # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name' 203 if kwarg_name == "device": 204 pass # deserialized in __init__ 205 elif kwarg_name.endswith("_loader"): 206 self.load_data_loader(kwarg_name, optional) 207 else: 208 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 209 load(kwarg_name, optional=optional) 210 211 def load_data_loader(self, loader_name, optional) -> None: 212 """@private 213 """ 214 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 215 if ds is None and optional: 216 return 217 218 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 219 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 220 # monkey patch shuffle loader_name to the loader 221 loader.shuffle = loader_kwargs.get("shuffle", False) 222 self.trainer_kwargs[loader_name] = loader 223 224 def load_generic( 225 self, 226 kwarg_name: str, 227 *dynamic_args: Dict, 228 optional: bool, 229 only_class: bool = False, 230 dynamic_kwargs: Optional[Dict[str, Any]] = None, 231 ) -> None: 232 """@private 233 """ 234 if kwarg_name in self.init_data: 235 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 236 return 237 238 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 239 if this_cls is None: 240 if optional: 241 return 242 else: 243 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 244 245 assert isinstance(this_cls, str), this_cls 246 assert "." in this_cls, this_cls 247 cls_p, cls_m = this_cls.rsplit(".", 1) 248 this_cls = getattr(import_module(cls_p), cls_m) 249 if only_class: 250 self.trainer_kwargs[kwarg_name] = this_cls 251 else: 252 self.trainer_kwargs[kwarg_name] = this_cls( 253 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 254 ) 255 256 def load_name(self, kwarg_name: str, optional: bool): 257 """@private 258 """ 259 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 260 261 def load_optimizer(self, kwarg_name: str, optional: bool): 262 """@private 263 """ 264 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 265 266 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 267 """@private 268 """ 269 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 270 271 # todo: remove and rename kwarg 'logger' to 'logger_class' 272 def load_logger(self, kwarg_name: str, optional: bool): 273 """@private 274 """ 275 assert kwarg_name == "logger" 276 self.load_generic("logger", optional=optional, only_class=True)
Determines how to deserialize the trainer kwargs from serialized 'init_data'.
Examples:
To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. Note that
DefaultTrainer.Deserializer.load_generic()
covers most cases already.This example adds
the_answer
kwarg, which requires 'calculations' upon initialization:>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, >>> # see DefaultTrainer.Serializer >>> >>> class Deserializer(DefaultTrainer.Deserializer): >>> def load_the_answer(self): >>> generic_answer = self.init_data["the_answer"] >>> # (device dependent) special deserialization >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 >>> else: >>> self.trainer_kwargs["the_answer"] = generic_answer * 2
Arguments:
- init_data: The initialization data of the trainer.
- save_path: The path where the checkpoint was saved.
- device: The device.
191 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 192 self.init_data = init_data 193 self.save_path = save_path 194 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 195 self.trainer_kwargs: Dict[str, Any] = dict( 196 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 197 )
328 class Serializer: 329 """Implements how to serialize trainer kwargs from a trainer instance. 330 331 Examples: 332 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 333 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 334 called by the `dump()` method when appropriate cover most cases already. 335 336 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 337 'the_answer' attribute: 338 >>> class MyTrainer(DefaultTrainer): 339 >>> def __init__(self, *args, the_answer: int, **kwargs): 340 >>> super().__init__(*args, **kwargs) 341 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 342 >>> # but let's make things more interesting... 343 >>> self.the = the_answer // 10 344 >>> self.answer = the_answer % 10 345 >>> 346 >>> class Serializer(DefaultTrainer.Serializer): 347 >>> trainer: MyTrainer 348 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 349 >>> assert kwarg_name == "the_answer" 350 >>> # populate self.init_data with the serialized data required by Deserializer 351 >>> # to restore the trainer kwargs 352 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 353 354 This example with both Serializer and Deserializer adds `the_answer` kwarg, 355 while saving it in two separate entries 'the' and 'answer' 356 >>> class MyTrainer(DefaultTrainer): 357 >>> def __init__(self, *args, the_answer: int, **kwargs): 358 >>> super().__init__(*args, **kwargs) 359 >>> self.the_answer = the_answer 360 >>> 361 >>> class Serializer(DefaultTrainer.Serializer): 362 >>> trainer: MyTrainer 363 >>> def dump_the_answer(self, kwarg_name: str): 364 >>> assert kwarg_name == "the_answer" 365 >>> self.init_data.update({ 366 >>> "the": self.trainer.the_answer // 10, 367 >>> "answer": self.trainer.the_answer % 10 368 >>> }) 369 >>> 370 >>> class Deserializer(DefaultTrainer.Deserializer): 371 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 372 >>> assert kwarg_name == "the_answer" 373 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 374 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 375 376 Args: 377 trainer: The trainer instance. 378 """ 379 380 def __init__(self, trainer: DefaultTrainer): 381 self.trainer = trainer 382 self.init_data = {} # to be populated during serialization process 383 384 def dump(self, kwarg_name: str) -> None: 385 """@private 386 """ 387 dumper = getattr(self, f"dump_{kwarg_name}", None) 388 if dumper is not None: 389 dumper(kwarg_name) 390 elif kwarg_name.endswith("_loader"): 391 self.dump_data_loader(kwarg_name) 392 elif kwarg_name.endswith("_class"): 393 self.dump_generic_class(kwarg_name) 394 elif not hasattr(self.trainer, kwarg_name): 395 raise AttributeError( 396 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 397 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 398 ) 399 else: 400 assert hasattr(self.trainer, kwarg_name) 401 obj = getattr(self.trainer, kwarg_name) 402 if obj is None or type(obj) in ( 403 bool, 404 bytearray, 405 bytes, 406 dict, 407 float, 408 frozenset, 409 int, 410 list, 411 set, 412 str, 413 tuple, 414 ): 415 self.dump_generic_builtin(kwarg_name) 416 else: 417 self.dump_generic_instance(kwarg_name) 418 419 def dump_generic_builtin(self, kwarg_name: str) -> None: 420 """@private 421 """ 422 assert hasattr(self.trainer, kwarg_name) 423 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 424 425 def dump_generic_class(self, kwarg_name: str) -> None: 426 """@private 427 """ 428 assert hasattr(self.trainer, kwarg_name) 429 assert kwarg_name.endswith("_class") 430 obj = getattr(self.trainer, kwarg_name) 431 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 432 433 def dump_generic_instance(self, kwarg_name: str) -> None: 434 """@private 435 """ 436 assert hasattr(self.trainer, kwarg_name) 437 instance = getattr(self.trainer, kwarg_name) 438 self.init_data.update( 439 { 440 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 441 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 442 } 443 ) 444 445 def dump_device(self, kwarg_name: str): 446 """@private 447 """ 448 assert hasattr(self.trainer, kwarg_name) 449 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 450 451 def dump_data_loader(self, kwarg_name: str) -> None: 452 """@private 453 """ 454 assert hasattr(self.trainer, kwarg_name) 455 loader = getattr(self.trainer, kwarg_name) 456 if loader is None: 457 return 458 self.init_data.update( 459 { 460 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 461 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 462 } 463 ) 464 465 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 466 """@private 467 """ 468 self.dump_generic_class(f"{kwarg_name}_class") 469 470 def dump_model(self, kwarg_name: str): 471 """@private 472 """ 473 if is_compiled(self.trainer.model): 474 self.init_data.update( 475 {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs} 476 ) 477 else: 478 self.dump_generic_instance("model")
Implements how to serialize trainer kwargs from a trainer instance.
Examples:
To extend the serialization process you can inherite from this Serializer in a derived Trainer class. Note that the methods
dump_generic_builtin()
,dump_generic_class()
anddump_generic_instance()
called by thedump()
method when appropriate cover most cases already.This example adds
the_answer
kwarg, which requires extra steps on dumping only because we don't keep a 'the_answer' attribute:>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, >>> # but let's make things more interesting... >>> self.the = the_answer // 10 >>> self.answer = the_answer % 10 >>> >>> class Serializer(DefaultTrainer.Serializer): >>> trainer: MyTrainer >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg >>> assert kwarg_name == "the_answer" >>> # populate self.init_data with the serialized data required by Deserializer >>> # to restore the trainer kwargs >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
This example with both Serializer and Deserializer adds
the_answer
kwarg, while saving it in two separate entries 'the' and 'answer'>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> self.the_answer = the_answer >>> >>> class Serializer(DefaultTrainer.Serializer): >>> trainer: MyTrainer >>> def dump_the_answer(self, kwarg_name: str): >>> assert kwarg_name == "the_answer" >>> self.init_data.update({ >>> "the": self.trainer.the_answer // 10, >>> "answer": self.trainer.the_answer % 10 >>> }) >>> >>> class Deserializer(DefaultTrainer.Deserializer): >>> def load_the_answer(self, kwarg_name: str, optional: bool): >>> assert kwarg_name == "the_answer" >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
Arguments:
- trainer: The trainer instance.