torch_em.trainer.default_trainer
1from __future__ import annotations 2 3import os 4import time 5import inspect 6import warnings 7import contextlib 8from tqdm import tqdm 9from functools import partial 10from datetime import datetime 11from collections import OrderedDict 12from importlib import import_module 13from typing import Any, Callable, Dict, Optional, Union, Literal 14 15import numpy as np 16 17import torch 18 19from .wandb_logger import WandbLogger 20from .tensorboard_logger import TensorboardLogger 21from ..util import auto_compile, get_constructor_arguments, is_compiled 22 23 24class DefaultTrainer: 25 """Trainer class for training a segmentation network. 26 27 The trainer class implements the core logic for training a network with pytorch. 28 It implements a training loop to run training and validation, which is started with `fit`. 29 The checkpoints and logs from the training run will be saved in the current working directory, 30 or in the directory specifified by `save_root`. Training can be continued from a checkpoint 31 by passing it's location to the `load_from_checkpoint` argument of `fit`. 32 33 A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`. 34 Alternatively, the trainer class can also be instantiated as in this example: 35 ```python 36 import torch 37 from torch_em.loss import DiceLoss 38 from torch_em.model import UNet2d 39 from torch_em.data.datasets.light_microscopy import get_dsb_loader 40 from torch_em.trainer import DefaultTrainer 41 42 # The training data will be downloaded to this location. 43 data_root = "/path/to/save/the/training/data" 44 patch_shape = (256, 256) 45 46 # Create the model and optimizer. 47 model = UNet2d(in_channels=1, out_channels=1) 48 optimizer = torch.optim.AdamW(model.parameters()) 49 50 trainer = DefaultTrainer( 51 name="unet-training", 52 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 53 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 54 model=model, 55 loss=DiceLoss(), # The loss function. 56 optimizer=optimizer, 57 metric=DiceLoss(), # The metric. The trainer expects smaller values to represent better results. 58 device="cuda", # The device to use for training. 59 ) 60 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 61 ``` 62 63 Args: 64 name: The name of the checkpoint that will be created by the trainer. 65 train_loader: The data loader containing the training data. 66 val_loader: The data loader containing the validation data. 67 model: The model to train. 68 loss: The loss function for training. 69 optimizer: The optimizer. 70 metric: The metric for validation. 71 device: The torch device to use for training. If None, will use a GPU if available. 72 lr_scheduler: The learning rate scheduler. 73 log_image_interval: The interval for saving images during logging, in training iterations. 74 mixed_precision: Whether to train with mixed precision. 75 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 76 logger: The logger class. Will be instantiated for logging. 77 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 78 logger_kwargs: The keyword arguments for the logger class. 79 id_: Unique identifier for the trainer. If None then `name` will be used. 80 save_root: The root folder for saving the checkpoint and logs. 81 compile_model: Whether to compile the model before training. 82 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 83 """ 84 def __init__( 85 self, 86 name: Optional[str], 87 train_loader: torch.utils.data.DataLoader, 88 val_loader: torch.utils.data.DataLoader, 89 model: torch.nn.Module, 90 loss: torch.nn.Module, 91 optimizer: torch.optim.Optimizer, 92 metric: Callable, 93 device: Union[str, torch.device], 94 lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 95 log_image_interval: int = 100, 96 mixed_precision: bool = True, 97 early_stopping: Optional[int] = None, 98 logger=TensorboardLogger, 99 logger_kwargs: Optional[Dict[str, Any]] = None, 100 id_: Optional[str] = None, 101 save_root: Optional[str] = None, 102 compile_model: Optional[Union[bool, str]] = None, 103 rank: Optional[int] = None, 104 ): 105 if name is None and not issubclass(logger, WandbLogger): 106 raise TypeError("Name cannot be None if not using the WandbLogger") 107 108 self._generate_name = name is None 109 self.name = name 110 self.id_ = id_ or name 111 self.train_loader = train_loader 112 self.val_loader = val_loader 113 self.model = model 114 self.loss = loss 115 self.optimizer = optimizer 116 self.metric = metric 117 self.device = torch.device(device) 118 self.lr_scheduler = lr_scheduler 119 self.log_image_interval = log_image_interval 120 self.save_root = save_root 121 self.compile_model = compile_model 122 self.rank = rank 123 124 self._iteration = 0 125 self._epoch = 0 126 self._best_epoch = 0 127 128 self.mixed_precision = mixed_precision 129 self.early_stopping = early_stopping 130 self.train_time = 0.0 131 132 if mixed_precision: 133 self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda") 134 else: 135 self.scaler = None 136 137 self.logger_class = logger 138 self.logger_kwargs = logger_kwargs 139 self.log_image_interval = log_image_interval 140 141 @property 142 def checkpoint_folder(self): 143 assert self.id_ is not None # Because the logger may generate and set trainer.id on logger.__init__. 144 # Save_root enables saving the checkpoints somewhere else than in the local folder. 145 # This is handy for filesystems with limited space, where saving the checkpoints 146 # and log files can lead to running out of space. 147 save_root = getattr(self, "save_root", None) 148 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 149 os.path.join(save_root, "./checkpoints", self.id_) 150 151 @property 152 def iteration(self): 153 return self._iteration 154 155 @property 156 def epoch(self): 157 return self._epoch 158 159 class Deserializer: 160 """Determines how to deserialize the trainer kwargs from serialized 'init_data'. 161 162 Examples: 163 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 164 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 165 166 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 167 >>> class MyTrainer(DefaultTrainer): 168 >>> def __init__(self, *args, the_answer: int, **kwargs): 169 >>> super().__init__(*args, **kwargs) 170 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 171 >>> # see DefaultTrainer.Serializer 172 >>> 173 >>> class Deserializer(DefaultTrainer.Deserializer): 174 >>> def load_the_answer(self): 175 >>> generic_answer = self.init_data["the_answer"] 176 >>> # (device dependent) special deserialization 177 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 178 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 179 >>> else: 180 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 181 182 Args: 183 init_data: The initialization data of the trainer. 184 save_path: The path where the checkpoint was saved. 185 device: The device. 186 """ 187 188 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 189 self.init_data = init_data 190 self.save_path = save_path 191 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 192 self.trainer_kwargs: Dict[str, Any] = dict( 193 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 194 ) 195 196 def load(self, kwarg_name: str, optional): 197 """@private 198 """ 199 # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name' 200 if kwarg_name == "device": 201 pass # deserialized in __init__ 202 elif kwarg_name.endswith("_loader"): 203 self.load_data_loader(kwarg_name, optional) 204 else: 205 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 206 load(kwarg_name, optional=optional) 207 208 def load_data_loader(self, loader_name, optional) -> None: 209 """@private 210 """ 211 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 212 if ds is None and optional: 213 return 214 215 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 216 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 217 # monkey patch shuffle loader_name to the loader 218 loader.shuffle = loader_kwargs.get("shuffle", False) 219 self.trainer_kwargs[loader_name] = loader 220 221 def load_generic( 222 self, 223 kwarg_name: str, 224 *dynamic_args: Dict, 225 optional: bool, 226 only_class: bool = False, 227 dynamic_kwargs: Optional[Dict[str, Any]] = None, 228 ) -> None: 229 """@private 230 """ 231 if kwarg_name in self.init_data: 232 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 233 return 234 235 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 236 if this_cls is None: 237 if optional: 238 return 239 else: 240 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 241 242 assert isinstance(this_cls, str), this_cls 243 assert "." in this_cls, this_cls 244 cls_p, cls_m = this_cls.rsplit(".", 1) 245 this_cls = getattr(import_module(cls_p), cls_m) 246 if only_class: 247 self.trainer_kwargs[kwarg_name] = this_cls 248 else: 249 self.trainer_kwargs[kwarg_name] = this_cls( 250 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 251 ) 252 253 def load_name(self, kwarg_name: str, optional: bool): 254 """@private 255 """ 256 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 257 258 def load_optimizer(self, kwarg_name: str, optional: bool): 259 """@private 260 """ 261 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 262 263 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 264 """@private 265 """ 266 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 267 268 # todo: remove and rename kwarg 'logger' to 'logger_class' 269 def load_logger(self, kwarg_name: str, optional: bool): 270 """@private 271 """ 272 assert kwarg_name == "logger" 273 self.load_generic("logger", optional=optional, only_class=True) 274 275 @staticmethod 276 def _get_save_dict(save_path, device): 277 if not os.path.exists(save_path): 278 raise ValueError(f"Cannot find checkpoint {save_path}") 279 return torch.load(save_path, map_location=device, weights_only=False) 280 281 @classmethod 282 def from_checkpoint( 283 cls, 284 checkpoint_folder: Union[os.PathLike, str], 285 name: Literal["best", "latest"] = "best", 286 device: Optional[Union[str, torch.device]] = None, 287 ): 288 """@private 289 """ 290 save_path = os.path.join(checkpoint_folder, f"{name}.pt") 291 # make sure the correct device is set if we don't have access to CUDA 292 if not torch.cuda.is_available(): 293 device = "cpu" 294 save_dict = cls._get_save_dict(save_path, device) 295 deserializer = cls.Deserializer(save_dict["init"], save_path, device) 296 297 has_kwargs = False 298 deserialized = [] 299 for name, parameter in inspect.signature(cls).parameters.items(): 300 if name == "kwargs": 301 has_kwargs = True 302 continue 303 deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty) 304 deserialized.append(name) 305 306 # to deserialze kwargs we can't rely on inspecting the signature, so we 307 # go through the remaning kwarg names in init data instead 308 if has_kwargs: 309 kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized)) 310 for name in kwarg_names: 311 if name.endswith("_kwargs"): 312 continue 313 elif name.endswith("_dataset"): 314 deserializer.load(name.replace("dataset", "loader"), optional=False) 315 elif name.endswith("_class"): 316 deserializer.load(name.replace("_class", ""), optional=False) 317 else: 318 deserializer.load(name, optional=False) 319 320 trainer = cls(**deserializer.trainer_kwargs) 321 trainer._initialize(0, save_dict) 322 trainer._is_initialized = True 323 return trainer 324 325 class Serializer: 326 """Implements how to serialize trainer kwargs from a trainer instance. 327 328 Examples: 329 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 330 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 331 called by the `dump()` method when appropriate cover most cases already. 332 333 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 334 'the_answer' attribute: 335 >>> class MyTrainer(DefaultTrainer): 336 >>> def __init__(self, *args, the_answer: int, **kwargs): 337 >>> super().__init__(*args, **kwargs) 338 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 339 >>> # but let's make things more interesting... 340 >>> self.the = the_answer // 10 341 >>> self.answer = the_answer % 10 342 >>> 343 >>> class Serializer(DefaultTrainer.Serializer): 344 >>> trainer: MyTrainer 345 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 346 >>> assert kwarg_name == "the_answer" 347 >>> # populate self.init_data with the serialized data required by Deserializer 348 >>> # to restore the trainer kwargs 349 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 350 351 This example with both Serializer and Deserializer adds `the_answer` kwarg, 352 while saving it in two separate entries 'the' and 'answer' 353 >>> class MyTrainer(DefaultTrainer): 354 >>> def __init__(self, *args, the_answer: int, **kwargs): 355 >>> super().__init__(*args, **kwargs) 356 >>> self.the_answer = the_answer 357 >>> 358 >>> class Serializer(DefaultTrainer.Serializer): 359 >>> trainer: MyTrainer 360 >>> def dump_the_answer(self, kwarg_name: str): 361 >>> assert kwarg_name == "the_answer" 362 >>> self.init_data.update({ 363 >>> "the": self.trainer.the_answer // 10, 364 >>> "answer": self.trainer.the_answer % 10 365 >>> }) 366 >>> 367 >>> class Deserializer(DefaultTrainer.Deserializer): 368 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 369 >>> assert kwarg_name == "the_answer" 370 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 371 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 372 373 Args: 374 trainer: The trainer instance. 375 """ 376 377 def __init__(self, trainer: DefaultTrainer): 378 self.trainer = trainer 379 self.init_data = {} # to be populated during serialization process 380 381 def dump(self, kwarg_name: str) -> None: 382 """@private 383 """ 384 dumper = getattr(self, f"dump_{kwarg_name}", None) 385 if dumper is not None: 386 dumper(kwarg_name) 387 elif kwarg_name.endswith("_loader"): 388 self.dump_data_loader(kwarg_name) 389 elif kwarg_name.endswith("_class"): 390 self.dump_generic_class(kwarg_name) 391 elif not hasattr(self.trainer, kwarg_name): 392 raise AttributeError( 393 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 394 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 395 ) 396 else: 397 assert hasattr(self.trainer, kwarg_name) 398 obj = getattr(self.trainer, kwarg_name) 399 if obj is None or type(obj) in ( 400 bool, 401 bytearray, 402 bytes, 403 dict, 404 float, 405 frozenset, 406 int, 407 list, 408 set, 409 str, 410 tuple, 411 ): 412 self.dump_generic_builtin(kwarg_name) 413 else: 414 self.dump_generic_instance(kwarg_name) 415 416 def dump_generic_builtin(self, kwarg_name: str) -> None: 417 """@private 418 """ 419 assert hasattr(self.trainer, kwarg_name) 420 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 421 422 def dump_generic_class(self, kwarg_name: str) -> None: 423 """@private 424 """ 425 assert hasattr(self.trainer, kwarg_name) 426 assert kwarg_name.endswith("_class") 427 obj = getattr(self.trainer, kwarg_name) 428 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 429 430 def dump_generic_instance(self, kwarg_name: str) -> None: 431 """@private 432 """ 433 assert hasattr(self.trainer, kwarg_name) 434 instance = getattr(self.trainer, kwarg_name) 435 self.init_data.update( 436 { 437 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 438 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 439 } 440 ) 441 442 def dump_device(self, kwarg_name: str): 443 """@private 444 """ 445 assert hasattr(self.trainer, kwarg_name) 446 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 447 448 def dump_data_loader(self, kwarg_name: str) -> None: 449 """@private 450 """ 451 assert hasattr(self.trainer, kwarg_name) 452 loader = getattr(self.trainer, kwarg_name) 453 if loader is None: 454 return 455 self.init_data.update( 456 { 457 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 458 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 459 } 460 ) 461 462 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 463 """@private 464 """ 465 self.dump_generic_class(f"{kwarg_name}_class") 466 467 def dump_model(self, kwarg_name: str): 468 """@private 469 """ 470 if is_compiled(self.trainer.model): 471 self.init_data.update( 472 {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs} 473 ) 474 else: 475 self.dump_generic_instance("model") 476 477 def _build_init(self) -> Dict[str, Any]: 478 serializer = self.Serializer(self) 479 for name in inspect.signature(self.__class__).parameters: 480 # special rules to serialize kwargs 481 # if a trainer class inherits from DefaultTrainer and has **kwargs 482 # they need to be saved in self._kwargs 483 if name == "kwargs": 484 if not hasattr(self, "_kwargs"): 485 msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\ 486 "Please add self._kwargs to its __init__ function" 487 raise RuntimeError(msg) 488 kwargs = getattr(self, "_kwargs") 489 for kwarg_name in kwargs: 490 serializer.dump(kwarg_name) 491 continue 492 serializer.dump(name) 493 494 return serializer.init_data 495 496 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 497 assert self.train_loader is not None 498 assert self.val_loader is not None 499 assert self.model is not None 500 assert self.loss is not None 501 assert self.optimizer is not None 502 assert self.metric is not None 503 assert self.device is not None 504 505 if load_from_checkpoint is not None: 506 self.load_checkpoint(load_from_checkpoint) 507 508 if sum((iterations is not None, epochs is not None)) != 1: 509 raise ValueError( 510 "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer." 511 f"You have passed 'iterations'={iterations} and 'epochs'={epochs}" 512 ) 513 514 if epochs is None: 515 epochs = int(np.ceil(float(iterations) / len(self.train_loader))) 516 else: 517 iterations = epochs * len(self.train_loader) 518 519 self.max_iteration = self._iteration + iterations 520 self.max_epoch = self._epoch + epochs 521 522 if not getattr(self, "_is_initialized", False): 523 # check if we compile the model (only supported by pytorch 2) 524 # to enable (de)serialization of compiled models, we keep track of the model class and kwargs 525 if is_compiled(self.model): 526 warnings.warn( 527 "You have passed a compiled model to the trainer." 528 "It will not be possible to (de)serialize the trainer with it." 529 "If you want to be able to do this please pass the normal model." 530 "It can be automatically compiled by setting 'compile_model' to True" 531 ) 532 self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}" 533 self._model_kwargs = get_constructor_arguments(self.model) 534 self.model = auto_compile(self.model, self.compile_model) 535 536 self.model.to(self.device) 537 self.loss.to(self.device) 538 539 # this saves all the information that is necessary 540 # to fully load the trainer from the checkpoint 541 self.init_data = self._build_init() 542 543 if self.logger_class is None: 544 self.logger = None 545 else: 546 # may set self.name if self.name is None 547 save_root = getattr(self, "save_root", None) 548 try: 549 self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {})) 550 except (PermissionError, RuntimeError): 551 warnings.warn( 552 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 553 "The most likely reason for this is that you copied the checkpoint somewhere else," 554 "so we skip this error to enable loading the model from this checkpoint." 555 ) 556 557 try: 558 os.makedirs(self.checkpoint_folder, exist_ok=True) 559 except PermissionError: 560 warnings.warn( 561 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 562 "The most likely reason for this is that you copied the checkpoint somewhere else," 563 "so we skip this error to enable loading the model from this checkpoint." 564 ) 565 pass 566 567 best_metric = np.inf 568 return best_metric 569 570 def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict): 571 """@private 572 """ 573 save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") 574 extra_init_dict = extra_save_dict.pop("init", {}) 575 save_dict = { 576 "iteration": self._iteration, 577 "epoch": self._epoch, 578 "best_epoch": self._best_epoch, 579 "best_metric": best_metric, 580 "current_metric": current_metric, 581 "model_state": self.model.state_dict(), 582 "optimizer_state": self.optimizer.state_dict(), 583 "init": self.init_data | extra_init_dict, 584 "train_time": train_time, 585 "timestamp": datetime.now().strftime("%d-%m-%Y (%H:%M:%S)"), 586 } 587 save_dict.update(**extra_save_dict) 588 if self.scaler is not None: 589 save_dict.update({"scaler_state": self.scaler.state_dict()}) 590 if self.lr_scheduler is not None: 591 save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()}) 592 593 rank = getattr(self, "rank", None) 594 if rank is None or rank == 0: 595 torch.save(save_dict, save_path) 596 597 def load_checkpoint(self, checkpoint="best"): 598 """@private 599 """ 600 if isinstance(checkpoint, str): 601 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 602 if not os.path.exists(save_path): 603 warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.") 604 return 605 save_dict = torch.load(save_path, weights_only=False) 606 elif isinstance(checkpoint, dict): 607 save_dict = checkpoint 608 else: 609 raise RuntimeError 610 611 self._iteration = save_dict["iteration"] 612 self._epoch = save_dict["epoch"] 613 self._best_epoch = save_dict["best_epoch"] 614 self.best_metric = save_dict["best_metric"] 615 self.current_metric = save_dict["current_metric"] 616 self.train_time = save_dict.get("train_time", 0.0) 617 618 model_state = save_dict["model_state"] 619 # to enable loading compiled models 620 compiled_prefix = "_orig_mod." 621 model_state = OrderedDict( 622 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()] 623 ) 624 self.model.load_state_dict(model_state) 625 # we need to send the network to the device before loading the optimizer state! 626 self.model.to(self.device) 627 628 self.optimizer.load_state_dict(save_dict["optimizer_state"]) 629 if self.scaler is not None: 630 self.scaler.load_state_dict(save_dict["scaler_state"]) 631 if self.lr_scheduler is not None: 632 self.lr_scheduler.load_state_dict(save_dict["scheduler_state"]) 633 634 return save_dict 635 636 def _verify_if_training_completed(self, checkpoint="latest"): 637 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 638 save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None 639 if save_dict and self.max_iteration == save_dict.get("iteration"): 640 return True 641 return False 642 643 def fit( 644 self, 645 iterations: Optional[int] = None, 646 load_from_checkpoint: Optional[Union[os.PathLike, str]] = None, 647 epochs: Optional[int] = None, 648 save_every_kth_epoch: Optional[int] = None, 649 progress=None, 650 overwrite_training: bool = True, 651 ): 652 """Run neural network training. 653 654 Exactly one of 'iterations' or 'epochs' has to be passed. 655 656 Args: 657 iterations: How long to train, specified in iterations. 658 load_from_checkpoint: Path to a checkpoint from where training should be continued . 659 epochs: How long to train, specified in epochs. 660 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 661 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 662 progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface. 663 overwrite_training: Whether to overwrite existing checkpoints in the save directory. 664 """ 665 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 666 667 if not overwrite_training: 668 if load_from_checkpoint is not None: 669 raise ValueError( 670 "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time." 671 ) 672 673 if self._verify_if_training_completed(): 674 print( 675 f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs " 676 "and 'overwrite_training' is set to 'False'." 677 ) 678 print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.") 679 return 680 681 print( 682 "Start fitting for", 683 self.max_iteration - self._iteration, 684 "iterations / ", 685 self.max_epoch - self._epoch, 686 "epochs", 687 ) 688 print("with", len(self.train_loader), "iterations per epoch") 689 690 if self.mixed_precision: 691 train_epoch = self._train_epoch_mixed 692 validate = self._validate_mixed 693 print("Training with mixed precision") 694 else: 695 train_epoch = self._train_epoch 696 validate = self._validate 697 print("Training with single precision") 698 699 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 700 if progress is None: 701 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 702 else: 703 progress.total = total_iterations 704 progress.set_description(f"Epoch {self._epoch}") 705 706 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 707 train_epochs = self.max_epoch - self._epoch 708 t_start = time.time() 709 for epoch in range(train_epochs): 710 711 # Ensure data is shuffled differently at each epoch. 712 try: 713 self.train_loader.sampler.set_epoch(epoch) 714 except AttributeError: 715 pass 716 717 # Run training and validation for this epoch 718 t_per_iter = train_epoch(progress) 719 current_metric = validate() 720 721 # perform all the post-epoch steps: 722 723 # apply the learning rate scheduler 724 if self.lr_scheduler is not None: 725 self.lr_scheduler.step(current_metric) 726 727 # how long did we train in total? 728 total_train_time = (time.time() - t_start) + self.train_time 729 730 # save this checkpoint as the new best checkpoint if 731 # it has the best overall validation metric 732 if current_metric < best_metric: 733 best_metric = current_metric 734 self._best_epoch = self._epoch 735 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 736 737 # save this checkpoint as the latest checkpoint 738 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 739 740 # if we save after every k-th epoch then check if we need to save now 741 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 742 self.save_checkpoint( 743 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 744 ) 745 746 # if early stopping has been specified then check if the stopping condition is met 747 if self.early_stopping is not None: 748 epochs_since_best = self._epoch - self._best_epoch 749 if epochs_since_best > self.early_stopping: 750 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 751 break 752 753 self._epoch += 1 754 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 755 756 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 757 print(f"The best epoch is number {self._best_epoch}.") 758 759 if self._generate_name: 760 self.name = None 761 762 # Update the train time 763 self.train_time = total_train_time 764 765 # TODO save the model to wandb if we have the wandb logger 766 if isinstance(self.logger, WandbLogger): 767 self.logger.get_wandb().finish() 768 769 def _backprop(self, loss): 770 loss.backward() 771 self.optimizer.step() 772 773 def _backprop_mixed(self, loss): 774 self.scaler.scale(loss).backward() 775 self.scaler.step(self.optimizer) 776 self.scaler.update() 777 778 def _train_epoch(self, progress): 779 return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop) 780 781 def _train_epoch_mixed(self, progress): 782 return self._train_epoch_impl( 783 progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"), 784 self._backprop_mixed 785 ) 786 787 def _forward_and_loss(self, x, y): 788 pred = self.model(x) 789 if self._iteration % self.log_image_interval == 0: 790 if pred.requires_grad: 791 pred.retain_grad() 792 793 loss = self.loss(pred, y) 794 return pred, loss 795 796 def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]): 797 self.model.train() 798 799 n_iter = 0 800 t_per_iter = time.time() 801 for x, y in self.train_loader: 802 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 803 804 self.optimizer.zero_grad() 805 806 with forward_context(): 807 pred, loss = self._forward_and_loss(x, y) 808 809 backprop(loss) 810 811 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 812 if self.logger is not None: 813 self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True) 814 815 self._iteration += 1 816 n_iter += 1 817 if self._iteration >= self.max_iteration: 818 break 819 progress.update(1) 820 821 t_per_iter = (time.time() - t_per_iter) / n_iter 822 return t_per_iter 823 824 def _validate(self): 825 return self._validate_impl(contextlib.nullcontext) 826 827 def _validate_mixed(self): 828 return self._validate_impl( 829 partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda") 830 ) 831 832 def _validate_impl(self, forward_context): 833 self.model.eval() 834 835 metric_val = 0.0 836 loss_val = 0.0 837 838 with torch.no_grad(): 839 for x, y in self.val_loader: 840 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 841 with forward_context(): 842 pred, loss = self._forward_and_loss(x, y) 843 metric = self.metric(pred, y) 844 845 loss_val += loss.item() 846 metric_val += metric.item() 847 848 metric_val /= len(self.val_loader) 849 loss_val /= len(self.val_loader) 850 if self.logger is not None: 851 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred) 852 return metric_val
25class DefaultTrainer: 26 """Trainer class for training a segmentation network. 27 28 The trainer class implements the core logic for training a network with pytorch. 29 It implements a training loop to run training and validation, which is started with `fit`. 30 The checkpoints and logs from the training run will be saved in the current working directory, 31 or in the directory specifified by `save_root`. Training can be continued from a checkpoint 32 by passing it's location to the `load_from_checkpoint` argument of `fit`. 33 34 A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`. 35 Alternatively, the trainer class can also be instantiated as in this example: 36 ```python 37 import torch 38 from torch_em.loss import DiceLoss 39 from torch_em.model import UNet2d 40 from torch_em.data.datasets.light_microscopy import get_dsb_loader 41 from torch_em.trainer import DefaultTrainer 42 43 # The training data will be downloaded to this location. 44 data_root = "/path/to/save/the/training/data" 45 patch_shape = (256, 256) 46 47 # Create the model and optimizer. 48 model = UNet2d(in_channels=1, out_channels=1) 49 optimizer = torch.optim.AdamW(model.parameters()) 50 51 trainer = DefaultTrainer( 52 name="unet-training", 53 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 54 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 55 model=model, 56 loss=DiceLoss(), # The loss function. 57 optimizer=optimizer, 58 metric=DiceLoss(), # The metric. The trainer expects smaller values to represent better results. 59 device="cuda", # The device to use for training. 60 ) 61 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 62 ``` 63 64 Args: 65 name: The name of the checkpoint that will be created by the trainer. 66 train_loader: The data loader containing the training data. 67 val_loader: The data loader containing the validation data. 68 model: The model to train. 69 loss: The loss function for training. 70 optimizer: The optimizer. 71 metric: The metric for validation. 72 device: The torch device to use for training. If None, will use a GPU if available. 73 lr_scheduler: The learning rate scheduler. 74 log_image_interval: The interval for saving images during logging, in training iterations. 75 mixed_precision: Whether to train with mixed precision. 76 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 77 logger: The logger class. Will be instantiated for logging. 78 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 79 logger_kwargs: The keyword arguments for the logger class. 80 id_: Unique identifier for the trainer. If None then `name` will be used. 81 save_root: The root folder for saving the checkpoint and logs. 82 compile_model: Whether to compile the model before training. 83 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 84 """ 85 def __init__( 86 self, 87 name: Optional[str], 88 train_loader: torch.utils.data.DataLoader, 89 val_loader: torch.utils.data.DataLoader, 90 model: torch.nn.Module, 91 loss: torch.nn.Module, 92 optimizer: torch.optim.Optimizer, 93 metric: Callable, 94 device: Union[str, torch.device], 95 lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 96 log_image_interval: int = 100, 97 mixed_precision: bool = True, 98 early_stopping: Optional[int] = None, 99 logger=TensorboardLogger, 100 logger_kwargs: Optional[Dict[str, Any]] = None, 101 id_: Optional[str] = None, 102 save_root: Optional[str] = None, 103 compile_model: Optional[Union[bool, str]] = None, 104 rank: Optional[int] = None, 105 ): 106 if name is None and not issubclass(logger, WandbLogger): 107 raise TypeError("Name cannot be None if not using the WandbLogger") 108 109 self._generate_name = name is None 110 self.name = name 111 self.id_ = id_ or name 112 self.train_loader = train_loader 113 self.val_loader = val_loader 114 self.model = model 115 self.loss = loss 116 self.optimizer = optimizer 117 self.metric = metric 118 self.device = torch.device(device) 119 self.lr_scheduler = lr_scheduler 120 self.log_image_interval = log_image_interval 121 self.save_root = save_root 122 self.compile_model = compile_model 123 self.rank = rank 124 125 self._iteration = 0 126 self._epoch = 0 127 self._best_epoch = 0 128 129 self.mixed_precision = mixed_precision 130 self.early_stopping = early_stopping 131 self.train_time = 0.0 132 133 if mixed_precision: 134 self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda") 135 else: 136 self.scaler = None 137 138 self.logger_class = logger 139 self.logger_kwargs = logger_kwargs 140 self.log_image_interval = log_image_interval 141 142 @property 143 def checkpoint_folder(self): 144 assert self.id_ is not None # Because the logger may generate and set trainer.id on logger.__init__. 145 # Save_root enables saving the checkpoints somewhere else than in the local folder. 146 # This is handy for filesystems with limited space, where saving the checkpoints 147 # and log files can lead to running out of space. 148 save_root = getattr(self, "save_root", None) 149 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 150 os.path.join(save_root, "./checkpoints", self.id_) 151 152 @property 153 def iteration(self): 154 return self._iteration 155 156 @property 157 def epoch(self): 158 return self._epoch 159 160 class Deserializer: 161 """Determines how to deserialize the trainer kwargs from serialized 'init_data'. 162 163 Examples: 164 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 165 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 166 167 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 168 >>> class MyTrainer(DefaultTrainer): 169 >>> def __init__(self, *args, the_answer: int, **kwargs): 170 >>> super().__init__(*args, **kwargs) 171 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 172 >>> # see DefaultTrainer.Serializer 173 >>> 174 >>> class Deserializer(DefaultTrainer.Deserializer): 175 >>> def load_the_answer(self): 176 >>> generic_answer = self.init_data["the_answer"] 177 >>> # (device dependent) special deserialization 178 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 179 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 180 >>> else: 181 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 182 183 Args: 184 init_data: The initialization data of the trainer. 185 save_path: The path where the checkpoint was saved. 186 device: The device. 187 """ 188 189 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 190 self.init_data = init_data 191 self.save_path = save_path 192 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 193 self.trainer_kwargs: Dict[str, Any] = dict( 194 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 195 ) 196 197 def load(self, kwarg_name: str, optional): 198 """@private 199 """ 200 # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name' 201 if kwarg_name == "device": 202 pass # deserialized in __init__ 203 elif kwarg_name.endswith("_loader"): 204 self.load_data_loader(kwarg_name, optional) 205 else: 206 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 207 load(kwarg_name, optional=optional) 208 209 def load_data_loader(self, loader_name, optional) -> None: 210 """@private 211 """ 212 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 213 if ds is None and optional: 214 return 215 216 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 217 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 218 # monkey patch shuffle loader_name to the loader 219 loader.shuffle = loader_kwargs.get("shuffle", False) 220 self.trainer_kwargs[loader_name] = loader 221 222 def load_generic( 223 self, 224 kwarg_name: str, 225 *dynamic_args: Dict, 226 optional: bool, 227 only_class: bool = False, 228 dynamic_kwargs: Optional[Dict[str, Any]] = None, 229 ) -> None: 230 """@private 231 """ 232 if kwarg_name in self.init_data: 233 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 234 return 235 236 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 237 if this_cls is None: 238 if optional: 239 return 240 else: 241 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 242 243 assert isinstance(this_cls, str), this_cls 244 assert "." in this_cls, this_cls 245 cls_p, cls_m = this_cls.rsplit(".", 1) 246 this_cls = getattr(import_module(cls_p), cls_m) 247 if only_class: 248 self.trainer_kwargs[kwarg_name] = this_cls 249 else: 250 self.trainer_kwargs[kwarg_name] = this_cls( 251 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 252 ) 253 254 def load_name(self, kwarg_name: str, optional: bool): 255 """@private 256 """ 257 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 258 259 def load_optimizer(self, kwarg_name: str, optional: bool): 260 """@private 261 """ 262 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 263 264 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 265 """@private 266 """ 267 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 268 269 # todo: remove and rename kwarg 'logger' to 'logger_class' 270 def load_logger(self, kwarg_name: str, optional: bool): 271 """@private 272 """ 273 assert kwarg_name == "logger" 274 self.load_generic("logger", optional=optional, only_class=True) 275 276 @staticmethod 277 def _get_save_dict(save_path, device): 278 if not os.path.exists(save_path): 279 raise ValueError(f"Cannot find checkpoint {save_path}") 280 return torch.load(save_path, map_location=device, weights_only=False) 281 282 @classmethod 283 def from_checkpoint( 284 cls, 285 checkpoint_folder: Union[os.PathLike, str], 286 name: Literal["best", "latest"] = "best", 287 device: Optional[Union[str, torch.device]] = None, 288 ): 289 """@private 290 """ 291 save_path = os.path.join(checkpoint_folder, f"{name}.pt") 292 # make sure the correct device is set if we don't have access to CUDA 293 if not torch.cuda.is_available(): 294 device = "cpu" 295 save_dict = cls._get_save_dict(save_path, device) 296 deserializer = cls.Deserializer(save_dict["init"], save_path, device) 297 298 has_kwargs = False 299 deserialized = [] 300 for name, parameter in inspect.signature(cls).parameters.items(): 301 if name == "kwargs": 302 has_kwargs = True 303 continue 304 deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty) 305 deserialized.append(name) 306 307 # to deserialze kwargs we can't rely on inspecting the signature, so we 308 # go through the remaning kwarg names in init data instead 309 if has_kwargs: 310 kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized)) 311 for name in kwarg_names: 312 if name.endswith("_kwargs"): 313 continue 314 elif name.endswith("_dataset"): 315 deserializer.load(name.replace("dataset", "loader"), optional=False) 316 elif name.endswith("_class"): 317 deserializer.load(name.replace("_class", ""), optional=False) 318 else: 319 deserializer.load(name, optional=False) 320 321 trainer = cls(**deserializer.trainer_kwargs) 322 trainer._initialize(0, save_dict) 323 trainer._is_initialized = True 324 return trainer 325 326 class Serializer: 327 """Implements how to serialize trainer kwargs from a trainer instance. 328 329 Examples: 330 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 331 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 332 called by the `dump()` method when appropriate cover most cases already. 333 334 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 335 'the_answer' attribute: 336 >>> class MyTrainer(DefaultTrainer): 337 >>> def __init__(self, *args, the_answer: int, **kwargs): 338 >>> super().__init__(*args, **kwargs) 339 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 340 >>> # but let's make things more interesting... 341 >>> self.the = the_answer // 10 342 >>> self.answer = the_answer % 10 343 >>> 344 >>> class Serializer(DefaultTrainer.Serializer): 345 >>> trainer: MyTrainer 346 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 347 >>> assert kwarg_name == "the_answer" 348 >>> # populate self.init_data with the serialized data required by Deserializer 349 >>> # to restore the trainer kwargs 350 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 351 352 This example with both Serializer and Deserializer adds `the_answer` kwarg, 353 while saving it in two separate entries 'the' and 'answer' 354 >>> class MyTrainer(DefaultTrainer): 355 >>> def __init__(self, *args, the_answer: int, **kwargs): 356 >>> super().__init__(*args, **kwargs) 357 >>> self.the_answer = the_answer 358 >>> 359 >>> class Serializer(DefaultTrainer.Serializer): 360 >>> trainer: MyTrainer 361 >>> def dump_the_answer(self, kwarg_name: str): 362 >>> assert kwarg_name == "the_answer" 363 >>> self.init_data.update({ 364 >>> "the": self.trainer.the_answer // 10, 365 >>> "answer": self.trainer.the_answer % 10 366 >>> }) 367 >>> 368 >>> class Deserializer(DefaultTrainer.Deserializer): 369 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 370 >>> assert kwarg_name == "the_answer" 371 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 372 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 373 374 Args: 375 trainer: The trainer instance. 376 """ 377 378 def __init__(self, trainer: DefaultTrainer): 379 self.trainer = trainer 380 self.init_data = {} # to be populated during serialization process 381 382 def dump(self, kwarg_name: str) -> None: 383 """@private 384 """ 385 dumper = getattr(self, f"dump_{kwarg_name}", None) 386 if dumper is not None: 387 dumper(kwarg_name) 388 elif kwarg_name.endswith("_loader"): 389 self.dump_data_loader(kwarg_name) 390 elif kwarg_name.endswith("_class"): 391 self.dump_generic_class(kwarg_name) 392 elif not hasattr(self.trainer, kwarg_name): 393 raise AttributeError( 394 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 395 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 396 ) 397 else: 398 assert hasattr(self.trainer, kwarg_name) 399 obj = getattr(self.trainer, kwarg_name) 400 if obj is None or type(obj) in ( 401 bool, 402 bytearray, 403 bytes, 404 dict, 405 float, 406 frozenset, 407 int, 408 list, 409 set, 410 str, 411 tuple, 412 ): 413 self.dump_generic_builtin(kwarg_name) 414 else: 415 self.dump_generic_instance(kwarg_name) 416 417 def dump_generic_builtin(self, kwarg_name: str) -> None: 418 """@private 419 """ 420 assert hasattr(self.trainer, kwarg_name) 421 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 422 423 def dump_generic_class(self, kwarg_name: str) -> None: 424 """@private 425 """ 426 assert hasattr(self.trainer, kwarg_name) 427 assert kwarg_name.endswith("_class") 428 obj = getattr(self.trainer, kwarg_name) 429 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 430 431 def dump_generic_instance(self, kwarg_name: str) -> None: 432 """@private 433 """ 434 assert hasattr(self.trainer, kwarg_name) 435 instance = getattr(self.trainer, kwarg_name) 436 self.init_data.update( 437 { 438 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 439 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 440 } 441 ) 442 443 def dump_device(self, kwarg_name: str): 444 """@private 445 """ 446 assert hasattr(self.trainer, kwarg_name) 447 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 448 449 def dump_data_loader(self, kwarg_name: str) -> None: 450 """@private 451 """ 452 assert hasattr(self.trainer, kwarg_name) 453 loader = getattr(self.trainer, kwarg_name) 454 if loader is None: 455 return 456 self.init_data.update( 457 { 458 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 459 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 460 } 461 ) 462 463 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 464 """@private 465 """ 466 self.dump_generic_class(f"{kwarg_name}_class") 467 468 def dump_model(self, kwarg_name: str): 469 """@private 470 """ 471 if is_compiled(self.trainer.model): 472 self.init_data.update( 473 {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs} 474 ) 475 else: 476 self.dump_generic_instance("model") 477 478 def _build_init(self) -> Dict[str, Any]: 479 serializer = self.Serializer(self) 480 for name in inspect.signature(self.__class__).parameters: 481 # special rules to serialize kwargs 482 # if a trainer class inherits from DefaultTrainer and has **kwargs 483 # they need to be saved in self._kwargs 484 if name == "kwargs": 485 if not hasattr(self, "_kwargs"): 486 msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\ 487 "Please add self._kwargs to its __init__ function" 488 raise RuntimeError(msg) 489 kwargs = getattr(self, "_kwargs") 490 for kwarg_name in kwargs: 491 serializer.dump(kwarg_name) 492 continue 493 serializer.dump(name) 494 495 return serializer.init_data 496 497 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 498 assert self.train_loader is not None 499 assert self.val_loader is not None 500 assert self.model is not None 501 assert self.loss is not None 502 assert self.optimizer is not None 503 assert self.metric is not None 504 assert self.device is not None 505 506 if load_from_checkpoint is not None: 507 self.load_checkpoint(load_from_checkpoint) 508 509 if sum((iterations is not None, epochs is not None)) != 1: 510 raise ValueError( 511 "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer." 512 f"You have passed 'iterations'={iterations} and 'epochs'={epochs}" 513 ) 514 515 if epochs is None: 516 epochs = int(np.ceil(float(iterations) / len(self.train_loader))) 517 else: 518 iterations = epochs * len(self.train_loader) 519 520 self.max_iteration = self._iteration + iterations 521 self.max_epoch = self._epoch + epochs 522 523 if not getattr(self, "_is_initialized", False): 524 # check if we compile the model (only supported by pytorch 2) 525 # to enable (de)serialization of compiled models, we keep track of the model class and kwargs 526 if is_compiled(self.model): 527 warnings.warn( 528 "You have passed a compiled model to the trainer." 529 "It will not be possible to (de)serialize the trainer with it." 530 "If you want to be able to do this please pass the normal model." 531 "It can be automatically compiled by setting 'compile_model' to True" 532 ) 533 self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}" 534 self._model_kwargs = get_constructor_arguments(self.model) 535 self.model = auto_compile(self.model, self.compile_model) 536 537 self.model.to(self.device) 538 self.loss.to(self.device) 539 540 # this saves all the information that is necessary 541 # to fully load the trainer from the checkpoint 542 self.init_data = self._build_init() 543 544 if self.logger_class is None: 545 self.logger = None 546 else: 547 # may set self.name if self.name is None 548 save_root = getattr(self, "save_root", None) 549 try: 550 self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {})) 551 except (PermissionError, RuntimeError): 552 warnings.warn( 553 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 554 "The most likely reason for this is that you copied the checkpoint somewhere else," 555 "so we skip this error to enable loading the model from this checkpoint." 556 ) 557 558 try: 559 os.makedirs(self.checkpoint_folder, exist_ok=True) 560 except PermissionError: 561 warnings.warn( 562 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 563 "The most likely reason for this is that you copied the checkpoint somewhere else," 564 "so we skip this error to enable loading the model from this checkpoint." 565 ) 566 pass 567 568 best_metric = np.inf 569 return best_metric 570 571 def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict): 572 """@private 573 """ 574 save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") 575 extra_init_dict = extra_save_dict.pop("init", {}) 576 save_dict = { 577 "iteration": self._iteration, 578 "epoch": self._epoch, 579 "best_epoch": self._best_epoch, 580 "best_metric": best_metric, 581 "current_metric": current_metric, 582 "model_state": self.model.state_dict(), 583 "optimizer_state": self.optimizer.state_dict(), 584 "init": self.init_data | extra_init_dict, 585 "train_time": train_time, 586 "timestamp": datetime.now().strftime("%d-%m-%Y (%H:%M:%S)"), 587 } 588 save_dict.update(**extra_save_dict) 589 if self.scaler is not None: 590 save_dict.update({"scaler_state": self.scaler.state_dict()}) 591 if self.lr_scheduler is not None: 592 save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()}) 593 594 rank = getattr(self, "rank", None) 595 if rank is None or rank == 0: 596 torch.save(save_dict, save_path) 597 598 def load_checkpoint(self, checkpoint="best"): 599 """@private 600 """ 601 if isinstance(checkpoint, str): 602 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 603 if not os.path.exists(save_path): 604 warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.") 605 return 606 save_dict = torch.load(save_path, weights_only=False) 607 elif isinstance(checkpoint, dict): 608 save_dict = checkpoint 609 else: 610 raise RuntimeError 611 612 self._iteration = save_dict["iteration"] 613 self._epoch = save_dict["epoch"] 614 self._best_epoch = save_dict["best_epoch"] 615 self.best_metric = save_dict["best_metric"] 616 self.current_metric = save_dict["current_metric"] 617 self.train_time = save_dict.get("train_time", 0.0) 618 619 model_state = save_dict["model_state"] 620 # to enable loading compiled models 621 compiled_prefix = "_orig_mod." 622 model_state = OrderedDict( 623 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()] 624 ) 625 self.model.load_state_dict(model_state) 626 # we need to send the network to the device before loading the optimizer state! 627 self.model.to(self.device) 628 629 self.optimizer.load_state_dict(save_dict["optimizer_state"]) 630 if self.scaler is not None: 631 self.scaler.load_state_dict(save_dict["scaler_state"]) 632 if self.lr_scheduler is not None: 633 self.lr_scheduler.load_state_dict(save_dict["scheduler_state"]) 634 635 return save_dict 636 637 def _verify_if_training_completed(self, checkpoint="latest"): 638 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 639 save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None 640 if save_dict and self.max_iteration == save_dict.get("iteration"): 641 return True 642 return False 643 644 def fit( 645 self, 646 iterations: Optional[int] = None, 647 load_from_checkpoint: Optional[Union[os.PathLike, str]] = None, 648 epochs: Optional[int] = None, 649 save_every_kth_epoch: Optional[int] = None, 650 progress=None, 651 overwrite_training: bool = True, 652 ): 653 """Run neural network training. 654 655 Exactly one of 'iterations' or 'epochs' has to be passed. 656 657 Args: 658 iterations: How long to train, specified in iterations. 659 load_from_checkpoint: Path to a checkpoint from where training should be continued . 660 epochs: How long to train, specified in epochs. 661 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 662 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 663 progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface. 664 overwrite_training: Whether to overwrite existing checkpoints in the save directory. 665 """ 666 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 667 668 if not overwrite_training: 669 if load_from_checkpoint is not None: 670 raise ValueError( 671 "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time." 672 ) 673 674 if self._verify_if_training_completed(): 675 print( 676 f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs " 677 "and 'overwrite_training' is set to 'False'." 678 ) 679 print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.") 680 return 681 682 print( 683 "Start fitting for", 684 self.max_iteration - self._iteration, 685 "iterations / ", 686 self.max_epoch - self._epoch, 687 "epochs", 688 ) 689 print("with", len(self.train_loader), "iterations per epoch") 690 691 if self.mixed_precision: 692 train_epoch = self._train_epoch_mixed 693 validate = self._validate_mixed 694 print("Training with mixed precision") 695 else: 696 train_epoch = self._train_epoch 697 validate = self._validate 698 print("Training with single precision") 699 700 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 701 if progress is None: 702 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 703 else: 704 progress.total = total_iterations 705 progress.set_description(f"Epoch {self._epoch}") 706 707 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 708 train_epochs = self.max_epoch - self._epoch 709 t_start = time.time() 710 for epoch in range(train_epochs): 711 712 # Ensure data is shuffled differently at each epoch. 713 try: 714 self.train_loader.sampler.set_epoch(epoch) 715 except AttributeError: 716 pass 717 718 # Run training and validation for this epoch 719 t_per_iter = train_epoch(progress) 720 current_metric = validate() 721 722 # perform all the post-epoch steps: 723 724 # apply the learning rate scheduler 725 if self.lr_scheduler is not None: 726 self.lr_scheduler.step(current_metric) 727 728 # how long did we train in total? 729 total_train_time = (time.time() - t_start) + self.train_time 730 731 # save this checkpoint as the new best checkpoint if 732 # it has the best overall validation metric 733 if current_metric < best_metric: 734 best_metric = current_metric 735 self._best_epoch = self._epoch 736 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 737 738 # save this checkpoint as the latest checkpoint 739 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 740 741 # if we save after every k-th epoch then check if we need to save now 742 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 743 self.save_checkpoint( 744 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 745 ) 746 747 # if early stopping has been specified then check if the stopping condition is met 748 if self.early_stopping is not None: 749 epochs_since_best = self._epoch - self._best_epoch 750 if epochs_since_best > self.early_stopping: 751 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 752 break 753 754 self._epoch += 1 755 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 756 757 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 758 print(f"The best epoch is number {self._best_epoch}.") 759 760 if self._generate_name: 761 self.name = None 762 763 # Update the train time 764 self.train_time = total_train_time 765 766 # TODO save the model to wandb if we have the wandb logger 767 if isinstance(self.logger, WandbLogger): 768 self.logger.get_wandb().finish() 769 770 def _backprop(self, loss): 771 loss.backward() 772 self.optimizer.step() 773 774 def _backprop_mixed(self, loss): 775 self.scaler.scale(loss).backward() 776 self.scaler.step(self.optimizer) 777 self.scaler.update() 778 779 def _train_epoch(self, progress): 780 return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop) 781 782 def _train_epoch_mixed(self, progress): 783 return self._train_epoch_impl( 784 progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"), 785 self._backprop_mixed 786 ) 787 788 def _forward_and_loss(self, x, y): 789 pred = self.model(x) 790 if self._iteration % self.log_image_interval == 0: 791 if pred.requires_grad: 792 pred.retain_grad() 793 794 loss = self.loss(pred, y) 795 return pred, loss 796 797 def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]): 798 self.model.train() 799 800 n_iter = 0 801 t_per_iter = time.time() 802 for x, y in self.train_loader: 803 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 804 805 self.optimizer.zero_grad() 806 807 with forward_context(): 808 pred, loss = self._forward_and_loss(x, y) 809 810 backprop(loss) 811 812 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 813 if self.logger is not None: 814 self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True) 815 816 self._iteration += 1 817 n_iter += 1 818 if self._iteration >= self.max_iteration: 819 break 820 progress.update(1) 821 822 t_per_iter = (time.time() - t_per_iter) / n_iter 823 return t_per_iter 824 825 def _validate(self): 826 return self._validate_impl(contextlib.nullcontext) 827 828 def _validate_mixed(self): 829 return self._validate_impl( 830 partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda") 831 ) 832 833 def _validate_impl(self, forward_context): 834 self.model.eval() 835 836 metric_val = 0.0 837 loss_val = 0.0 838 839 with torch.no_grad(): 840 for x, y in self.val_loader: 841 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 842 with forward_context(): 843 pred, loss = self._forward_and_loss(x, y) 844 metric = self.metric(pred, y) 845 846 loss_val += loss.item() 847 metric_val += metric.item() 848 849 metric_val /= len(self.val_loader) 850 loss_val /= len(self.val_loader) 851 if self.logger is not None: 852 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred) 853 return metric_val
Trainer class for training a segmentation network.
The trainer class implements the core logic for training a network with pytorch.
It implements a training loop to run training and validation, which is started with fit.
The checkpoints and logs from the training run will be saved in the current working directory,
or in the directory specifified by save_root. Training can be continued from a checkpoint
by passing it's location to the load_from_checkpoint argument of fit.
A pre-configured instance of the trainer can be obtained from torch_em.default_segmentation_trainer.
Alternatively, the trainer class can also be instantiated as in this example:
import torch
from torch_em.loss import DiceLoss
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
from torch_em.trainer import DefaultTrainer
# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
# Create the model and optimizer.
model = UNet2d(in_channels=1, out_channels=1)
optimizer = torch.optim.AdamW(model.parameters())
trainer = DefaultTrainer(
name="unet-training",
train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
model=model,
loss=DiceLoss(), # The loss function.
optimizer=optimizer,
metric=DiceLoss(), # The metric. The trainer expects smaller values to represent better results.
device="cuda", # The device to use for training.
)
trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations.
Arguments:
- name: The name of the checkpoint that will be created by the trainer.
- train_loader: The data loader containing the training data.
- val_loader: The data loader containing the validation data.
- model: The model to train.
- loss: The loss function for training.
- optimizer: The optimizer.
- metric: The metric for validation.
- device: The torch device to use for training. If None, will use a GPU if available.
- lr_scheduler: The learning rate scheduler.
- log_image_interval: The interval for saving images during logging, in training iterations.
- mixed_precision: Whether to train with mixed precision.
- early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
- logger: The logger class. Will be instantiated for logging.
By default uses
torch_em.training.tensorboard_logger.TensorboardLogger. - logger_kwargs: The keyword arguments for the logger class.
- id_: Unique identifier for the trainer. If None then
namewill be used. - save_root: The root folder for saving the checkpoint and logs.
- compile_model: Whether to compile the model before training.
- rank: Rank argument for distributed training. See
torch_em.multi_gpu_trainingfor details.
85 def __init__( 86 self, 87 name: Optional[str], 88 train_loader: torch.utils.data.DataLoader, 89 val_loader: torch.utils.data.DataLoader, 90 model: torch.nn.Module, 91 loss: torch.nn.Module, 92 optimizer: torch.optim.Optimizer, 93 metric: Callable, 94 device: Union[str, torch.device], 95 lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 96 log_image_interval: int = 100, 97 mixed_precision: bool = True, 98 early_stopping: Optional[int] = None, 99 logger=TensorboardLogger, 100 logger_kwargs: Optional[Dict[str, Any]] = None, 101 id_: Optional[str] = None, 102 save_root: Optional[str] = None, 103 compile_model: Optional[Union[bool, str]] = None, 104 rank: Optional[int] = None, 105 ): 106 if name is None and not issubclass(logger, WandbLogger): 107 raise TypeError("Name cannot be None if not using the WandbLogger") 108 109 self._generate_name = name is None 110 self.name = name 111 self.id_ = id_ or name 112 self.train_loader = train_loader 113 self.val_loader = val_loader 114 self.model = model 115 self.loss = loss 116 self.optimizer = optimizer 117 self.metric = metric 118 self.device = torch.device(device) 119 self.lr_scheduler = lr_scheduler 120 self.log_image_interval = log_image_interval 121 self.save_root = save_root 122 self.compile_model = compile_model 123 self.rank = rank 124 125 self._iteration = 0 126 self._epoch = 0 127 self._best_epoch = 0 128 129 self.mixed_precision = mixed_precision 130 self.early_stopping = early_stopping 131 self.train_time = 0.0 132 133 if mixed_precision: 134 self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda") 135 else: 136 self.scaler = None 137 138 self.logger_class = logger 139 self.logger_kwargs = logger_kwargs 140 self.log_image_interval = log_image_interval
142 @property 143 def checkpoint_folder(self): 144 assert self.id_ is not None # Because the logger may generate and set trainer.id on logger.__init__. 145 # Save_root enables saving the checkpoints somewhere else than in the local folder. 146 # This is handy for filesystems with limited space, where saving the checkpoints 147 # and log files can lead to running out of space. 148 save_root = getattr(self, "save_root", None) 149 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 150 os.path.join(save_root, "./checkpoints", self.id_)
644 def fit( 645 self, 646 iterations: Optional[int] = None, 647 load_from_checkpoint: Optional[Union[os.PathLike, str]] = None, 648 epochs: Optional[int] = None, 649 save_every_kth_epoch: Optional[int] = None, 650 progress=None, 651 overwrite_training: bool = True, 652 ): 653 """Run neural network training. 654 655 Exactly one of 'iterations' or 'epochs' has to be passed. 656 657 Args: 658 iterations: How long to train, specified in iterations. 659 load_from_checkpoint: Path to a checkpoint from where training should be continued . 660 epochs: How long to train, specified in epochs. 661 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 662 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 663 progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface. 664 overwrite_training: Whether to overwrite existing checkpoints in the save directory. 665 """ 666 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 667 668 if not overwrite_training: 669 if load_from_checkpoint is not None: 670 raise ValueError( 671 "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time." 672 ) 673 674 if self._verify_if_training_completed(): 675 print( 676 f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs " 677 "and 'overwrite_training' is set to 'False'." 678 ) 679 print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.") 680 return 681 682 print( 683 "Start fitting for", 684 self.max_iteration - self._iteration, 685 "iterations / ", 686 self.max_epoch - self._epoch, 687 "epochs", 688 ) 689 print("with", len(self.train_loader), "iterations per epoch") 690 691 if self.mixed_precision: 692 train_epoch = self._train_epoch_mixed 693 validate = self._validate_mixed 694 print("Training with mixed precision") 695 else: 696 train_epoch = self._train_epoch 697 validate = self._validate 698 print("Training with single precision") 699 700 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 701 if progress is None: 702 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 703 else: 704 progress.total = total_iterations 705 progress.set_description(f"Epoch {self._epoch}") 706 707 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 708 train_epochs = self.max_epoch - self._epoch 709 t_start = time.time() 710 for epoch in range(train_epochs): 711 712 # Ensure data is shuffled differently at each epoch. 713 try: 714 self.train_loader.sampler.set_epoch(epoch) 715 except AttributeError: 716 pass 717 718 # Run training and validation for this epoch 719 t_per_iter = train_epoch(progress) 720 current_metric = validate() 721 722 # perform all the post-epoch steps: 723 724 # apply the learning rate scheduler 725 if self.lr_scheduler is not None: 726 self.lr_scheduler.step(current_metric) 727 728 # how long did we train in total? 729 total_train_time = (time.time() - t_start) + self.train_time 730 731 # save this checkpoint as the new best checkpoint if 732 # it has the best overall validation metric 733 if current_metric < best_metric: 734 best_metric = current_metric 735 self._best_epoch = self._epoch 736 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 737 738 # save this checkpoint as the latest checkpoint 739 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 740 741 # if we save after every k-th epoch then check if we need to save now 742 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 743 self.save_checkpoint( 744 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 745 ) 746 747 # if early stopping has been specified then check if the stopping condition is met 748 if self.early_stopping is not None: 749 epochs_since_best = self._epoch - self._best_epoch 750 if epochs_since_best > self.early_stopping: 751 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 752 break 753 754 self._epoch += 1 755 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 756 757 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 758 print(f"The best epoch is number {self._best_epoch}.") 759 760 if self._generate_name: 761 self.name = None 762 763 # Update the train time 764 self.train_time = total_train_time 765 766 # TODO save the model to wandb if we have the wandb logger 767 if isinstance(self.logger, WandbLogger): 768 self.logger.get_wandb().finish()
Run neural network training.
Exactly one of 'iterations' or 'epochs' has to be passed.
Arguments:
- iterations: How long to train, specified in iterations.
- load_from_checkpoint: Path to a checkpoint from where training should be continued .
- epochs: How long to train, specified in epochs.
- save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
- progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
- overwrite_training: Whether to overwrite existing checkpoints in the save directory.
160 class Deserializer: 161 """Determines how to deserialize the trainer kwargs from serialized 'init_data'. 162 163 Examples: 164 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 165 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 166 167 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 168 >>> class MyTrainer(DefaultTrainer): 169 >>> def __init__(self, *args, the_answer: int, **kwargs): 170 >>> super().__init__(*args, **kwargs) 171 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 172 >>> # see DefaultTrainer.Serializer 173 >>> 174 >>> class Deserializer(DefaultTrainer.Deserializer): 175 >>> def load_the_answer(self): 176 >>> generic_answer = self.init_data["the_answer"] 177 >>> # (device dependent) special deserialization 178 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 179 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 180 >>> else: 181 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 182 183 Args: 184 init_data: The initialization data of the trainer. 185 save_path: The path where the checkpoint was saved. 186 device: The device. 187 """ 188 189 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 190 self.init_data = init_data 191 self.save_path = save_path 192 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 193 self.trainer_kwargs: Dict[str, Any] = dict( 194 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 195 ) 196 197 def load(self, kwarg_name: str, optional): 198 """@private 199 """ 200 # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name' 201 if kwarg_name == "device": 202 pass # deserialized in __init__ 203 elif kwarg_name.endswith("_loader"): 204 self.load_data_loader(kwarg_name, optional) 205 else: 206 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 207 load(kwarg_name, optional=optional) 208 209 def load_data_loader(self, loader_name, optional) -> None: 210 """@private 211 """ 212 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 213 if ds is None and optional: 214 return 215 216 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 217 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 218 # monkey patch shuffle loader_name to the loader 219 loader.shuffle = loader_kwargs.get("shuffle", False) 220 self.trainer_kwargs[loader_name] = loader 221 222 def load_generic( 223 self, 224 kwarg_name: str, 225 *dynamic_args: Dict, 226 optional: bool, 227 only_class: bool = False, 228 dynamic_kwargs: Optional[Dict[str, Any]] = None, 229 ) -> None: 230 """@private 231 """ 232 if kwarg_name in self.init_data: 233 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 234 return 235 236 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 237 if this_cls is None: 238 if optional: 239 return 240 else: 241 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 242 243 assert isinstance(this_cls, str), this_cls 244 assert "." in this_cls, this_cls 245 cls_p, cls_m = this_cls.rsplit(".", 1) 246 this_cls = getattr(import_module(cls_p), cls_m) 247 if only_class: 248 self.trainer_kwargs[kwarg_name] = this_cls 249 else: 250 self.trainer_kwargs[kwarg_name] = this_cls( 251 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 252 ) 253 254 def load_name(self, kwarg_name: str, optional: bool): 255 """@private 256 """ 257 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 258 259 def load_optimizer(self, kwarg_name: str, optional: bool): 260 """@private 261 """ 262 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 263 264 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 265 """@private 266 """ 267 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 268 269 # todo: remove and rename kwarg 'logger' to 'logger_class' 270 def load_logger(self, kwarg_name: str, optional: bool): 271 """@private 272 """ 273 assert kwarg_name == "logger" 274 self.load_generic("logger", optional=optional, only_class=True)
Determines how to deserialize the trainer kwargs from serialized 'init_data'.
Examples:
To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. Note that
DefaultTrainer.Deserializer.load_generic()covers most cases already.This example adds
the_answerkwarg, which requires 'calculations' upon initialization:>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, >>> # see DefaultTrainer.Serializer >>> >>> class Deserializer(DefaultTrainer.Deserializer): >>> def load_the_answer(self): >>> generic_answer = self.init_data["the_answer"] >>> # (device dependent) special deserialization >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 >>> else: >>> self.trainer_kwargs["the_answer"] = generic_answer * 2
Arguments:
- init_data: The initialization data of the trainer.
- save_path: The path where the checkpoint was saved.
- device: The device.
189 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 190 self.init_data = init_data 191 self.save_path = save_path 192 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 193 self.trainer_kwargs: Dict[str, Any] = dict( 194 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 195 )
326 class Serializer: 327 """Implements how to serialize trainer kwargs from a trainer instance. 328 329 Examples: 330 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 331 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 332 called by the `dump()` method when appropriate cover most cases already. 333 334 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 335 'the_answer' attribute: 336 >>> class MyTrainer(DefaultTrainer): 337 >>> def __init__(self, *args, the_answer: int, **kwargs): 338 >>> super().__init__(*args, **kwargs) 339 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 340 >>> # but let's make things more interesting... 341 >>> self.the = the_answer // 10 342 >>> self.answer = the_answer % 10 343 >>> 344 >>> class Serializer(DefaultTrainer.Serializer): 345 >>> trainer: MyTrainer 346 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 347 >>> assert kwarg_name == "the_answer" 348 >>> # populate self.init_data with the serialized data required by Deserializer 349 >>> # to restore the trainer kwargs 350 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 351 352 This example with both Serializer and Deserializer adds `the_answer` kwarg, 353 while saving it in two separate entries 'the' and 'answer' 354 >>> class MyTrainer(DefaultTrainer): 355 >>> def __init__(self, *args, the_answer: int, **kwargs): 356 >>> super().__init__(*args, **kwargs) 357 >>> self.the_answer = the_answer 358 >>> 359 >>> class Serializer(DefaultTrainer.Serializer): 360 >>> trainer: MyTrainer 361 >>> def dump_the_answer(self, kwarg_name: str): 362 >>> assert kwarg_name == "the_answer" 363 >>> self.init_data.update({ 364 >>> "the": self.trainer.the_answer // 10, 365 >>> "answer": self.trainer.the_answer % 10 366 >>> }) 367 >>> 368 >>> class Deserializer(DefaultTrainer.Deserializer): 369 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 370 >>> assert kwarg_name == "the_answer" 371 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 372 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 373 374 Args: 375 trainer: The trainer instance. 376 """ 377 378 def __init__(self, trainer: DefaultTrainer): 379 self.trainer = trainer 380 self.init_data = {} # to be populated during serialization process 381 382 def dump(self, kwarg_name: str) -> None: 383 """@private 384 """ 385 dumper = getattr(self, f"dump_{kwarg_name}", None) 386 if dumper is not None: 387 dumper(kwarg_name) 388 elif kwarg_name.endswith("_loader"): 389 self.dump_data_loader(kwarg_name) 390 elif kwarg_name.endswith("_class"): 391 self.dump_generic_class(kwarg_name) 392 elif not hasattr(self.trainer, kwarg_name): 393 raise AttributeError( 394 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 395 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 396 ) 397 else: 398 assert hasattr(self.trainer, kwarg_name) 399 obj = getattr(self.trainer, kwarg_name) 400 if obj is None or type(obj) in ( 401 bool, 402 bytearray, 403 bytes, 404 dict, 405 float, 406 frozenset, 407 int, 408 list, 409 set, 410 str, 411 tuple, 412 ): 413 self.dump_generic_builtin(kwarg_name) 414 else: 415 self.dump_generic_instance(kwarg_name) 416 417 def dump_generic_builtin(self, kwarg_name: str) -> None: 418 """@private 419 """ 420 assert hasattr(self.trainer, kwarg_name) 421 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 422 423 def dump_generic_class(self, kwarg_name: str) -> None: 424 """@private 425 """ 426 assert hasattr(self.trainer, kwarg_name) 427 assert kwarg_name.endswith("_class") 428 obj = getattr(self.trainer, kwarg_name) 429 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 430 431 def dump_generic_instance(self, kwarg_name: str) -> None: 432 """@private 433 """ 434 assert hasattr(self.trainer, kwarg_name) 435 instance = getattr(self.trainer, kwarg_name) 436 self.init_data.update( 437 { 438 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 439 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 440 } 441 ) 442 443 def dump_device(self, kwarg_name: str): 444 """@private 445 """ 446 assert hasattr(self.trainer, kwarg_name) 447 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 448 449 def dump_data_loader(self, kwarg_name: str) -> None: 450 """@private 451 """ 452 assert hasattr(self.trainer, kwarg_name) 453 loader = getattr(self.trainer, kwarg_name) 454 if loader is None: 455 return 456 self.init_data.update( 457 { 458 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 459 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 460 } 461 ) 462 463 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 464 """@private 465 """ 466 self.dump_generic_class(f"{kwarg_name}_class") 467 468 def dump_model(self, kwarg_name: str): 469 """@private 470 """ 471 if is_compiled(self.trainer.model): 472 self.init_data.update( 473 {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs} 474 ) 475 else: 476 self.dump_generic_instance("model")
Implements how to serialize trainer kwargs from a trainer instance.
Examples:
To extend the serialization process you can inherite from this Serializer in a derived Trainer class. Note that the methods
dump_generic_builtin(),dump_generic_class()anddump_generic_instance()called by thedump()method when appropriate cover most cases already.This example adds
the_answerkwarg, which requires extra steps on dumping only because we don't keep a 'the_answer' attribute:>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, >>> # but let's make things more interesting... >>> self.the = the_answer // 10 >>> self.answer = the_answer % 10 >>> >>> class Serializer(DefaultTrainer.Serializer): >>> trainer: MyTrainer >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg >>> assert kwarg_name == "the_answer" >>> # populate self.init_data with the serialized data required by Deserializer >>> # to restore the trainer kwargs >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answerThis example with both Serializer and Deserializer adds
the_answerkwarg, while saving it in two separate entries 'the' and 'answer'>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> self.the_answer = the_answer >>> >>> class Serializer(DefaultTrainer.Serializer): >>> trainer: MyTrainer >>> def dump_the_answer(self, kwarg_name: str): >>> assert kwarg_name == "the_answer" >>> self.init_data.update({ >>> "the": self.trainer.the_answer // 10, >>> "answer": self.trainer.the_answer % 10 >>> }) >>> >>> class Deserializer(DefaultTrainer.Deserializer): >>> def load_the_answer(self, kwarg_name: str, optional: bool): >>> assert kwarg_name == "the_answer" >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
Arguments:
- trainer: The trainer instance.