torch_em.trainer.default_trainer
1from __future__ import annotations 2 3import contextlib 4import inspect 5import os 6import time 7import warnings 8from collections import OrderedDict 9from importlib import import_module 10from typing import Any, Callable, Dict, Optional, Union 11 12import numpy as np 13import torch 14import torch.cuda.amp as amp 15from tqdm import tqdm 16 17from .tensorboard_logger import TensorboardLogger 18from .wandb_logger import WandbLogger 19from ..util import auto_compile, get_constructor_arguments, is_compiled 20 21 22class DefaultTrainer: 23 """Trainer class for 2d/3d training on a single GPU.""" 24 25 def __init__( 26 self, 27 name: Optional[str], 28 train_loader: torch.utils.data.DataLoader, 29 val_loader: torch.utils.data.DataLoader, 30 model: torch.nn.Module, 31 loss, 32 optimizer, 33 metric, 34 device: Union[str, torch.device], 35 lr_scheduler=None, 36 log_image_interval=100, 37 mixed_precision=True, 38 early_stopping=None, 39 logger=TensorboardLogger, 40 logger_kwargs: Optional[Dict[str, Any]] = None, 41 id_: Optional[str] = None, 42 save_root: Optional[str] = None, 43 compile_model: Optional[Union[bool, str]] = None, 44 ): 45 if name is None and not issubclass(logger, WandbLogger): 46 raise TypeError("Name cannot be None if not using the WandbLogger") 47 48 if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): 49 raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.") 50 51 self._generate_name = name is None 52 self.name = name 53 self.id_ = id_ or name 54 self.train_loader = train_loader 55 self.val_loader = val_loader 56 self.model = model 57 self.loss = loss 58 self.optimizer = optimizer 59 self.metric = metric 60 self.device = device 61 self.lr_scheduler = lr_scheduler 62 self.log_image_interval = log_image_interval 63 self.save_root = save_root 64 self.compile_model = compile_model 65 66 self._iteration = 0 67 self._epoch = 0 68 self._best_epoch = 0 69 70 self.mixed_precision = mixed_precision 71 self.early_stopping = early_stopping 72 self.train_time = 0.0 73 74 self.scaler = amp.GradScaler() if mixed_precision else None 75 76 self.logger_class = logger 77 self.logger_kwargs = logger_kwargs 78 self.log_image_interval = log_image_interval 79 80 @property # because the logger may generate and set trainer.id on logger.__init__ 81 def checkpoint_folder(self): 82 assert self.id_ is not None 83 # save_root enables saving the checkpoints somewhere else than in the local 84 # folder. This is handy for filesystems with limited space, where saving the checkpoints 85 # and log files can easily lead to running out of space. 86 save_root = getattr(self, "save_root", None) 87 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 88 os.path.join(save_root, "./checkpoints", self.id_) 89 90 @property 91 def iteration(self): 92 return self._iteration 93 94 @property 95 def epoch(self): 96 return self._epoch 97 98 class Deserializer: 99 """Determines how to deserialize the trainer kwargs from serialized 'init_data' 100 101 Examples: 102 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 103 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 104 105 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 106 >>> class MyTrainer(DefaultTrainer): 107 >>> def __init__(self, *args, the_answer: int, **kwargs): 108 >>> super().__init__(*args, **kwargs) 109 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 110 >>> # see DefaultTrainer.Serializer 111 >>> 112 >>> class Deserializer(DefaultTrainer.Deserializer): 113 >>> def load_the_answer(self): 114 >>> generic_answer = self.init_data["the_answer"] 115 >>> # (device dependent) special deserialization 116 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 117 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 118 >>> else: 119 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 120 """ 121 122 def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]): 123 self.init_data = init_data 124 self.save_path = save_path 125 # populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device' 126 self.trainer_kwargs: Dict[str, Any] = dict( 127 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 128 ) 129 130 def load(self, kwarg_name: str, optional): 131 """`optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'""" 132 133 if kwarg_name == "device": 134 pass # deserialized in __init__ 135 elif kwarg_name.endswith("_loader"): 136 self.load_data_loader(kwarg_name, optional) 137 else: 138 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 139 load(kwarg_name, optional=optional) 140 141 def load_data_loader(self, loader_name, optional) -> None: 142 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 143 if ds is None and optional: 144 return 145 146 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 147 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 148 # monkey patch shuffle loader_name to the loader 149 loader.shuffle = loader_kwargs.get("shuffle", False) 150 self.trainer_kwargs[loader_name] = loader 151 152 def load_generic( 153 self, 154 kwarg_name: str, 155 *dynamic_args, 156 optional: bool, 157 only_class: bool = False, 158 dynamic_kwargs: Optional[Dict[str, Any]] = None, 159 ) -> None: 160 if kwarg_name in self.init_data: 161 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 162 return 163 164 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 165 if this_cls is None: 166 if optional: 167 return 168 else: 169 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 170 171 assert isinstance(this_cls, str), this_cls 172 assert "." in this_cls, this_cls 173 cls_p, cls_m = this_cls.rsplit(".", 1) 174 this_cls = getattr(import_module(cls_p), cls_m) 175 if only_class: 176 self.trainer_kwargs[kwarg_name] = this_cls 177 else: 178 self.trainer_kwargs[kwarg_name] = this_cls( 179 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 180 ) 181 182 def load_name(self, kwarg_name: str, optional: bool): 183 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 184 185 def load_optimizer(self, kwarg_name: str, optional: bool): 186 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 187 188 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 189 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 190 191 # todo: remove and rename kwarg 'logger' to 'logger_class' 192 def load_logger(self, kwarg_name: str, optional: bool): 193 assert kwarg_name == "logger" 194 self.load_generic("logger", optional=optional, only_class=True) 195 196 @staticmethod 197 def _get_save_dict(save_path, device): 198 if not os.path.exists(save_path): 199 raise ValueError(f"Cannot find checkpoint {save_path}") 200 return torch.load(save_path, map_location=device) 201 202 @classmethod 203 def from_checkpoint(cls, checkpoint_folder, name="best", device=None): 204 save_path = os.path.join(checkpoint_folder, f"{name}.pt") 205 # make sure the correct device is set if we don't have access to CUDA 206 if not torch.cuda.is_available(): 207 device = "cpu" 208 save_dict = cls._get_save_dict(save_path, device) 209 deserializer = cls.Deserializer(save_dict["init"], save_path, device) 210 211 has_kwargs = False 212 deserialized = [] 213 for name, parameter in inspect.signature(cls).parameters.items(): 214 if name == "kwargs": 215 has_kwargs = True 216 continue 217 deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty) 218 deserialized.append(name) 219 220 # to deserialze kwargs we can't rely on inspecting the signature, so we 221 # go through the remaning kwarg names in init data instead 222 if has_kwargs: 223 kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized)) 224 for name in kwarg_names: 225 if name.endswith("_kwargs"): 226 continue 227 elif name.endswith("_dataset"): 228 deserializer.load(name.replace("dataset", "loader"), optional=False) 229 elif name.endswith("_class"): 230 deserializer.load(name.replace("_class", ""), optional=False) 231 else: 232 deserializer.load(name, optional=False) 233 234 trainer = cls(**deserializer.trainer_kwargs) 235 trainer._initialize(0, save_dict) 236 trainer._is_initialized = True 237 return trainer 238 239 class Serializer: 240 """Implements how to serialize trainer kwargs from a trainer instance 241 242 Examples: 243 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 244 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 245 called by the `dump()` method when appropriate cover most cases already. 246 247 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 248 'the_answer' attribute: 249 >>> class MyTrainer(DefaultTrainer): 250 >>> def __init__(self, *args, the_answer: int, **kwargs): 251 >>> super().__init__(*args, **kwargs) 252 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 253 >>> # but let's make things more interesting... 254 >>> self.the = the_answer // 10 255 >>> self.answer = the_answer % 10 256 >>> 257 >>> class Serializer(DefaultTrainer.Serializer): 258 >>> trainer: MyTrainer 259 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 260 >>> assert kwarg_name == "the_answer" 261 >>> # populate self.init_data with the serialized data required by Deserializer 262 >>> # to restore the trainer kwargs 263 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 264 265 This example with both Serializer and Deserializer adds `the_answer` kwarg, 266 while saving it in two separate entries 'the' and 'answer' 267 >>> class MyTrainer(DefaultTrainer): 268 >>> def __init__(self, *args, the_answer: int, **kwargs): 269 >>> super().__init__(*args, **kwargs) 270 >>> self.the_answer = the_answer 271 >>> 272 >>> class Serializer(DefaultTrainer.Serializer): 273 >>> trainer: MyTrainer 274 >>> def dump_the_answer(self, kwarg_name: str): 275 >>> assert kwarg_name == "the_answer" 276 >>> self.init_data.update({ 277 >>> "the": self.trainer.the_answer // 10, 278 >>> "answer": self.trainer.the_answer % 10 279 >>> }) 280 >>> 281 >>> class Deserializer(DefaultTrainer.Deserializer): 282 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 283 >>> assert kwarg_name == "the_answer" 284 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 285 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 286 """ 287 288 def __init__(self, trainer: DefaultTrainer): 289 self.trainer = trainer 290 self.init_data = {} # to be populated during serialization process 291 292 def dump(self, kwarg_name: str) -> None: 293 dumper = getattr(self, f"dump_{kwarg_name}", None) 294 if dumper is not None: 295 dumper(kwarg_name) 296 elif kwarg_name.endswith("_loader"): 297 self.dump_data_loader(kwarg_name) 298 elif kwarg_name.endswith("_class"): 299 self.dump_generic_class(kwarg_name) 300 elif not hasattr(self.trainer, kwarg_name): 301 raise AttributeError( 302 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 303 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 304 ) 305 else: 306 assert hasattr(self.trainer, kwarg_name) 307 obj = getattr(self.trainer, kwarg_name) 308 if obj is None or type(obj) in ( 309 bool, 310 bytearray, 311 bytes, 312 dict, 313 float, 314 frozenset, 315 int, 316 list, 317 set, 318 str, 319 tuple, 320 ): 321 self.dump_generic_builtin(kwarg_name) 322 else: 323 self.dump_generic_instance(kwarg_name) 324 325 def dump_generic_builtin(self, kwarg_name: str) -> None: 326 assert hasattr(self.trainer, kwarg_name) 327 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 328 329 def dump_generic_class(self, kwarg_name: str) -> None: 330 assert hasattr(self.trainer, kwarg_name) 331 assert kwarg_name.endswith("_class") 332 obj = getattr(self.trainer, kwarg_name) 333 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 334 335 def dump_generic_instance(self, kwarg_name: str) -> None: 336 assert hasattr(self.trainer, kwarg_name) 337 instance = getattr(self.trainer, kwarg_name) 338 self.init_data.update( 339 { 340 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 341 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 342 } 343 ) 344 345 def dump_device(self, kwarg_name: str): 346 assert hasattr(self.trainer, kwarg_name) 347 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 348 349 def dump_data_loader(self, kwarg_name: str) -> None: 350 assert hasattr(self.trainer, kwarg_name) 351 loader = getattr(self.trainer, kwarg_name) 352 if loader is None: 353 return 354 self.init_data.update( 355 { 356 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 357 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 358 } 359 ) 360 361 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 362 self.dump_generic_class(f"{kwarg_name}_class") 363 364 def dump_model(self, kwarg_name: str): 365 if is_compiled(self.trainer.model): 366 self.init_data.update( 367 { 368 "model_class": self.trainer._model_class, 369 "model_kwargs": self.trainer._model_kwargs, 370 } 371 ) 372 else: 373 self.dump_generic_instance("model") 374 375 def _build_init(self) -> Dict[str, Any]: 376 serializer = self.Serializer(self) 377 for name in inspect.signature(self.__class__).parameters: 378 # special rules to serialize kwargs 379 # if a trainer class inherits from DefaultTrainer and has **kwargs 380 # they need to be saved in self._kwargs 381 if name == "kwargs": 382 if not hasattr(self, "_kwargs"): 383 msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\ 384 "Please add self._kwargs to its __init__ function" 385 raise RuntimeError(msg) 386 kwargs = getattr(self, "_kwargs") 387 for kwarg_name in kwargs: 388 serializer.dump(kwarg_name) 389 continue 390 serializer.dump(name) 391 392 return serializer.init_data 393 394 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 395 assert self.train_loader is not None 396 assert self.val_loader is not None 397 assert self.model is not None 398 assert self.loss is not None 399 assert self.optimizer is not None 400 assert self.metric is not None 401 assert self.device is not None 402 403 if load_from_checkpoint is not None: 404 self.load_checkpoint(load_from_checkpoint) 405 406 if sum((iterations is not None, epochs is not None)) != 1: 407 raise ValueError( 408 "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer." 409 f"You have passed 'iterations'={iterations} and 'epochs'={epochs}" 410 ) 411 412 if epochs is None: 413 epochs = int(np.ceil(float(iterations) / len(self.train_loader))) 414 else: 415 iterations = epochs * len(self.train_loader) 416 417 self.max_iteration = self._iteration + iterations 418 self.max_epoch = self._epoch + epochs 419 420 if not getattr(self, "_is_initialized", False): 421 # check if we compile the model (only supported by pytorch 2) 422 # to enable (de)serialization of compiled models, we keep track of the model class and kwargs 423 if is_compiled(self.model): 424 warnings.warn( 425 "You have passed a compiled model to the trainer." 426 "It will not be possible to (de)serialize the trainer with it." 427 "If you want to be able to do this please pass the normal model." 428 "It can be automatically compiled by setting 'compile_model' to True" 429 ) 430 self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}" 431 self._model_kwargs = get_constructor_arguments(self.model) 432 self.model = auto_compile(self.model, self.compile_model) 433 434 self.model.to(self.device) 435 self.loss.to(self.device) 436 437 # this saves all the information that is necessary 438 # to fully load the trainer from the checkpoint 439 self.init_data = self._build_init() 440 441 if self.logger_class is None: 442 self.logger = None 443 else: 444 # may set self.name if self.name is None 445 save_root = getattr(self, "save_root", None) 446 self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {})) 447 448 try: 449 os.makedirs(self.checkpoint_folder, exist_ok=True) 450 except PermissionError: 451 warnings.warn( 452 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 453 "The most likely reason for this is that you copied the checkpoint somewhere else," 454 "so we skip this error to enable loading the model from this checkpoint." 455 ) 456 pass 457 458 best_metric = np.inf 459 return best_metric 460 461 def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict): 462 save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") 463 extra_init_dict = extra_save_dict.pop("init", {}) 464 save_dict = { 465 "iteration": self._iteration, 466 "epoch": self._epoch, 467 "best_epoch": self._best_epoch, 468 "best_metric": best_metric, 469 "current_metric": current_metric, 470 "model_state": self.model.state_dict(), 471 "optimizer_state": self.optimizer.state_dict(), 472 "init": self.init_data | extra_init_dict, 473 "train_time": train_time, 474 } 475 save_dict.update(**extra_save_dict) 476 if self.scaler is not None: 477 save_dict.update({"scaler_state": self.scaler.state_dict()}) 478 if self.lr_scheduler is not None: 479 save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()}) 480 torch.save(save_dict, save_path) 481 482 def load_checkpoint(self, checkpoint="best"): 483 if isinstance(checkpoint, str): 484 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 485 if not os.path.exists(save_path): 486 warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.") 487 return 488 save_dict = torch.load(save_path) 489 elif isinstance(checkpoint, dict): 490 save_dict = checkpoint 491 else: 492 raise RuntimeError 493 494 self._iteration = save_dict["iteration"] 495 self._epoch = save_dict["epoch"] 496 self._best_epoch = save_dict["best_epoch"] 497 self.best_metric = save_dict["best_metric"] 498 self.current_metric = save_dict["current_metric"] 499 self.train_time = save_dict.get("train_time", 0.0) 500 501 model_state = save_dict["model_state"] 502 # to enable loading compiled models 503 compiled_prefix = "_orig_mod." 504 model_state = OrderedDict( 505 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()] 506 ) 507 self.model.load_state_dict(model_state) 508 # we need to send the network to the device before loading the optimizer state! 509 self.model.to(self.device) 510 511 self.optimizer.load_state_dict(save_dict["optimizer_state"]) 512 if self.scaler is not None: 513 self.scaler.load_state_dict(save_dict["scaler_state"]) 514 if self.lr_scheduler is not None: 515 self.lr_scheduler.load_state_dict(save_dict["scheduler_state"]) 516 517 return save_dict 518 519 def fit( 520 self, 521 iterations=None, 522 load_from_checkpoint=None, 523 epochs=None, 524 save_every_kth_epoch=None, 525 progress=None, 526 ): 527 """Run neural network training. 528 529 Exactly one of 'iterations' or 'epochs' has to be passed. 530 531 Parameters: 532 iterations [int] - how long to train, specified in iterations (default: None) 533 load_from_checkpoint [str] - path to a checkpoint from where training should be continued (default: None) 534 epochs [int] - how long to train, specified in epochs (default: None) 535 save_every_kth_epoch [int] - save checkpoints after every kth epoch separately. 536 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None) 537 progress [progress_bar] - optional progress bar for integration with external tools. 538 Expected to follow the tqdm interface. 539 """ 540 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 541 print( 542 "Start fitting for", 543 self.max_iteration - self._iteration, 544 "iterations / ", 545 self.max_epoch - self._epoch, 546 "epochs", 547 ) 548 print("with", len(self.train_loader), "iterations per epoch") 549 550 if self.mixed_precision: 551 train_epoch = self._train_epoch_mixed 552 validate = self._validate_mixed 553 print("Training with mixed precision") 554 else: 555 train_epoch = self._train_epoch 556 validate = self._validate 557 print("Training with single precision") 558 559 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 560 if progress is None: 561 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 562 else: 563 progress.total = total_iterations 564 progress.set_description(f"Epoch {self._epoch}") 565 566 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 567 train_epochs = self.max_epoch - self._epoch 568 t_start = time.time() 569 for _ in range(train_epochs): 570 571 # run training and validation for this epoch 572 t_per_iter = train_epoch(progress) 573 current_metric = validate() 574 575 # perform all the post-epoch steps: 576 577 # apply the learning rate scheduler 578 if self.lr_scheduler is not None: 579 self.lr_scheduler.step(current_metric) 580 581 # how long did we train in total? 582 total_train_time = (time.time() - t_start) + self.train_time 583 584 # save this checkpoint as the new best checkpoint if 585 # it has the best overall validation metric 586 if current_metric < best_metric: 587 best_metric = current_metric 588 self._best_epoch = self._epoch 589 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 590 591 # save this checkpoint as the latest checkpoint 592 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 593 594 # if we save after every k-th epoch then check if we need to save now 595 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 596 self.save_checkpoint( 597 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 598 ) 599 600 # if early stopping has been specified then check if the stopping condition is met 601 if self.early_stopping is not None: 602 epochs_since_best = self._epoch - self._best_epoch 603 if epochs_since_best > self.early_stopping: 604 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 605 break 606 607 self._epoch += 1 608 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 609 610 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 611 print(f"The best epoch is number {self._best_epoch}.") 612 613 if self._generate_name: 614 self.name = None 615 616 # Update the train time 617 self.train_time = total_train_time 618 619 # TODO save the model to wandb if we have the wandb logger 620 if isinstance(self.logger, WandbLogger): 621 self.logger.get_wandb().finish() 622 623 def _backprop(self, loss): 624 loss.backward() 625 self.optimizer.step() 626 627 def _backprop_mixed(self, loss): 628 self.scaler.scale(loss).backward() 629 self.scaler.step(self.optimizer) 630 self.scaler.update() 631 632 def _train_epoch(self, progress): 633 return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop) 634 635 def _train_epoch_mixed(self, progress): 636 return self._train_epoch_impl(progress, amp.autocast, self._backprop_mixed) 637 638 def _forward_and_loss(self, x, y): 639 pred = self.model(x) 640 if self._iteration % self.log_image_interval == 0: 641 if pred.requires_grad: 642 pred.retain_grad() 643 644 loss = self.loss(pred, y) 645 return pred, loss 646 647 def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]): 648 self.model.train() 649 650 n_iter = 0 651 t_per_iter = time.time() 652 for x, y in self.train_loader: 653 x, y = x.to(self.device), y.to(self.device) 654 655 self.optimizer.zero_grad() 656 657 with forward_context(): 658 pred, loss = self._forward_and_loss(x, y) 659 660 backprop(loss) 661 662 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 663 if self.logger is not None: 664 self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True) 665 666 self._iteration += 1 667 n_iter += 1 668 if self._iteration >= self.max_iteration: 669 break 670 progress.update(1) 671 672 t_per_iter = (time.time() - t_per_iter) / n_iter 673 return t_per_iter 674 675 def _validate(self): 676 return self._validate_impl(contextlib.nullcontext) 677 678 def _validate_mixed(self): 679 return self._validate_impl(amp.autocast) 680 681 def _validate_impl(self, forward_context): 682 self.model.eval() 683 684 metric_val = 0.0 685 loss_val = 0.0 686 687 with torch.no_grad(): 688 for x, y in self.val_loader: 689 x, y = x.to(self.device), y.to(self.device) 690 with forward_context(): 691 pred, loss = self._forward_and_loss(x, y) 692 metric = self.metric(pred, y) 693 694 loss_val += loss.item() 695 metric_val += metric.item() 696 697 metric_val /= len(self.val_loader) 698 loss_val /= len(self.val_loader) 699 if self.logger is not None: 700 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred) 701 return metric_val
23class DefaultTrainer: 24 """Trainer class for 2d/3d training on a single GPU.""" 25 26 def __init__( 27 self, 28 name: Optional[str], 29 train_loader: torch.utils.data.DataLoader, 30 val_loader: torch.utils.data.DataLoader, 31 model: torch.nn.Module, 32 loss, 33 optimizer, 34 metric, 35 device: Union[str, torch.device], 36 lr_scheduler=None, 37 log_image_interval=100, 38 mixed_precision=True, 39 early_stopping=None, 40 logger=TensorboardLogger, 41 logger_kwargs: Optional[Dict[str, Any]] = None, 42 id_: Optional[str] = None, 43 save_root: Optional[str] = None, 44 compile_model: Optional[Union[bool, str]] = None, 45 ): 46 if name is None and not issubclass(logger, WandbLogger): 47 raise TypeError("Name cannot be None if not using the WandbLogger") 48 49 if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): 50 raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.") 51 52 self._generate_name = name is None 53 self.name = name 54 self.id_ = id_ or name 55 self.train_loader = train_loader 56 self.val_loader = val_loader 57 self.model = model 58 self.loss = loss 59 self.optimizer = optimizer 60 self.metric = metric 61 self.device = device 62 self.lr_scheduler = lr_scheduler 63 self.log_image_interval = log_image_interval 64 self.save_root = save_root 65 self.compile_model = compile_model 66 67 self._iteration = 0 68 self._epoch = 0 69 self._best_epoch = 0 70 71 self.mixed_precision = mixed_precision 72 self.early_stopping = early_stopping 73 self.train_time = 0.0 74 75 self.scaler = amp.GradScaler() if mixed_precision else None 76 77 self.logger_class = logger 78 self.logger_kwargs = logger_kwargs 79 self.log_image_interval = log_image_interval 80 81 @property # because the logger may generate and set trainer.id on logger.__init__ 82 def checkpoint_folder(self): 83 assert self.id_ is not None 84 # save_root enables saving the checkpoints somewhere else than in the local 85 # folder. This is handy for filesystems with limited space, where saving the checkpoints 86 # and log files can easily lead to running out of space. 87 save_root = getattr(self, "save_root", None) 88 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 89 os.path.join(save_root, "./checkpoints", self.id_) 90 91 @property 92 def iteration(self): 93 return self._iteration 94 95 @property 96 def epoch(self): 97 return self._epoch 98 99 class Deserializer: 100 """Determines how to deserialize the trainer kwargs from serialized 'init_data' 101 102 Examples: 103 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 104 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 105 106 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 107 >>> class MyTrainer(DefaultTrainer): 108 >>> def __init__(self, *args, the_answer: int, **kwargs): 109 >>> super().__init__(*args, **kwargs) 110 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 111 >>> # see DefaultTrainer.Serializer 112 >>> 113 >>> class Deserializer(DefaultTrainer.Deserializer): 114 >>> def load_the_answer(self): 115 >>> generic_answer = self.init_data["the_answer"] 116 >>> # (device dependent) special deserialization 117 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 118 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 119 >>> else: 120 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 121 """ 122 123 def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]): 124 self.init_data = init_data 125 self.save_path = save_path 126 # populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device' 127 self.trainer_kwargs: Dict[str, Any] = dict( 128 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 129 ) 130 131 def load(self, kwarg_name: str, optional): 132 """`optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'""" 133 134 if kwarg_name == "device": 135 pass # deserialized in __init__ 136 elif kwarg_name.endswith("_loader"): 137 self.load_data_loader(kwarg_name, optional) 138 else: 139 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 140 load(kwarg_name, optional=optional) 141 142 def load_data_loader(self, loader_name, optional) -> None: 143 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 144 if ds is None and optional: 145 return 146 147 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 148 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 149 # monkey patch shuffle loader_name to the loader 150 loader.shuffle = loader_kwargs.get("shuffle", False) 151 self.trainer_kwargs[loader_name] = loader 152 153 def load_generic( 154 self, 155 kwarg_name: str, 156 *dynamic_args, 157 optional: bool, 158 only_class: bool = False, 159 dynamic_kwargs: Optional[Dict[str, Any]] = None, 160 ) -> None: 161 if kwarg_name in self.init_data: 162 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 163 return 164 165 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 166 if this_cls is None: 167 if optional: 168 return 169 else: 170 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 171 172 assert isinstance(this_cls, str), this_cls 173 assert "." in this_cls, this_cls 174 cls_p, cls_m = this_cls.rsplit(".", 1) 175 this_cls = getattr(import_module(cls_p), cls_m) 176 if only_class: 177 self.trainer_kwargs[kwarg_name] = this_cls 178 else: 179 self.trainer_kwargs[kwarg_name] = this_cls( 180 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 181 ) 182 183 def load_name(self, kwarg_name: str, optional: bool): 184 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 185 186 def load_optimizer(self, kwarg_name: str, optional: bool): 187 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 188 189 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 190 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 191 192 # todo: remove and rename kwarg 'logger' to 'logger_class' 193 def load_logger(self, kwarg_name: str, optional: bool): 194 assert kwarg_name == "logger" 195 self.load_generic("logger", optional=optional, only_class=True) 196 197 @staticmethod 198 def _get_save_dict(save_path, device): 199 if not os.path.exists(save_path): 200 raise ValueError(f"Cannot find checkpoint {save_path}") 201 return torch.load(save_path, map_location=device) 202 203 @classmethod 204 def from_checkpoint(cls, checkpoint_folder, name="best", device=None): 205 save_path = os.path.join(checkpoint_folder, f"{name}.pt") 206 # make sure the correct device is set if we don't have access to CUDA 207 if not torch.cuda.is_available(): 208 device = "cpu" 209 save_dict = cls._get_save_dict(save_path, device) 210 deserializer = cls.Deserializer(save_dict["init"], save_path, device) 211 212 has_kwargs = False 213 deserialized = [] 214 for name, parameter in inspect.signature(cls).parameters.items(): 215 if name == "kwargs": 216 has_kwargs = True 217 continue 218 deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty) 219 deserialized.append(name) 220 221 # to deserialze kwargs we can't rely on inspecting the signature, so we 222 # go through the remaning kwarg names in init data instead 223 if has_kwargs: 224 kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized)) 225 for name in kwarg_names: 226 if name.endswith("_kwargs"): 227 continue 228 elif name.endswith("_dataset"): 229 deserializer.load(name.replace("dataset", "loader"), optional=False) 230 elif name.endswith("_class"): 231 deserializer.load(name.replace("_class", ""), optional=False) 232 else: 233 deserializer.load(name, optional=False) 234 235 trainer = cls(**deserializer.trainer_kwargs) 236 trainer._initialize(0, save_dict) 237 trainer._is_initialized = True 238 return trainer 239 240 class Serializer: 241 """Implements how to serialize trainer kwargs from a trainer instance 242 243 Examples: 244 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 245 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 246 called by the `dump()` method when appropriate cover most cases already. 247 248 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 249 'the_answer' attribute: 250 >>> class MyTrainer(DefaultTrainer): 251 >>> def __init__(self, *args, the_answer: int, **kwargs): 252 >>> super().__init__(*args, **kwargs) 253 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 254 >>> # but let's make things more interesting... 255 >>> self.the = the_answer // 10 256 >>> self.answer = the_answer % 10 257 >>> 258 >>> class Serializer(DefaultTrainer.Serializer): 259 >>> trainer: MyTrainer 260 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 261 >>> assert kwarg_name == "the_answer" 262 >>> # populate self.init_data with the serialized data required by Deserializer 263 >>> # to restore the trainer kwargs 264 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 265 266 This example with both Serializer and Deserializer adds `the_answer` kwarg, 267 while saving it in two separate entries 'the' and 'answer' 268 >>> class MyTrainer(DefaultTrainer): 269 >>> def __init__(self, *args, the_answer: int, **kwargs): 270 >>> super().__init__(*args, **kwargs) 271 >>> self.the_answer = the_answer 272 >>> 273 >>> class Serializer(DefaultTrainer.Serializer): 274 >>> trainer: MyTrainer 275 >>> def dump_the_answer(self, kwarg_name: str): 276 >>> assert kwarg_name == "the_answer" 277 >>> self.init_data.update({ 278 >>> "the": self.trainer.the_answer // 10, 279 >>> "answer": self.trainer.the_answer % 10 280 >>> }) 281 >>> 282 >>> class Deserializer(DefaultTrainer.Deserializer): 283 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 284 >>> assert kwarg_name == "the_answer" 285 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 286 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 287 """ 288 289 def __init__(self, trainer: DefaultTrainer): 290 self.trainer = trainer 291 self.init_data = {} # to be populated during serialization process 292 293 def dump(self, kwarg_name: str) -> None: 294 dumper = getattr(self, f"dump_{kwarg_name}", None) 295 if dumper is not None: 296 dumper(kwarg_name) 297 elif kwarg_name.endswith("_loader"): 298 self.dump_data_loader(kwarg_name) 299 elif kwarg_name.endswith("_class"): 300 self.dump_generic_class(kwarg_name) 301 elif not hasattr(self.trainer, kwarg_name): 302 raise AttributeError( 303 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 304 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 305 ) 306 else: 307 assert hasattr(self.trainer, kwarg_name) 308 obj = getattr(self.trainer, kwarg_name) 309 if obj is None or type(obj) in ( 310 bool, 311 bytearray, 312 bytes, 313 dict, 314 float, 315 frozenset, 316 int, 317 list, 318 set, 319 str, 320 tuple, 321 ): 322 self.dump_generic_builtin(kwarg_name) 323 else: 324 self.dump_generic_instance(kwarg_name) 325 326 def dump_generic_builtin(self, kwarg_name: str) -> None: 327 assert hasattr(self.trainer, kwarg_name) 328 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 329 330 def dump_generic_class(self, kwarg_name: str) -> None: 331 assert hasattr(self.trainer, kwarg_name) 332 assert kwarg_name.endswith("_class") 333 obj = getattr(self.trainer, kwarg_name) 334 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 335 336 def dump_generic_instance(self, kwarg_name: str) -> None: 337 assert hasattr(self.trainer, kwarg_name) 338 instance = getattr(self.trainer, kwarg_name) 339 self.init_data.update( 340 { 341 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 342 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 343 } 344 ) 345 346 def dump_device(self, kwarg_name: str): 347 assert hasattr(self.trainer, kwarg_name) 348 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 349 350 def dump_data_loader(self, kwarg_name: str) -> None: 351 assert hasattr(self.trainer, kwarg_name) 352 loader = getattr(self.trainer, kwarg_name) 353 if loader is None: 354 return 355 self.init_data.update( 356 { 357 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 358 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 359 } 360 ) 361 362 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 363 self.dump_generic_class(f"{kwarg_name}_class") 364 365 def dump_model(self, kwarg_name: str): 366 if is_compiled(self.trainer.model): 367 self.init_data.update( 368 { 369 "model_class": self.trainer._model_class, 370 "model_kwargs": self.trainer._model_kwargs, 371 } 372 ) 373 else: 374 self.dump_generic_instance("model") 375 376 def _build_init(self) -> Dict[str, Any]: 377 serializer = self.Serializer(self) 378 for name in inspect.signature(self.__class__).parameters: 379 # special rules to serialize kwargs 380 # if a trainer class inherits from DefaultTrainer and has **kwargs 381 # they need to be saved in self._kwargs 382 if name == "kwargs": 383 if not hasattr(self, "_kwargs"): 384 msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\ 385 "Please add self._kwargs to its __init__ function" 386 raise RuntimeError(msg) 387 kwargs = getattr(self, "_kwargs") 388 for kwarg_name in kwargs: 389 serializer.dump(kwarg_name) 390 continue 391 serializer.dump(name) 392 393 return serializer.init_data 394 395 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 396 assert self.train_loader is not None 397 assert self.val_loader is not None 398 assert self.model is not None 399 assert self.loss is not None 400 assert self.optimizer is not None 401 assert self.metric is not None 402 assert self.device is not None 403 404 if load_from_checkpoint is not None: 405 self.load_checkpoint(load_from_checkpoint) 406 407 if sum((iterations is not None, epochs is not None)) != 1: 408 raise ValueError( 409 "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer." 410 f"You have passed 'iterations'={iterations} and 'epochs'={epochs}" 411 ) 412 413 if epochs is None: 414 epochs = int(np.ceil(float(iterations) / len(self.train_loader))) 415 else: 416 iterations = epochs * len(self.train_loader) 417 418 self.max_iteration = self._iteration + iterations 419 self.max_epoch = self._epoch + epochs 420 421 if not getattr(self, "_is_initialized", False): 422 # check if we compile the model (only supported by pytorch 2) 423 # to enable (de)serialization of compiled models, we keep track of the model class and kwargs 424 if is_compiled(self.model): 425 warnings.warn( 426 "You have passed a compiled model to the trainer." 427 "It will not be possible to (de)serialize the trainer with it." 428 "If you want to be able to do this please pass the normal model." 429 "It can be automatically compiled by setting 'compile_model' to True" 430 ) 431 self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}" 432 self._model_kwargs = get_constructor_arguments(self.model) 433 self.model = auto_compile(self.model, self.compile_model) 434 435 self.model.to(self.device) 436 self.loss.to(self.device) 437 438 # this saves all the information that is necessary 439 # to fully load the trainer from the checkpoint 440 self.init_data = self._build_init() 441 442 if self.logger_class is None: 443 self.logger = None 444 else: 445 # may set self.name if self.name is None 446 save_root = getattr(self, "save_root", None) 447 self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {})) 448 449 try: 450 os.makedirs(self.checkpoint_folder, exist_ok=True) 451 except PermissionError: 452 warnings.warn( 453 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 454 "The most likely reason for this is that you copied the checkpoint somewhere else," 455 "so we skip this error to enable loading the model from this checkpoint." 456 ) 457 pass 458 459 best_metric = np.inf 460 return best_metric 461 462 def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict): 463 save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") 464 extra_init_dict = extra_save_dict.pop("init", {}) 465 save_dict = { 466 "iteration": self._iteration, 467 "epoch": self._epoch, 468 "best_epoch": self._best_epoch, 469 "best_metric": best_metric, 470 "current_metric": current_metric, 471 "model_state": self.model.state_dict(), 472 "optimizer_state": self.optimizer.state_dict(), 473 "init": self.init_data | extra_init_dict, 474 "train_time": train_time, 475 } 476 save_dict.update(**extra_save_dict) 477 if self.scaler is not None: 478 save_dict.update({"scaler_state": self.scaler.state_dict()}) 479 if self.lr_scheduler is not None: 480 save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()}) 481 torch.save(save_dict, save_path) 482 483 def load_checkpoint(self, checkpoint="best"): 484 if isinstance(checkpoint, str): 485 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 486 if not os.path.exists(save_path): 487 warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.") 488 return 489 save_dict = torch.load(save_path) 490 elif isinstance(checkpoint, dict): 491 save_dict = checkpoint 492 else: 493 raise RuntimeError 494 495 self._iteration = save_dict["iteration"] 496 self._epoch = save_dict["epoch"] 497 self._best_epoch = save_dict["best_epoch"] 498 self.best_metric = save_dict["best_metric"] 499 self.current_metric = save_dict["current_metric"] 500 self.train_time = save_dict.get("train_time", 0.0) 501 502 model_state = save_dict["model_state"] 503 # to enable loading compiled models 504 compiled_prefix = "_orig_mod." 505 model_state = OrderedDict( 506 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()] 507 ) 508 self.model.load_state_dict(model_state) 509 # we need to send the network to the device before loading the optimizer state! 510 self.model.to(self.device) 511 512 self.optimizer.load_state_dict(save_dict["optimizer_state"]) 513 if self.scaler is not None: 514 self.scaler.load_state_dict(save_dict["scaler_state"]) 515 if self.lr_scheduler is not None: 516 self.lr_scheduler.load_state_dict(save_dict["scheduler_state"]) 517 518 return save_dict 519 520 def fit( 521 self, 522 iterations=None, 523 load_from_checkpoint=None, 524 epochs=None, 525 save_every_kth_epoch=None, 526 progress=None, 527 ): 528 """Run neural network training. 529 530 Exactly one of 'iterations' or 'epochs' has to be passed. 531 532 Parameters: 533 iterations [int] - how long to train, specified in iterations (default: None) 534 load_from_checkpoint [str] - path to a checkpoint from where training should be continued (default: None) 535 epochs [int] - how long to train, specified in epochs (default: None) 536 save_every_kth_epoch [int] - save checkpoints after every kth epoch separately. 537 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None) 538 progress [progress_bar] - optional progress bar for integration with external tools. 539 Expected to follow the tqdm interface. 540 """ 541 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 542 print( 543 "Start fitting for", 544 self.max_iteration - self._iteration, 545 "iterations / ", 546 self.max_epoch - self._epoch, 547 "epochs", 548 ) 549 print("with", len(self.train_loader), "iterations per epoch") 550 551 if self.mixed_precision: 552 train_epoch = self._train_epoch_mixed 553 validate = self._validate_mixed 554 print("Training with mixed precision") 555 else: 556 train_epoch = self._train_epoch 557 validate = self._validate 558 print("Training with single precision") 559 560 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 561 if progress is None: 562 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 563 else: 564 progress.total = total_iterations 565 progress.set_description(f"Epoch {self._epoch}") 566 567 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 568 train_epochs = self.max_epoch - self._epoch 569 t_start = time.time() 570 for _ in range(train_epochs): 571 572 # run training and validation for this epoch 573 t_per_iter = train_epoch(progress) 574 current_metric = validate() 575 576 # perform all the post-epoch steps: 577 578 # apply the learning rate scheduler 579 if self.lr_scheduler is not None: 580 self.lr_scheduler.step(current_metric) 581 582 # how long did we train in total? 583 total_train_time = (time.time() - t_start) + self.train_time 584 585 # save this checkpoint as the new best checkpoint if 586 # it has the best overall validation metric 587 if current_metric < best_metric: 588 best_metric = current_metric 589 self._best_epoch = self._epoch 590 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 591 592 # save this checkpoint as the latest checkpoint 593 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 594 595 # if we save after every k-th epoch then check if we need to save now 596 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 597 self.save_checkpoint( 598 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 599 ) 600 601 # if early stopping has been specified then check if the stopping condition is met 602 if self.early_stopping is not None: 603 epochs_since_best = self._epoch - self._best_epoch 604 if epochs_since_best > self.early_stopping: 605 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 606 break 607 608 self._epoch += 1 609 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 610 611 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 612 print(f"The best epoch is number {self._best_epoch}.") 613 614 if self._generate_name: 615 self.name = None 616 617 # Update the train time 618 self.train_time = total_train_time 619 620 # TODO save the model to wandb if we have the wandb logger 621 if isinstance(self.logger, WandbLogger): 622 self.logger.get_wandb().finish() 623 624 def _backprop(self, loss): 625 loss.backward() 626 self.optimizer.step() 627 628 def _backprop_mixed(self, loss): 629 self.scaler.scale(loss).backward() 630 self.scaler.step(self.optimizer) 631 self.scaler.update() 632 633 def _train_epoch(self, progress): 634 return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop) 635 636 def _train_epoch_mixed(self, progress): 637 return self._train_epoch_impl(progress, amp.autocast, self._backprop_mixed) 638 639 def _forward_and_loss(self, x, y): 640 pred = self.model(x) 641 if self._iteration % self.log_image_interval == 0: 642 if pred.requires_grad: 643 pred.retain_grad() 644 645 loss = self.loss(pred, y) 646 return pred, loss 647 648 def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]): 649 self.model.train() 650 651 n_iter = 0 652 t_per_iter = time.time() 653 for x, y in self.train_loader: 654 x, y = x.to(self.device), y.to(self.device) 655 656 self.optimizer.zero_grad() 657 658 with forward_context(): 659 pred, loss = self._forward_and_loss(x, y) 660 661 backprop(loss) 662 663 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 664 if self.logger is not None: 665 self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True) 666 667 self._iteration += 1 668 n_iter += 1 669 if self._iteration >= self.max_iteration: 670 break 671 progress.update(1) 672 673 t_per_iter = (time.time() - t_per_iter) / n_iter 674 return t_per_iter 675 676 def _validate(self): 677 return self._validate_impl(contextlib.nullcontext) 678 679 def _validate_mixed(self): 680 return self._validate_impl(amp.autocast) 681 682 def _validate_impl(self, forward_context): 683 self.model.eval() 684 685 metric_val = 0.0 686 loss_val = 0.0 687 688 with torch.no_grad(): 689 for x, y in self.val_loader: 690 x, y = x.to(self.device), y.to(self.device) 691 with forward_context(): 692 pred, loss = self._forward_and_loss(x, y) 693 metric = self.metric(pred, y) 694 695 loss_val += loss.item() 696 metric_val += metric.item() 697 698 metric_val /= len(self.val_loader) 699 loss_val /= len(self.val_loader) 700 if self.logger is not None: 701 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred) 702 return metric_val
Trainer class for 2d/3d training on a single GPU.
26 def __init__( 27 self, 28 name: Optional[str], 29 train_loader: torch.utils.data.DataLoader, 30 val_loader: torch.utils.data.DataLoader, 31 model: torch.nn.Module, 32 loss, 33 optimizer, 34 metric, 35 device: Union[str, torch.device], 36 lr_scheduler=None, 37 log_image_interval=100, 38 mixed_precision=True, 39 early_stopping=None, 40 logger=TensorboardLogger, 41 logger_kwargs: Optional[Dict[str, Any]] = None, 42 id_: Optional[str] = None, 43 save_root: Optional[str] = None, 44 compile_model: Optional[Union[bool, str]] = None, 45 ): 46 if name is None and not issubclass(logger, WandbLogger): 47 raise TypeError("Name cannot be None if not using the WandbLogger") 48 49 if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): 50 raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.") 51 52 self._generate_name = name is None 53 self.name = name 54 self.id_ = id_ or name 55 self.train_loader = train_loader 56 self.val_loader = val_loader 57 self.model = model 58 self.loss = loss 59 self.optimizer = optimizer 60 self.metric = metric 61 self.device = device 62 self.lr_scheduler = lr_scheduler 63 self.log_image_interval = log_image_interval 64 self.save_root = save_root 65 self.compile_model = compile_model 66 67 self._iteration = 0 68 self._epoch = 0 69 self._best_epoch = 0 70 71 self.mixed_precision = mixed_precision 72 self.early_stopping = early_stopping 73 self.train_time = 0.0 74 75 self.scaler = amp.GradScaler() if mixed_precision else None 76 77 self.logger_class = logger 78 self.logger_kwargs = logger_kwargs 79 self.log_image_interval = log_image_interval
203 @classmethod 204 def from_checkpoint(cls, checkpoint_folder, name="best", device=None): 205 save_path = os.path.join(checkpoint_folder, f"{name}.pt") 206 # make sure the correct device is set if we don't have access to CUDA 207 if not torch.cuda.is_available(): 208 device = "cpu" 209 save_dict = cls._get_save_dict(save_path, device) 210 deserializer = cls.Deserializer(save_dict["init"], save_path, device) 211 212 has_kwargs = False 213 deserialized = [] 214 for name, parameter in inspect.signature(cls).parameters.items(): 215 if name == "kwargs": 216 has_kwargs = True 217 continue 218 deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty) 219 deserialized.append(name) 220 221 # to deserialze kwargs we can't rely on inspecting the signature, so we 222 # go through the remaning kwarg names in init data instead 223 if has_kwargs: 224 kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized)) 225 for name in kwarg_names: 226 if name.endswith("_kwargs"): 227 continue 228 elif name.endswith("_dataset"): 229 deserializer.load(name.replace("dataset", "loader"), optional=False) 230 elif name.endswith("_class"): 231 deserializer.load(name.replace("_class", ""), optional=False) 232 else: 233 deserializer.load(name, optional=False) 234 235 trainer = cls(**deserializer.trainer_kwargs) 236 trainer._initialize(0, save_dict) 237 trainer._is_initialized = True 238 return trainer
462 def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict): 463 save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") 464 extra_init_dict = extra_save_dict.pop("init", {}) 465 save_dict = { 466 "iteration": self._iteration, 467 "epoch": self._epoch, 468 "best_epoch": self._best_epoch, 469 "best_metric": best_metric, 470 "current_metric": current_metric, 471 "model_state": self.model.state_dict(), 472 "optimizer_state": self.optimizer.state_dict(), 473 "init": self.init_data | extra_init_dict, 474 "train_time": train_time, 475 } 476 save_dict.update(**extra_save_dict) 477 if self.scaler is not None: 478 save_dict.update({"scaler_state": self.scaler.state_dict()}) 479 if self.lr_scheduler is not None: 480 save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()}) 481 torch.save(save_dict, save_path)
483 def load_checkpoint(self, checkpoint="best"): 484 if isinstance(checkpoint, str): 485 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 486 if not os.path.exists(save_path): 487 warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.") 488 return 489 save_dict = torch.load(save_path) 490 elif isinstance(checkpoint, dict): 491 save_dict = checkpoint 492 else: 493 raise RuntimeError 494 495 self._iteration = save_dict["iteration"] 496 self._epoch = save_dict["epoch"] 497 self._best_epoch = save_dict["best_epoch"] 498 self.best_metric = save_dict["best_metric"] 499 self.current_metric = save_dict["current_metric"] 500 self.train_time = save_dict.get("train_time", 0.0) 501 502 model_state = save_dict["model_state"] 503 # to enable loading compiled models 504 compiled_prefix = "_orig_mod." 505 model_state = OrderedDict( 506 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()] 507 ) 508 self.model.load_state_dict(model_state) 509 # we need to send the network to the device before loading the optimizer state! 510 self.model.to(self.device) 511 512 self.optimizer.load_state_dict(save_dict["optimizer_state"]) 513 if self.scaler is not None: 514 self.scaler.load_state_dict(save_dict["scaler_state"]) 515 if self.lr_scheduler is not None: 516 self.lr_scheduler.load_state_dict(save_dict["scheduler_state"]) 517 518 return save_dict
520 def fit( 521 self, 522 iterations=None, 523 load_from_checkpoint=None, 524 epochs=None, 525 save_every_kth_epoch=None, 526 progress=None, 527 ): 528 """Run neural network training. 529 530 Exactly one of 'iterations' or 'epochs' has to be passed. 531 532 Parameters: 533 iterations [int] - how long to train, specified in iterations (default: None) 534 load_from_checkpoint [str] - path to a checkpoint from where training should be continued (default: None) 535 epochs [int] - how long to train, specified in epochs (default: None) 536 save_every_kth_epoch [int] - save checkpoints after every kth epoch separately. 537 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None) 538 progress [progress_bar] - optional progress bar for integration with external tools. 539 Expected to follow the tqdm interface. 540 """ 541 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 542 print( 543 "Start fitting for", 544 self.max_iteration - self._iteration, 545 "iterations / ", 546 self.max_epoch - self._epoch, 547 "epochs", 548 ) 549 print("with", len(self.train_loader), "iterations per epoch") 550 551 if self.mixed_precision: 552 train_epoch = self._train_epoch_mixed 553 validate = self._validate_mixed 554 print("Training with mixed precision") 555 else: 556 train_epoch = self._train_epoch 557 validate = self._validate 558 print("Training with single precision") 559 560 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 561 if progress is None: 562 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 563 else: 564 progress.total = total_iterations 565 progress.set_description(f"Epoch {self._epoch}") 566 567 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 568 train_epochs = self.max_epoch - self._epoch 569 t_start = time.time() 570 for _ in range(train_epochs): 571 572 # run training and validation for this epoch 573 t_per_iter = train_epoch(progress) 574 current_metric = validate() 575 576 # perform all the post-epoch steps: 577 578 # apply the learning rate scheduler 579 if self.lr_scheduler is not None: 580 self.lr_scheduler.step(current_metric) 581 582 # how long did we train in total? 583 total_train_time = (time.time() - t_start) + self.train_time 584 585 # save this checkpoint as the new best checkpoint if 586 # it has the best overall validation metric 587 if current_metric < best_metric: 588 best_metric = current_metric 589 self._best_epoch = self._epoch 590 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 591 592 # save this checkpoint as the latest checkpoint 593 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 594 595 # if we save after every k-th epoch then check if we need to save now 596 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 597 self.save_checkpoint( 598 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 599 ) 600 601 # if early stopping has been specified then check if the stopping condition is met 602 if self.early_stopping is not None: 603 epochs_since_best = self._epoch - self._best_epoch 604 if epochs_since_best > self.early_stopping: 605 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 606 break 607 608 self._epoch += 1 609 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 610 611 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 612 print(f"The best epoch is number {self._best_epoch}.") 613 614 if self._generate_name: 615 self.name = None 616 617 # Update the train time 618 self.train_time = total_train_time 619 620 # TODO save the model to wandb if we have the wandb logger 621 if isinstance(self.logger, WandbLogger): 622 self.logger.get_wandb().finish()
Run neural network training.
Exactly one of 'iterations' or 'epochs' has to be passed.
Arguments:
- iterations [int] - how long to train, specified in iterations (default: None)
- load_from_checkpoint [str] - path to a checkpoint from where training should be continued (default: None)
- epochs [int] - how long to train, specified in epochs (default: None)
- save_every_kth_epoch [int] - save checkpoints after every kth epoch separately. The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None)
- progress [progress_bar] - optional progress bar for integration with external tools. Expected to follow the tqdm interface.
99 class Deserializer: 100 """Determines how to deserialize the trainer kwargs from serialized 'init_data' 101 102 Examples: 103 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 104 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 105 106 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 107 >>> class MyTrainer(DefaultTrainer): 108 >>> def __init__(self, *args, the_answer: int, **kwargs): 109 >>> super().__init__(*args, **kwargs) 110 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 111 >>> # see DefaultTrainer.Serializer 112 >>> 113 >>> class Deserializer(DefaultTrainer.Deserializer): 114 >>> def load_the_answer(self): 115 >>> generic_answer = self.init_data["the_answer"] 116 >>> # (device dependent) special deserialization 117 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 118 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 119 >>> else: 120 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 121 """ 122 123 def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]): 124 self.init_data = init_data 125 self.save_path = save_path 126 # populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device' 127 self.trainer_kwargs: Dict[str, Any] = dict( 128 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 129 ) 130 131 def load(self, kwarg_name: str, optional): 132 """`optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'""" 133 134 if kwarg_name == "device": 135 pass # deserialized in __init__ 136 elif kwarg_name.endswith("_loader"): 137 self.load_data_loader(kwarg_name, optional) 138 else: 139 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 140 load(kwarg_name, optional=optional) 141 142 def load_data_loader(self, loader_name, optional) -> None: 143 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 144 if ds is None and optional: 145 return 146 147 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 148 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 149 # monkey patch shuffle loader_name to the loader 150 loader.shuffle = loader_kwargs.get("shuffle", False) 151 self.trainer_kwargs[loader_name] = loader 152 153 def load_generic( 154 self, 155 kwarg_name: str, 156 *dynamic_args, 157 optional: bool, 158 only_class: bool = False, 159 dynamic_kwargs: Optional[Dict[str, Any]] = None, 160 ) -> None: 161 if kwarg_name in self.init_data: 162 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 163 return 164 165 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 166 if this_cls is None: 167 if optional: 168 return 169 else: 170 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 171 172 assert isinstance(this_cls, str), this_cls 173 assert "." in this_cls, this_cls 174 cls_p, cls_m = this_cls.rsplit(".", 1) 175 this_cls = getattr(import_module(cls_p), cls_m) 176 if only_class: 177 self.trainer_kwargs[kwarg_name] = this_cls 178 else: 179 self.trainer_kwargs[kwarg_name] = this_cls( 180 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 181 ) 182 183 def load_name(self, kwarg_name: str, optional: bool): 184 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 185 186 def load_optimizer(self, kwarg_name: str, optional: bool): 187 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 188 189 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 190 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 191 192 # todo: remove and rename kwarg 'logger' to 'logger_class' 193 def load_logger(self, kwarg_name: str, optional: bool): 194 assert kwarg_name == "logger" 195 self.load_generic("logger", optional=optional, only_class=True)
Determines how to deserialize the trainer kwargs from serialized 'init_data'
Examples:
To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. Note that
DefaultTrainer.Deserializer.load_generic()
covers most cases already.This example adds
the_answer
kwarg, which requires 'calculations' upon initialization:>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, >>> # see DefaultTrainer.Serializer >>> >>> class Deserializer(DefaultTrainer.Deserializer): >>> def load_the_answer(self): >>> generic_answer = self.init_data["the_answer"] >>> # (device dependent) special deserialization >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 >>> else: >>> self.trainer_kwargs["the_answer"] = generic_answer * 2
123 def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]): 124 self.init_data = init_data 125 self.save_path = save_path 126 # populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device' 127 self.trainer_kwargs: Dict[str, Any] = dict( 128 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 129 )
131 def load(self, kwarg_name: str, optional): 132 """`optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'""" 133 134 if kwarg_name == "device": 135 pass # deserialized in __init__ 136 elif kwarg_name.endswith("_loader"): 137 self.load_data_loader(kwarg_name, optional) 138 else: 139 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 140 load(kwarg_name, optional=optional)
optional
is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
142 def load_data_loader(self, loader_name, optional) -> None: 143 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 144 if ds is None and optional: 145 return 146 147 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 148 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 149 # monkey patch shuffle loader_name to the loader 150 loader.shuffle = loader_kwargs.get("shuffle", False) 151 self.trainer_kwargs[loader_name] = loader
153 def load_generic( 154 self, 155 kwarg_name: str, 156 *dynamic_args, 157 optional: bool, 158 only_class: bool = False, 159 dynamic_kwargs: Optional[Dict[str, Any]] = None, 160 ) -> None: 161 if kwarg_name in self.init_data: 162 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 163 return 164 165 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 166 if this_cls is None: 167 if optional: 168 return 169 else: 170 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 171 172 assert isinstance(this_cls, str), this_cls 173 assert "." in this_cls, this_cls 174 cls_p, cls_m = this_cls.rsplit(".", 1) 175 this_cls = getattr(import_module(cls_p), cls_m) 176 if only_class: 177 self.trainer_kwargs[kwarg_name] = this_cls 178 else: 179 self.trainer_kwargs[kwarg_name] = this_cls( 180 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 181 )
240 class Serializer: 241 """Implements how to serialize trainer kwargs from a trainer instance 242 243 Examples: 244 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 245 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 246 called by the `dump()` method when appropriate cover most cases already. 247 248 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 249 'the_answer' attribute: 250 >>> class MyTrainer(DefaultTrainer): 251 >>> def __init__(self, *args, the_answer: int, **kwargs): 252 >>> super().__init__(*args, **kwargs) 253 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 254 >>> # but let's make things more interesting... 255 >>> self.the = the_answer // 10 256 >>> self.answer = the_answer % 10 257 >>> 258 >>> class Serializer(DefaultTrainer.Serializer): 259 >>> trainer: MyTrainer 260 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 261 >>> assert kwarg_name == "the_answer" 262 >>> # populate self.init_data with the serialized data required by Deserializer 263 >>> # to restore the trainer kwargs 264 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 265 266 This example with both Serializer and Deserializer adds `the_answer` kwarg, 267 while saving it in two separate entries 'the' and 'answer' 268 >>> class MyTrainer(DefaultTrainer): 269 >>> def __init__(self, *args, the_answer: int, **kwargs): 270 >>> super().__init__(*args, **kwargs) 271 >>> self.the_answer = the_answer 272 >>> 273 >>> class Serializer(DefaultTrainer.Serializer): 274 >>> trainer: MyTrainer 275 >>> def dump_the_answer(self, kwarg_name: str): 276 >>> assert kwarg_name == "the_answer" 277 >>> self.init_data.update({ 278 >>> "the": self.trainer.the_answer // 10, 279 >>> "answer": self.trainer.the_answer % 10 280 >>> }) 281 >>> 282 >>> class Deserializer(DefaultTrainer.Deserializer): 283 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 284 >>> assert kwarg_name == "the_answer" 285 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 286 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 287 """ 288 289 def __init__(self, trainer: DefaultTrainer): 290 self.trainer = trainer 291 self.init_data = {} # to be populated during serialization process 292 293 def dump(self, kwarg_name: str) -> None: 294 dumper = getattr(self, f"dump_{kwarg_name}", None) 295 if dumper is not None: 296 dumper(kwarg_name) 297 elif kwarg_name.endswith("_loader"): 298 self.dump_data_loader(kwarg_name) 299 elif kwarg_name.endswith("_class"): 300 self.dump_generic_class(kwarg_name) 301 elif not hasattr(self.trainer, kwarg_name): 302 raise AttributeError( 303 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 304 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 305 ) 306 else: 307 assert hasattr(self.trainer, kwarg_name) 308 obj = getattr(self.trainer, kwarg_name) 309 if obj is None or type(obj) in ( 310 bool, 311 bytearray, 312 bytes, 313 dict, 314 float, 315 frozenset, 316 int, 317 list, 318 set, 319 str, 320 tuple, 321 ): 322 self.dump_generic_builtin(kwarg_name) 323 else: 324 self.dump_generic_instance(kwarg_name) 325 326 def dump_generic_builtin(self, kwarg_name: str) -> None: 327 assert hasattr(self.trainer, kwarg_name) 328 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 329 330 def dump_generic_class(self, kwarg_name: str) -> None: 331 assert hasattr(self.trainer, kwarg_name) 332 assert kwarg_name.endswith("_class") 333 obj = getattr(self.trainer, kwarg_name) 334 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 335 336 def dump_generic_instance(self, kwarg_name: str) -> None: 337 assert hasattr(self.trainer, kwarg_name) 338 instance = getattr(self.trainer, kwarg_name) 339 self.init_data.update( 340 { 341 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 342 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 343 } 344 ) 345 346 def dump_device(self, kwarg_name: str): 347 assert hasattr(self.trainer, kwarg_name) 348 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 349 350 def dump_data_loader(self, kwarg_name: str) -> None: 351 assert hasattr(self.trainer, kwarg_name) 352 loader = getattr(self.trainer, kwarg_name) 353 if loader is None: 354 return 355 self.init_data.update( 356 { 357 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 358 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 359 } 360 ) 361 362 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 363 self.dump_generic_class(f"{kwarg_name}_class") 364 365 def dump_model(self, kwarg_name: str): 366 if is_compiled(self.trainer.model): 367 self.init_data.update( 368 { 369 "model_class": self.trainer._model_class, 370 "model_kwargs": self.trainer._model_kwargs, 371 } 372 ) 373 else: 374 self.dump_generic_instance("model")
Implements how to serialize trainer kwargs from a trainer instance
Examples:
To extend the serialization process you can inherite from this Serializer in a derived Trainer class. Note that the methods
dump_generic_builtin()
,dump_generic_class()
anddump_generic_instance()
called by thedump()
method when appropriate cover most cases already.This example adds
the_answer
kwarg, which requires extra steps on dumping only because we don't keep a 'the_answer' attribute:>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, >>> # but let's make things more interesting... >>> self.the = the_answer // 10 >>> self.answer = the_answer % 10 >>> >>> class Serializer(DefaultTrainer.Serializer): >>> trainer: MyTrainer >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg >>> assert kwarg_name == "the_answer" >>> # populate self.init_data with the serialized data required by Deserializer >>> # to restore the trainer kwargs >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
This example with both Serializer and Deserializer adds
the_answer
kwarg, while saving it in two separate entries 'the' and 'answer'>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> self.the_answer = the_answer >>> >>> class Serializer(DefaultTrainer.Serializer): >>> trainer: MyTrainer >>> def dump_the_answer(self, kwarg_name: str): >>> assert kwarg_name == "the_answer" >>> self.init_data.update({ >>> "the": self.trainer.the_answer // 10, >>> "answer": self.trainer.the_answer % 10 >>> }) >>> >>> class Deserializer(DefaultTrainer.Deserializer): >>> def load_the_answer(self, kwarg_name: str, optional: bool): >>> assert kwarg_name == "the_answer" >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
293 def dump(self, kwarg_name: str) -> None: 294 dumper = getattr(self, f"dump_{kwarg_name}", None) 295 if dumper is not None: 296 dumper(kwarg_name) 297 elif kwarg_name.endswith("_loader"): 298 self.dump_data_loader(kwarg_name) 299 elif kwarg_name.endswith("_class"): 300 self.dump_generic_class(kwarg_name) 301 elif not hasattr(self.trainer, kwarg_name): 302 raise AttributeError( 303 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 304 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 305 ) 306 else: 307 assert hasattr(self.trainer, kwarg_name) 308 obj = getattr(self.trainer, kwarg_name) 309 if obj is None or type(obj) in ( 310 bool, 311 bytearray, 312 bytes, 313 dict, 314 float, 315 frozenset, 316 int, 317 list, 318 set, 319 str, 320 tuple, 321 ): 322 self.dump_generic_builtin(kwarg_name) 323 else: 324 self.dump_generic_instance(kwarg_name)
336 def dump_generic_instance(self, kwarg_name: str) -> None: 337 assert hasattr(self.trainer, kwarg_name) 338 instance = getattr(self.trainer, kwarg_name) 339 self.init_data.update( 340 { 341 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 342 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 343 } 344 )
350 def dump_data_loader(self, kwarg_name: str) -> None: 351 assert hasattr(self.trainer, kwarg_name) 352 loader = getattr(self.trainer, kwarg_name) 353 if loader is None: 354 return 355 self.init_data.update( 356 { 357 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 358 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 359 } 360 )