torch_em.trainer.default_trainer
1from __future__ import annotations 2 3import os 4import time 5import inspect 6import warnings 7import contextlib 8from tqdm import tqdm 9from functools import partial 10from datetime import datetime 11from collections import OrderedDict 12from importlib import import_module 13from typing import Any, Callable, Dict, Optional, Union, Literal 14 15import numpy as np 16 17import torch 18 19from .wandb_logger import WandbLogger 20from .tensorboard_logger import TensorboardLogger 21from ..util import auto_compile, get_constructor_arguments, is_compiled 22 23 24class DefaultTrainer: 25 """Trainer class for training a segmentation network. 26 27 The trainer class implements the core logic for training a network with pytorch. 28 It implements a training loop to run training and validation, which is started with `fit`. 29 The checkpoints and logs from the training run will be saved in the current working directory, 30 or in the directory specifified by `save_root`. Training can be continued from a checkpoint 31 by passing it's location to the `load_from_checkpoint` argument of `fit`. 32 33 A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`. 34 Alternatively, the trainer class can also be instantiated as in this example: 35 ```python 36 import torch 37 from torch_em.loss import DiceLoss 38 from torch_em.model import UNet2d 39 from torch_em.data.datasets.light_microscopy import get_dsb_loader 40 from torch_em.trainer import DefaultTrainer 41 42 # The training data will be downloaded to this location. 43 data_root = "/path/to/save/the/training/data" 44 patch_shape = (256, 256) 45 46 # Create the model and optimizer. 47 model = UNet2d(in_channels=1, out_channels=1) 48 optimizer = torch.optim.AdamW(model.parameters()) 49 50 trainer = DefaultTrainer( 51 name="unet-training", 52 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 53 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 54 model=model, 55 loss=DiceLoss(), # The loss function. 56 optimizer=optimizer, 57 metric=DiceLoss(), # The metric. The trainer expects smaller values to represent better results. 58 device="cuda", # The device to use for training. 59 ) 60 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 61 ``` 62 63 Args: 64 name: The name of the checkpoint that will be created by the trainer. 65 train_loader: The data loader containing the training data. 66 val_loader: The data loader containing the validation data. 67 model: The model to train. 68 loss: The loss function for training. 69 optimizer: The optimizer. 70 metric: The metric for validation. 71 device: The torch device to use for training. If None, will use a GPU if available. 72 lr_scheduler: The learning rate scheduler. 73 log_image_interval: The interval for saving images during logging, in training iterations. 74 mixed_precision: Whether to train with mixed precision. 75 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 76 logger: The logger class. Will be instantiated for logging. 77 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 78 logger_kwargs: The keyword arguments for the logger class. 79 id_: Unique identifier for the trainer. If None then `name` will be used. 80 save_root: The root folder for saving the checkpoint and logs. 81 compile_model: Whether to compile the model before training. 82 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 83 """ 84 def __init__( 85 self, 86 name: Optional[str], 87 train_loader: torch.utils.data.DataLoader, 88 val_loader: torch.utils.data.DataLoader, 89 model: torch.nn.Module, 90 loss: torch.nn.Module, 91 optimizer: torch.optim.Optimizer, 92 metric: Callable, 93 device: Union[str, torch.device], 94 lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 95 log_image_interval: int = 100, 96 mixed_precision: bool = True, 97 early_stopping: Optional[int] = None, 98 logger=TensorboardLogger, 99 logger_kwargs: Optional[Dict[str, Any]] = None, 100 id_: Optional[str] = None, 101 save_root: Optional[str] = None, 102 compile_model: Optional[Union[bool, str]] = None, 103 rank: Optional[int] = None, 104 ): 105 if name is None and not issubclass(logger, WandbLogger): 106 raise TypeError("Name cannot be None if not using the WandbLogger") 107 108 if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): 109 raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.") 110 111 self._generate_name = name is None 112 self.name = name 113 self.id_ = id_ or name 114 self.train_loader = train_loader 115 self.val_loader = val_loader 116 self.model = model 117 self.loss = loss 118 self.optimizer = optimizer 119 self.metric = metric 120 self.device = torch.device(device) 121 self.lr_scheduler = lr_scheduler 122 self.log_image_interval = log_image_interval 123 self.save_root = save_root 124 self.compile_model = compile_model 125 self.rank = rank 126 127 self._iteration = 0 128 self._epoch = 0 129 self._best_epoch = 0 130 131 self.mixed_precision = mixed_precision 132 self.early_stopping = early_stopping 133 self.train_time = 0.0 134 135 if mixed_precision: 136 self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda") 137 else: 138 self.scaler = None 139 140 self.logger_class = logger 141 self.logger_kwargs = logger_kwargs 142 self.log_image_interval = log_image_interval 143 144 @property 145 def checkpoint_folder(self): 146 assert self.id_ is not None # Because the logger may generate and set trainer.id on logger.__init__. 147 # Save_root enables saving the checkpoints somewhere else than in the local older. 148 # This is handy for filesystems with limited space, where saving the checkpoints 149 # and log files can ead to running out of space. 150 save_root = getattr(self, "save_root", None) 151 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 152 os.path.join(save_root, "./checkpoints", self.id_) 153 154 @property 155 def iteration(self): 156 return self._iteration 157 158 @property 159 def epoch(self): 160 return self._epoch 161 162 class Deserializer: 163 """Determines how to deserialize the trainer kwargs from serialized 'init_data'. 164 165 Examples: 166 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 167 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 168 169 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 170 >>> class MyTrainer(DefaultTrainer): 171 >>> def __init__(self, *args, the_answer: int, **kwargs): 172 >>> super().__init__(*args, **kwargs) 173 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 174 >>> # see DefaultTrainer.Serializer 175 >>> 176 >>> class Deserializer(DefaultTrainer.Deserializer): 177 >>> def load_the_answer(self): 178 >>> generic_answer = self.init_data["the_answer"] 179 >>> # (device dependent) special deserialization 180 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 181 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 182 >>> else: 183 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 184 185 Args: 186 init_data: The initialization data of the trainer. 187 save_path: The path where the checkpoint was saved. 188 device: The device. 189 """ 190 191 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 192 self.init_data = init_data 193 self.save_path = save_path 194 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 195 self.trainer_kwargs: Dict[str, Any] = dict( 196 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 197 ) 198 199 def load(self, kwarg_name: str, optional): 200 """@private 201 """ 202 # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name' 203 if kwarg_name == "device": 204 pass # deserialized in __init__ 205 elif kwarg_name.endswith("_loader"): 206 self.load_data_loader(kwarg_name, optional) 207 else: 208 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 209 load(kwarg_name, optional=optional) 210 211 def load_data_loader(self, loader_name, optional) -> None: 212 """@private 213 """ 214 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 215 if ds is None and optional: 216 return 217 218 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 219 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 220 # monkey patch shuffle loader_name to the loader 221 loader.shuffle = loader_kwargs.get("shuffle", False) 222 self.trainer_kwargs[loader_name] = loader 223 224 def load_generic( 225 self, 226 kwarg_name: str, 227 *dynamic_args: Dict, 228 optional: bool, 229 only_class: bool = False, 230 dynamic_kwargs: Optional[Dict[str, Any]] = None, 231 ) -> None: 232 """@private 233 """ 234 if kwarg_name in self.init_data: 235 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 236 return 237 238 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 239 if this_cls is None: 240 if optional: 241 return 242 else: 243 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 244 245 assert isinstance(this_cls, str), this_cls 246 assert "." in this_cls, this_cls 247 cls_p, cls_m = this_cls.rsplit(".", 1) 248 this_cls = getattr(import_module(cls_p), cls_m) 249 if only_class: 250 self.trainer_kwargs[kwarg_name] = this_cls 251 else: 252 self.trainer_kwargs[kwarg_name] = this_cls( 253 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 254 ) 255 256 def load_name(self, kwarg_name: str, optional: bool): 257 """@private 258 """ 259 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 260 261 def load_optimizer(self, kwarg_name: str, optional: bool): 262 """@private 263 """ 264 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 265 266 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 267 """@private 268 """ 269 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 270 271 # todo: remove and rename kwarg 'logger' to 'logger_class' 272 def load_logger(self, kwarg_name: str, optional: bool): 273 """@private 274 """ 275 assert kwarg_name == "logger" 276 self.load_generic("logger", optional=optional, only_class=True) 277 278 @staticmethod 279 def _get_save_dict(save_path, device): 280 if not os.path.exists(save_path): 281 raise ValueError(f"Cannot find checkpoint {save_path}") 282 return torch.load(save_path, map_location=device, weights_only=False) 283 284 @classmethod 285 def from_checkpoint( 286 cls, 287 checkpoint_folder: Union[os.PathLike, str], 288 name: Literal["best", "latest"] = "best", 289 device: Optional[Union[str, torch.device]] = None, 290 ): 291 """@private 292 """ 293 save_path = os.path.join(checkpoint_folder, f"{name}.pt") 294 # make sure the correct device is set if we don't have access to CUDA 295 if not torch.cuda.is_available(): 296 device = "cpu" 297 save_dict = cls._get_save_dict(save_path, device) 298 deserializer = cls.Deserializer(save_dict["init"], save_path, device) 299 300 has_kwargs = False 301 deserialized = [] 302 for name, parameter in inspect.signature(cls).parameters.items(): 303 if name == "kwargs": 304 has_kwargs = True 305 continue 306 deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty) 307 deserialized.append(name) 308 309 # to deserialze kwargs we can't rely on inspecting the signature, so we 310 # go through the remaning kwarg names in init data instead 311 if has_kwargs: 312 kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized)) 313 for name in kwarg_names: 314 if name.endswith("_kwargs"): 315 continue 316 elif name.endswith("_dataset"): 317 deserializer.load(name.replace("dataset", "loader"), optional=False) 318 elif name.endswith("_class"): 319 deserializer.load(name.replace("_class", ""), optional=False) 320 else: 321 deserializer.load(name, optional=False) 322 323 trainer = cls(**deserializer.trainer_kwargs) 324 trainer._initialize(0, save_dict) 325 trainer._is_initialized = True 326 return trainer 327 328 class Serializer: 329 """Implements how to serialize trainer kwargs from a trainer instance. 330 331 Examples: 332 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 333 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 334 called by the `dump()` method when appropriate cover most cases already. 335 336 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 337 'the_answer' attribute: 338 >>> class MyTrainer(DefaultTrainer): 339 >>> def __init__(self, *args, the_answer: int, **kwargs): 340 >>> super().__init__(*args, **kwargs) 341 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 342 >>> # but let's make things more interesting... 343 >>> self.the = the_answer // 10 344 >>> self.answer = the_answer % 10 345 >>> 346 >>> class Serializer(DefaultTrainer.Serializer): 347 >>> trainer: MyTrainer 348 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 349 >>> assert kwarg_name == "the_answer" 350 >>> # populate self.init_data with the serialized data required by Deserializer 351 >>> # to restore the trainer kwargs 352 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 353 354 This example with both Serializer and Deserializer adds `the_answer` kwarg, 355 while saving it in two separate entries 'the' and 'answer' 356 >>> class MyTrainer(DefaultTrainer): 357 >>> def __init__(self, *args, the_answer: int, **kwargs): 358 >>> super().__init__(*args, **kwargs) 359 >>> self.the_answer = the_answer 360 >>> 361 >>> class Serializer(DefaultTrainer.Serializer): 362 >>> trainer: MyTrainer 363 >>> def dump_the_answer(self, kwarg_name: str): 364 >>> assert kwarg_name == "the_answer" 365 >>> self.init_data.update({ 366 >>> "the": self.trainer.the_answer // 10, 367 >>> "answer": self.trainer.the_answer % 10 368 >>> }) 369 >>> 370 >>> class Deserializer(DefaultTrainer.Deserializer): 371 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 372 >>> assert kwarg_name == "the_answer" 373 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 374 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 375 376 Args: 377 trainer: The trainer instance. 378 """ 379 380 def __init__(self, trainer: DefaultTrainer): 381 self.trainer = trainer 382 self.init_data = {} # to be populated during serialization process 383 384 def dump(self, kwarg_name: str) -> None: 385 """@private 386 """ 387 dumper = getattr(self, f"dump_{kwarg_name}", None) 388 if dumper is not None: 389 dumper(kwarg_name) 390 elif kwarg_name.endswith("_loader"): 391 self.dump_data_loader(kwarg_name) 392 elif kwarg_name.endswith("_class"): 393 self.dump_generic_class(kwarg_name) 394 elif not hasattr(self.trainer, kwarg_name): 395 raise AttributeError( 396 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 397 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 398 ) 399 else: 400 assert hasattr(self.trainer, kwarg_name) 401 obj = getattr(self.trainer, kwarg_name) 402 if obj is None or type(obj) in ( 403 bool, 404 bytearray, 405 bytes, 406 dict, 407 float, 408 frozenset, 409 int, 410 list, 411 set, 412 str, 413 tuple, 414 ): 415 self.dump_generic_builtin(kwarg_name) 416 else: 417 self.dump_generic_instance(kwarg_name) 418 419 def dump_generic_builtin(self, kwarg_name: str) -> None: 420 """@private 421 """ 422 assert hasattr(self.trainer, kwarg_name) 423 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 424 425 def dump_generic_class(self, kwarg_name: str) -> None: 426 """@private 427 """ 428 assert hasattr(self.trainer, kwarg_name) 429 assert kwarg_name.endswith("_class") 430 obj = getattr(self.trainer, kwarg_name) 431 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 432 433 def dump_generic_instance(self, kwarg_name: str) -> None: 434 """@private 435 """ 436 assert hasattr(self.trainer, kwarg_name) 437 instance = getattr(self.trainer, kwarg_name) 438 self.init_data.update( 439 { 440 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 441 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 442 } 443 ) 444 445 def dump_device(self, kwarg_name: str): 446 """@private 447 """ 448 assert hasattr(self.trainer, kwarg_name) 449 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 450 451 def dump_data_loader(self, kwarg_name: str) -> None: 452 """@private 453 """ 454 assert hasattr(self.trainer, kwarg_name) 455 loader = getattr(self.trainer, kwarg_name) 456 if loader is None: 457 return 458 self.init_data.update( 459 { 460 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 461 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 462 } 463 ) 464 465 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 466 """@private 467 """ 468 self.dump_generic_class(f"{kwarg_name}_class") 469 470 def dump_model(self, kwarg_name: str): 471 """@private 472 """ 473 if is_compiled(self.trainer.model): 474 self.init_data.update( 475 {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs} 476 ) 477 else: 478 self.dump_generic_instance("model") 479 480 def _build_init(self) -> Dict[str, Any]: 481 serializer = self.Serializer(self) 482 for name in inspect.signature(self.__class__).parameters: 483 # special rules to serialize kwargs 484 # if a trainer class inherits from DefaultTrainer and has **kwargs 485 # they need to be saved in self._kwargs 486 if name == "kwargs": 487 if not hasattr(self, "_kwargs"): 488 msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\ 489 "Please add self._kwargs to its __init__ function" 490 raise RuntimeError(msg) 491 kwargs = getattr(self, "_kwargs") 492 for kwarg_name in kwargs: 493 serializer.dump(kwarg_name) 494 continue 495 serializer.dump(name) 496 497 return serializer.init_data 498 499 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 500 assert self.train_loader is not None 501 assert self.val_loader is not None 502 assert self.model is not None 503 assert self.loss is not None 504 assert self.optimizer is not None 505 assert self.metric is not None 506 assert self.device is not None 507 508 if load_from_checkpoint is not None: 509 self.load_checkpoint(load_from_checkpoint) 510 511 if sum((iterations is not None, epochs is not None)) != 1: 512 raise ValueError( 513 "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer." 514 f"You have passed 'iterations'={iterations} and 'epochs'={epochs}" 515 ) 516 517 if epochs is None: 518 epochs = int(np.ceil(float(iterations) / len(self.train_loader))) 519 else: 520 iterations = epochs * len(self.train_loader) 521 522 self.max_iteration = self._iteration + iterations 523 self.max_epoch = self._epoch + epochs 524 525 if not getattr(self, "_is_initialized", False): 526 # check if we compile the model (only supported by pytorch 2) 527 # to enable (de)serialization of compiled models, we keep track of the model class and kwargs 528 if is_compiled(self.model): 529 warnings.warn( 530 "You have passed a compiled model to the trainer." 531 "It will not be possible to (de)serialize the trainer with it." 532 "If you want to be able to do this please pass the normal model." 533 "It can be automatically compiled by setting 'compile_model' to True" 534 ) 535 self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}" 536 self._model_kwargs = get_constructor_arguments(self.model) 537 self.model = auto_compile(self.model, self.compile_model) 538 539 self.model.to(self.device) 540 self.loss.to(self.device) 541 542 # this saves all the information that is necessary 543 # to fully load the trainer from the checkpoint 544 self.init_data = self._build_init() 545 546 if self.logger_class is None: 547 self.logger = None 548 else: 549 # may set self.name if self.name is None 550 save_root = getattr(self, "save_root", None) 551 try: 552 self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {})) 553 except PermissionError: 554 warnings.warn( 555 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 556 "The most likely reason for this is that you copied the checkpoint somewhere else," 557 "so we skip this error to enable loading the model from this checkpoint." 558 ) 559 560 try: 561 os.makedirs(self.checkpoint_folder, exist_ok=True) 562 except PermissionError: 563 warnings.warn( 564 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 565 "The most likely reason for this is that you copied the checkpoint somewhere else," 566 "so we skip this error to enable loading the model from this checkpoint." 567 ) 568 pass 569 570 best_metric = np.inf 571 return best_metric 572 573 def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict): 574 """@private 575 """ 576 save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") 577 extra_init_dict = extra_save_dict.pop("init", {}) 578 save_dict = { 579 "iteration": self._iteration, 580 "epoch": self._epoch, 581 "best_epoch": self._best_epoch, 582 "best_metric": best_metric, 583 "current_metric": current_metric, 584 "model_state": self.model.state_dict(), 585 "optimizer_state": self.optimizer.state_dict(), 586 "init": self.init_data | extra_init_dict, 587 "train_time": train_time, 588 "timestamp": datetime.now().strftime("%d-%m-%Y (%H:%M:%S)"), 589 } 590 save_dict.update(**extra_save_dict) 591 if self.scaler is not None: 592 save_dict.update({"scaler_state": self.scaler.state_dict()}) 593 if self.lr_scheduler is not None: 594 save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()}) 595 596 rank = getattr(self, "rank", None) 597 if rank is None or rank == 0: 598 torch.save(save_dict, save_path) 599 600 def load_checkpoint(self, checkpoint="best"): 601 """@private 602 """ 603 if isinstance(checkpoint, str): 604 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 605 if not os.path.exists(save_path): 606 warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.") 607 return 608 save_dict = torch.load(save_path, weights_only=False) 609 elif isinstance(checkpoint, dict): 610 save_dict = checkpoint 611 else: 612 raise RuntimeError 613 614 self._iteration = save_dict["iteration"] 615 self._epoch = save_dict["epoch"] 616 self._best_epoch = save_dict["best_epoch"] 617 self.best_metric = save_dict["best_metric"] 618 self.current_metric = save_dict["current_metric"] 619 self.train_time = save_dict.get("train_time", 0.0) 620 621 model_state = save_dict["model_state"] 622 # to enable loading compiled models 623 compiled_prefix = "_orig_mod." 624 model_state = OrderedDict( 625 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()] 626 ) 627 self.model.load_state_dict(model_state) 628 # we need to send the network to the device before loading the optimizer state! 629 self.model.to(self.device) 630 631 self.optimizer.load_state_dict(save_dict["optimizer_state"]) 632 if self.scaler is not None: 633 self.scaler.load_state_dict(save_dict["scaler_state"]) 634 if self.lr_scheduler is not None: 635 self.lr_scheduler.load_state_dict(save_dict["scheduler_state"]) 636 637 return save_dict 638 639 def _verify_if_training_completed(self, checkpoint="latest"): 640 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 641 save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None 642 if save_dict and self.max_iteration == save_dict.get("iteration"): 643 return True 644 return False 645 646 def fit( 647 self, 648 iterations: Optional[int] = None, 649 load_from_checkpoint: Optional[Union[os.PathLike, str]] = None, 650 epochs: Optional[int] = None, 651 save_every_kth_epoch: Optional[int] = None, 652 progress=None, 653 overwrite_training: bool = True, 654 ): 655 """Run neural network training. 656 657 Exactly one of 'iterations' or 'epochs' has to be passed. 658 659 Args: 660 iterations: How long to train, specified in iterations. 661 load_from_checkpoint: Path to a checkpoint from where training should be continued . 662 epochs: How long to train, specified in epochs. 663 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 664 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 665 progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface. 666 overwrite_training: Whether to overwrite existing checkpoints in the save directory. 667 """ 668 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 669 670 if not overwrite_training: 671 if load_from_checkpoint is not None: 672 raise ValueError( 673 "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time." 674 ) 675 676 if self._verify_if_training_completed(): 677 print( 678 f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs " 679 "and 'overwrite_training' is set to 'False'." 680 ) 681 print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.") 682 return 683 684 print( 685 "Start fitting for", 686 self.max_iteration - self._iteration, 687 "iterations / ", 688 self.max_epoch - self._epoch, 689 "epochs", 690 ) 691 print("with", len(self.train_loader), "iterations per epoch") 692 693 if self.mixed_precision: 694 train_epoch = self._train_epoch_mixed 695 validate = self._validate_mixed 696 print("Training with mixed precision") 697 else: 698 train_epoch = self._train_epoch 699 validate = self._validate 700 print("Training with single precision") 701 702 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 703 if progress is None: 704 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 705 else: 706 progress.total = total_iterations 707 progress.set_description(f"Epoch {self._epoch}") 708 709 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 710 train_epochs = self.max_epoch - self._epoch 711 t_start = time.time() 712 for epoch in range(train_epochs): 713 714 # Ensure data is shuffled differently at each epoch. 715 try: 716 self.train_loader.sampler.set_epoch(epoch) 717 except AttributeError: 718 pass 719 720 # Run training and validation for this epoch 721 t_per_iter = train_epoch(progress) 722 current_metric = validate() 723 724 # perform all the post-epoch steps: 725 726 # apply the learning rate scheduler 727 if self.lr_scheduler is not None: 728 self.lr_scheduler.step(current_metric) 729 730 # how long did we train in total? 731 total_train_time = (time.time() - t_start) + self.train_time 732 733 # save this checkpoint as the new best checkpoint if 734 # it has the best overall validation metric 735 if current_metric < best_metric: 736 best_metric = current_metric 737 self._best_epoch = self._epoch 738 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 739 740 # save this checkpoint as the latest checkpoint 741 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 742 743 # if we save after every k-th epoch then check if we need to save now 744 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 745 self.save_checkpoint( 746 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 747 ) 748 749 # if early stopping has been specified then check if the stopping condition is met 750 if self.early_stopping is not None: 751 epochs_since_best = self._epoch - self._best_epoch 752 if epochs_since_best > self.early_stopping: 753 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 754 break 755 756 self._epoch += 1 757 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 758 759 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 760 print(f"The best epoch is number {self._best_epoch}.") 761 762 if self._generate_name: 763 self.name = None 764 765 # Update the train time 766 self.train_time = total_train_time 767 768 # TODO save the model to wandb if we have the wandb logger 769 if isinstance(self.logger, WandbLogger): 770 self.logger.get_wandb().finish() 771 772 def _backprop(self, loss): 773 loss.backward() 774 self.optimizer.step() 775 776 def _backprop_mixed(self, loss): 777 self.scaler.scale(loss).backward() 778 self.scaler.step(self.optimizer) 779 self.scaler.update() 780 781 def _train_epoch(self, progress): 782 return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop) 783 784 def _train_epoch_mixed(self, progress): 785 return self._train_epoch_impl( 786 progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"), 787 self._backprop_mixed 788 ) 789 790 def _forward_and_loss(self, x, y): 791 pred = self.model(x) 792 if self._iteration % self.log_image_interval == 0: 793 if pred.requires_grad: 794 pred.retain_grad() 795 796 loss = self.loss(pred, y) 797 return pred, loss 798 799 def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]): 800 self.model.train() 801 802 n_iter = 0 803 t_per_iter = time.time() 804 for x, y in self.train_loader: 805 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 806 807 self.optimizer.zero_grad() 808 809 with forward_context(): 810 pred, loss = self._forward_and_loss(x, y) 811 812 backprop(loss) 813 814 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 815 if self.logger is not None: 816 self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True) 817 818 self._iteration += 1 819 n_iter += 1 820 if self._iteration >= self.max_iteration: 821 break 822 progress.update(1) 823 824 t_per_iter = (time.time() - t_per_iter) / n_iter 825 return t_per_iter 826 827 def _validate(self): 828 return self._validate_impl(contextlib.nullcontext) 829 830 def _validate_mixed(self): 831 return self._validate_impl( 832 partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda") 833 ) 834 835 def _validate_impl(self, forward_context): 836 self.model.eval() 837 838 metric_val = 0.0 839 loss_val = 0.0 840 841 with torch.no_grad(): 842 for x, y in self.val_loader: 843 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 844 with forward_context(): 845 pred, loss = self._forward_and_loss(x, y) 846 metric = self.metric(pred, y) 847 848 loss_val += loss.item() 849 metric_val += metric.item() 850 851 metric_val /= len(self.val_loader) 852 loss_val /= len(self.val_loader) 853 if self.logger is not None: 854 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred) 855 return metric_val
25class DefaultTrainer: 26 """Trainer class for training a segmentation network. 27 28 The trainer class implements the core logic for training a network with pytorch. 29 It implements a training loop to run training and validation, which is started with `fit`. 30 The checkpoints and logs from the training run will be saved in the current working directory, 31 or in the directory specifified by `save_root`. Training can be continued from a checkpoint 32 by passing it's location to the `load_from_checkpoint` argument of `fit`. 33 34 A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`. 35 Alternatively, the trainer class can also be instantiated as in this example: 36 ```python 37 import torch 38 from torch_em.loss import DiceLoss 39 from torch_em.model import UNet2d 40 from torch_em.data.datasets.light_microscopy import get_dsb_loader 41 from torch_em.trainer import DefaultTrainer 42 43 # The training data will be downloaded to this location. 44 data_root = "/path/to/save/the/training/data" 45 patch_shape = (256, 256) 46 47 # Create the model and optimizer. 48 model = UNet2d(in_channels=1, out_channels=1) 49 optimizer = torch.optim.AdamW(model.parameters()) 50 51 trainer = DefaultTrainer( 52 name="unet-training", 53 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 54 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 55 model=model, 56 loss=DiceLoss(), # The loss function. 57 optimizer=optimizer, 58 metric=DiceLoss(), # The metric. The trainer expects smaller values to represent better results. 59 device="cuda", # The device to use for training. 60 ) 61 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 62 ``` 63 64 Args: 65 name: The name of the checkpoint that will be created by the trainer. 66 train_loader: The data loader containing the training data. 67 val_loader: The data loader containing the validation data. 68 model: The model to train. 69 loss: The loss function for training. 70 optimizer: The optimizer. 71 metric: The metric for validation. 72 device: The torch device to use for training. If None, will use a GPU if available. 73 lr_scheduler: The learning rate scheduler. 74 log_image_interval: The interval for saving images during logging, in training iterations. 75 mixed_precision: Whether to train with mixed precision. 76 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 77 logger: The logger class. Will be instantiated for logging. 78 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 79 logger_kwargs: The keyword arguments for the logger class. 80 id_: Unique identifier for the trainer. If None then `name` will be used. 81 save_root: The root folder for saving the checkpoint and logs. 82 compile_model: Whether to compile the model before training. 83 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 84 """ 85 def __init__( 86 self, 87 name: Optional[str], 88 train_loader: torch.utils.data.DataLoader, 89 val_loader: torch.utils.data.DataLoader, 90 model: torch.nn.Module, 91 loss: torch.nn.Module, 92 optimizer: torch.optim.Optimizer, 93 metric: Callable, 94 device: Union[str, torch.device], 95 lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 96 log_image_interval: int = 100, 97 mixed_precision: bool = True, 98 early_stopping: Optional[int] = None, 99 logger=TensorboardLogger, 100 logger_kwargs: Optional[Dict[str, Any]] = None, 101 id_: Optional[str] = None, 102 save_root: Optional[str] = None, 103 compile_model: Optional[Union[bool, str]] = None, 104 rank: Optional[int] = None, 105 ): 106 if name is None and not issubclass(logger, WandbLogger): 107 raise TypeError("Name cannot be None if not using the WandbLogger") 108 109 if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): 110 raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.") 111 112 self._generate_name = name is None 113 self.name = name 114 self.id_ = id_ or name 115 self.train_loader = train_loader 116 self.val_loader = val_loader 117 self.model = model 118 self.loss = loss 119 self.optimizer = optimizer 120 self.metric = metric 121 self.device = torch.device(device) 122 self.lr_scheduler = lr_scheduler 123 self.log_image_interval = log_image_interval 124 self.save_root = save_root 125 self.compile_model = compile_model 126 self.rank = rank 127 128 self._iteration = 0 129 self._epoch = 0 130 self._best_epoch = 0 131 132 self.mixed_precision = mixed_precision 133 self.early_stopping = early_stopping 134 self.train_time = 0.0 135 136 if mixed_precision: 137 self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda") 138 else: 139 self.scaler = None 140 141 self.logger_class = logger 142 self.logger_kwargs = logger_kwargs 143 self.log_image_interval = log_image_interval 144 145 @property 146 def checkpoint_folder(self): 147 assert self.id_ is not None # Because the logger may generate and set trainer.id on logger.__init__. 148 # Save_root enables saving the checkpoints somewhere else than in the local older. 149 # This is handy for filesystems with limited space, where saving the checkpoints 150 # and log files can ead to running out of space. 151 save_root = getattr(self, "save_root", None) 152 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 153 os.path.join(save_root, "./checkpoints", self.id_) 154 155 @property 156 def iteration(self): 157 return self._iteration 158 159 @property 160 def epoch(self): 161 return self._epoch 162 163 class Deserializer: 164 """Determines how to deserialize the trainer kwargs from serialized 'init_data'. 165 166 Examples: 167 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 168 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 169 170 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 171 >>> class MyTrainer(DefaultTrainer): 172 >>> def __init__(self, *args, the_answer: int, **kwargs): 173 >>> super().__init__(*args, **kwargs) 174 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 175 >>> # see DefaultTrainer.Serializer 176 >>> 177 >>> class Deserializer(DefaultTrainer.Deserializer): 178 >>> def load_the_answer(self): 179 >>> generic_answer = self.init_data["the_answer"] 180 >>> # (device dependent) special deserialization 181 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 182 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 183 >>> else: 184 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 185 186 Args: 187 init_data: The initialization data of the trainer. 188 save_path: The path where the checkpoint was saved. 189 device: The device. 190 """ 191 192 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 193 self.init_data = init_data 194 self.save_path = save_path 195 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 196 self.trainer_kwargs: Dict[str, Any] = dict( 197 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 198 ) 199 200 def load(self, kwarg_name: str, optional): 201 """@private 202 """ 203 # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name' 204 if kwarg_name == "device": 205 pass # deserialized in __init__ 206 elif kwarg_name.endswith("_loader"): 207 self.load_data_loader(kwarg_name, optional) 208 else: 209 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 210 load(kwarg_name, optional=optional) 211 212 def load_data_loader(self, loader_name, optional) -> None: 213 """@private 214 """ 215 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 216 if ds is None and optional: 217 return 218 219 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 220 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 221 # monkey patch shuffle loader_name to the loader 222 loader.shuffle = loader_kwargs.get("shuffle", False) 223 self.trainer_kwargs[loader_name] = loader 224 225 def load_generic( 226 self, 227 kwarg_name: str, 228 *dynamic_args: Dict, 229 optional: bool, 230 only_class: bool = False, 231 dynamic_kwargs: Optional[Dict[str, Any]] = None, 232 ) -> None: 233 """@private 234 """ 235 if kwarg_name in self.init_data: 236 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 237 return 238 239 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 240 if this_cls is None: 241 if optional: 242 return 243 else: 244 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 245 246 assert isinstance(this_cls, str), this_cls 247 assert "." in this_cls, this_cls 248 cls_p, cls_m = this_cls.rsplit(".", 1) 249 this_cls = getattr(import_module(cls_p), cls_m) 250 if only_class: 251 self.trainer_kwargs[kwarg_name] = this_cls 252 else: 253 self.trainer_kwargs[kwarg_name] = this_cls( 254 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 255 ) 256 257 def load_name(self, kwarg_name: str, optional: bool): 258 """@private 259 """ 260 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 261 262 def load_optimizer(self, kwarg_name: str, optional: bool): 263 """@private 264 """ 265 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 266 267 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 268 """@private 269 """ 270 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 271 272 # todo: remove and rename kwarg 'logger' to 'logger_class' 273 def load_logger(self, kwarg_name: str, optional: bool): 274 """@private 275 """ 276 assert kwarg_name == "logger" 277 self.load_generic("logger", optional=optional, only_class=True) 278 279 @staticmethod 280 def _get_save_dict(save_path, device): 281 if not os.path.exists(save_path): 282 raise ValueError(f"Cannot find checkpoint {save_path}") 283 return torch.load(save_path, map_location=device, weights_only=False) 284 285 @classmethod 286 def from_checkpoint( 287 cls, 288 checkpoint_folder: Union[os.PathLike, str], 289 name: Literal["best", "latest"] = "best", 290 device: Optional[Union[str, torch.device]] = None, 291 ): 292 """@private 293 """ 294 save_path = os.path.join(checkpoint_folder, f"{name}.pt") 295 # make sure the correct device is set if we don't have access to CUDA 296 if not torch.cuda.is_available(): 297 device = "cpu" 298 save_dict = cls._get_save_dict(save_path, device) 299 deserializer = cls.Deserializer(save_dict["init"], save_path, device) 300 301 has_kwargs = False 302 deserialized = [] 303 for name, parameter in inspect.signature(cls).parameters.items(): 304 if name == "kwargs": 305 has_kwargs = True 306 continue 307 deserializer.load(name, optional=parameter.default is not inspect.Parameter.empty) 308 deserialized.append(name) 309 310 # to deserialze kwargs we can't rely on inspecting the signature, so we 311 # go through the remaning kwarg names in init data instead 312 if has_kwargs: 313 kwarg_names = list(set(deserializer.init_data.keys()) - set(deserialized)) 314 for name in kwarg_names: 315 if name.endswith("_kwargs"): 316 continue 317 elif name.endswith("_dataset"): 318 deserializer.load(name.replace("dataset", "loader"), optional=False) 319 elif name.endswith("_class"): 320 deserializer.load(name.replace("_class", ""), optional=False) 321 else: 322 deserializer.load(name, optional=False) 323 324 trainer = cls(**deserializer.trainer_kwargs) 325 trainer._initialize(0, save_dict) 326 trainer._is_initialized = True 327 return trainer 328 329 class Serializer: 330 """Implements how to serialize trainer kwargs from a trainer instance. 331 332 Examples: 333 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 334 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 335 called by the `dump()` method when appropriate cover most cases already. 336 337 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 338 'the_answer' attribute: 339 >>> class MyTrainer(DefaultTrainer): 340 >>> def __init__(self, *args, the_answer: int, **kwargs): 341 >>> super().__init__(*args, **kwargs) 342 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 343 >>> # but let's make things more interesting... 344 >>> self.the = the_answer // 10 345 >>> self.answer = the_answer % 10 346 >>> 347 >>> class Serializer(DefaultTrainer.Serializer): 348 >>> trainer: MyTrainer 349 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 350 >>> assert kwarg_name == "the_answer" 351 >>> # populate self.init_data with the serialized data required by Deserializer 352 >>> # to restore the trainer kwargs 353 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 354 355 This example with both Serializer and Deserializer adds `the_answer` kwarg, 356 while saving it in two separate entries 'the' and 'answer' 357 >>> class MyTrainer(DefaultTrainer): 358 >>> def __init__(self, *args, the_answer: int, **kwargs): 359 >>> super().__init__(*args, **kwargs) 360 >>> self.the_answer = the_answer 361 >>> 362 >>> class Serializer(DefaultTrainer.Serializer): 363 >>> trainer: MyTrainer 364 >>> def dump_the_answer(self, kwarg_name: str): 365 >>> assert kwarg_name == "the_answer" 366 >>> self.init_data.update({ 367 >>> "the": self.trainer.the_answer // 10, 368 >>> "answer": self.trainer.the_answer % 10 369 >>> }) 370 >>> 371 >>> class Deserializer(DefaultTrainer.Deserializer): 372 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 373 >>> assert kwarg_name == "the_answer" 374 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 375 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 376 377 Args: 378 trainer: The trainer instance. 379 """ 380 381 def __init__(self, trainer: DefaultTrainer): 382 self.trainer = trainer 383 self.init_data = {} # to be populated during serialization process 384 385 def dump(self, kwarg_name: str) -> None: 386 """@private 387 """ 388 dumper = getattr(self, f"dump_{kwarg_name}", None) 389 if dumper is not None: 390 dumper(kwarg_name) 391 elif kwarg_name.endswith("_loader"): 392 self.dump_data_loader(kwarg_name) 393 elif kwarg_name.endswith("_class"): 394 self.dump_generic_class(kwarg_name) 395 elif not hasattr(self.trainer, kwarg_name): 396 raise AttributeError( 397 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 398 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 399 ) 400 else: 401 assert hasattr(self.trainer, kwarg_name) 402 obj = getattr(self.trainer, kwarg_name) 403 if obj is None or type(obj) in ( 404 bool, 405 bytearray, 406 bytes, 407 dict, 408 float, 409 frozenset, 410 int, 411 list, 412 set, 413 str, 414 tuple, 415 ): 416 self.dump_generic_builtin(kwarg_name) 417 else: 418 self.dump_generic_instance(kwarg_name) 419 420 def dump_generic_builtin(self, kwarg_name: str) -> None: 421 """@private 422 """ 423 assert hasattr(self.trainer, kwarg_name) 424 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 425 426 def dump_generic_class(self, kwarg_name: str) -> None: 427 """@private 428 """ 429 assert hasattr(self.trainer, kwarg_name) 430 assert kwarg_name.endswith("_class") 431 obj = getattr(self.trainer, kwarg_name) 432 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 433 434 def dump_generic_instance(self, kwarg_name: str) -> None: 435 """@private 436 """ 437 assert hasattr(self.trainer, kwarg_name) 438 instance = getattr(self.trainer, kwarg_name) 439 self.init_data.update( 440 { 441 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 442 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 443 } 444 ) 445 446 def dump_device(self, kwarg_name: str): 447 """@private 448 """ 449 assert hasattr(self.trainer, kwarg_name) 450 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 451 452 def dump_data_loader(self, kwarg_name: str) -> None: 453 """@private 454 """ 455 assert hasattr(self.trainer, kwarg_name) 456 loader = getattr(self.trainer, kwarg_name) 457 if loader is None: 458 return 459 self.init_data.update( 460 { 461 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 462 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 463 } 464 ) 465 466 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 467 """@private 468 """ 469 self.dump_generic_class(f"{kwarg_name}_class") 470 471 def dump_model(self, kwarg_name: str): 472 """@private 473 """ 474 if is_compiled(self.trainer.model): 475 self.init_data.update( 476 {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs} 477 ) 478 else: 479 self.dump_generic_instance("model") 480 481 def _build_init(self) -> Dict[str, Any]: 482 serializer = self.Serializer(self) 483 for name in inspect.signature(self.__class__).parameters: 484 # special rules to serialize kwargs 485 # if a trainer class inherits from DefaultTrainer and has **kwargs 486 # they need to be saved in self._kwargs 487 if name == "kwargs": 488 if not hasattr(self, "_kwargs"): 489 msg = "The trainer class has **kwargs in its signature, but is missing the _kwargs attribute. " +\ 490 "Please add self._kwargs to its __init__ function" 491 raise RuntimeError(msg) 492 kwargs = getattr(self, "_kwargs") 493 for kwarg_name in kwargs: 494 serializer.dump(kwarg_name) 495 continue 496 serializer.dump(name) 497 498 return serializer.init_data 499 500 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 501 assert self.train_loader is not None 502 assert self.val_loader is not None 503 assert self.model is not None 504 assert self.loss is not None 505 assert self.optimizer is not None 506 assert self.metric is not None 507 assert self.device is not None 508 509 if load_from_checkpoint is not None: 510 self.load_checkpoint(load_from_checkpoint) 511 512 if sum((iterations is not None, epochs is not None)) != 1: 513 raise ValueError( 514 "Exactly one of 'iterations' or 'epochs' has to be specified to initialize the trainer." 515 f"You have passed 'iterations'={iterations} and 'epochs'={epochs}" 516 ) 517 518 if epochs is None: 519 epochs = int(np.ceil(float(iterations) / len(self.train_loader))) 520 else: 521 iterations = epochs * len(self.train_loader) 522 523 self.max_iteration = self._iteration + iterations 524 self.max_epoch = self._epoch + epochs 525 526 if not getattr(self, "_is_initialized", False): 527 # check if we compile the model (only supported by pytorch 2) 528 # to enable (de)serialization of compiled models, we keep track of the model class and kwargs 529 if is_compiled(self.model): 530 warnings.warn( 531 "You have passed a compiled model to the trainer." 532 "It will not be possible to (de)serialize the trainer with it." 533 "If you want to be able to do this please pass the normal model." 534 "It can be automatically compiled by setting 'compile_model' to True" 535 ) 536 self._model_class = f"{self.model.__class__.__module__}.{self.model.__class__.__name__}" 537 self._model_kwargs = get_constructor_arguments(self.model) 538 self.model = auto_compile(self.model, self.compile_model) 539 540 self.model.to(self.device) 541 self.loss.to(self.device) 542 543 # this saves all the information that is necessary 544 # to fully load the trainer from the checkpoint 545 self.init_data = self._build_init() 546 547 if self.logger_class is None: 548 self.logger = None 549 else: 550 # may set self.name if self.name is None 551 save_root = getattr(self, "save_root", None) 552 try: 553 self.logger = self.logger_class(self, save_root, **(self.logger_kwargs or {})) 554 except PermissionError: 555 warnings.warn( 556 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 557 "The most likely reason for this is that you copied the checkpoint somewhere else," 558 "so we skip this error to enable loading the model from this checkpoint." 559 ) 560 561 try: 562 os.makedirs(self.checkpoint_folder, exist_ok=True) 563 except PermissionError: 564 warnings.warn( 565 f"The checkpoint folder at {self.checkpoint_folder} could not be created." 566 "The most likely reason for this is that you copied the checkpoint somewhere else," 567 "so we skip this error to enable loading the model from this checkpoint." 568 ) 569 pass 570 571 best_metric = np.inf 572 return best_metric 573 574 def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict): 575 """@private 576 """ 577 save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") 578 extra_init_dict = extra_save_dict.pop("init", {}) 579 save_dict = { 580 "iteration": self._iteration, 581 "epoch": self._epoch, 582 "best_epoch": self._best_epoch, 583 "best_metric": best_metric, 584 "current_metric": current_metric, 585 "model_state": self.model.state_dict(), 586 "optimizer_state": self.optimizer.state_dict(), 587 "init": self.init_data | extra_init_dict, 588 "train_time": train_time, 589 "timestamp": datetime.now().strftime("%d-%m-%Y (%H:%M:%S)"), 590 } 591 save_dict.update(**extra_save_dict) 592 if self.scaler is not None: 593 save_dict.update({"scaler_state": self.scaler.state_dict()}) 594 if self.lr_scheduler is not None: 595 save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()}) 596 597 rank = getattr(self, "rank", None) 598 if rank is None or rank == 0: 599 torch.save(save_dict, save_path) 600 601 def load_checkpoint(self, checkpoint="best"): 602 """@private 603 """ 604 if isinstance(checkpoint, str): 605 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 606 if not os.path.exists(save_path): 607 warnings.warn(f"Cannot load checkpoint. {save_path} does not exist.") 608 return 609 save_dict = torch.load(save_path, weights_only=False) 610 elif isinstance(checkpoint, dict): 611 save_dict = checkpoint 612 else: 613 raise RuntimeError 614 615 self._iteration = save_dict["iteration"] 616 self._epoch = save_dict["epoch"] 617 self._best_epoch = save_dict["best_epoch"] 618 self.best_metric = save_dict["best_metric"] 619 self.current_metric = save_dict["current_metric"] 620 self.train_time = save_dict.get("train_time", 0.0) 621 622 model_state = save_dict["model_state"] 623 # to enable loading compiled models 624 compiled_prefix = "_orig_mod." 625 model_state = OrderedDict( 626 [(k[len(compiled_prefix):] if k.startswith(compiled_prefix) else k, v) for k, v in model_state.items()] 627 ) 628 self.model.load_state_dict(model_state) 629 # we need to send the network to the device before loading the optimizer state! 630 self.model.to(self.device) 631 632 self.optimizer.load_state_dict(save_dict["optimizer_state"]) 633 if self.scaler is not None: 634 self.scaler.load_state_dict(save_dict["scaler_state"]) 635 if self.lr_scheduler is not None: 636 self.lr_scheduler.load_state_dict(save_dict["scheduler_state"]) 637 638 return save_dict 639 640 def _verify_if_training_completed(self, checkpoint="latest"): 641 save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt") 642 save_dict = torch.load(save_path, weights_only=False) if os.path.exists(save_path) else None 643 if save_dict and self.max_iteration == save_dict.get("iteration"): 644 return True 645 return False 646 647 def fit( 648 self, 649 iterations: Optional[int] = None, 650 load_from_checkpoint: Optional[Union[os.PathLike, str]] = None, 651 epochs: Optional[int] = None, 652 save_every_kth_epoch: Optional[int] = None, 653 progress=None, 654 overwrite_training: bool = True, 655 ): 656 """Run neural network training. 657 658 Exactly one of 'iterations' or 'epochs' has to be passed. 659 660 Args: 661 iterations: How long to train, specified in iterations. 662 load_from_checkpoint: Path to a checkpoint from where training should be continued . 663 epochs: How long to train, specified in epochs. 664 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 665 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 666 progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface. 667 overwrite_training: Whether to overwrite existing checkpoints in the save directory. 668 """ 669 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 670 671 if not overwrite_training: 672 if load_from_checkpoint is not None: 673 raise ValueError( 674 "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time." 675 ) 676 677 if self._verify_if_training_completed(): 678 print( 679 f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs " 680 "and 'overwrite_training' is set to 'False'." 681 ) 682 print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.") 683 return 684 685 print( 686 "Start fitting for", 687 self.max_iteration - self._iteration, 688 "iterations / ", 689 self.max_epoch - self._epoch, 690 "epochs", 691 ) 692 print("with", len(self.train_loader), "iterations per epoch") 693 694 if self.mixed_precision: 695 train_epoch = self._train_epoch_mixed 696 validate = self._validate_mixed 697 print("Training with mixed precision") 698 else: 699 train_epoch = self._train_epoch 700 validate = self._validate 701 print("Training with single precision") 702 703 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 704 if progress is None: 705 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 706 else: 707 progress.total = total_iterations 708 progress.set_description(f"Epoch {self._epoch}") 709 710 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 711 train_epochs = self.max_epoch - self._epoch 712 t_start = time.time() 713 for epoch in range(train_epochs): 714 715 # Ensure data is shuffled differently at each epoch. 716 try: 717 self.train_loader.sampler.set_epoch(epoch) 718 except AttributeError: 719 pass 720 721 # Run training and validation for this epoch 722 t_per_iter = train_epoch(progress) 723 current_metric = validate() 724 725 # perform all the post-epoch steps: 726 727 # apply the learning rate scheduler 728 if self.lr_scheduler is not None: 729 self.lr_scheduler.step(current_metric) 730 731 # how long did we train in total? 732 total_train_time = (time.time() - t_start) + self.train_time 733 734 # save this checkpoint as the new best checkpoint if 735 # it has the best overall validation metric 736 if current_metric < best_metric: 737 best_metric = current_metric 738 self._best_epoch = self._epoch 739 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 740 741 # save this checkpoint as the latest checkpoint 742 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 743 744 # if we save after every k-th epoch then check if we need to save now 745 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 746 self.save_checkpoint( 747 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 748 ) 749 750 # if early stopping has been specified then check if the stopping condition is met 751 if self.early_stopping is not None: 752 epochs_since_best = self._epoch - self._best_epoch 753 if epochs_since_best > self.early_stopping: 754 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 755 break 756 757 self._epoch += 1 758 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 759 760 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 761 print(f"The best epoch is number {self._best_epoch}.") 762 763 if self._generate_name: 764 self.name = None 765 766 # Update the train time 767 self.train_time = total_train_time 768 769 # TODO save the model to wandb if we have the wandb logger 770 if isinstance(self.logger, WandbLogger): 771 self.logger.get_wandb().finish() 772 773 def _backprop(self, loss): 774 loss.backward() 775 self.optimizer.step() 776 777 def _backprop_mixed(self, loss): 778 self.scaler.scale(loss).backward() 779 self.scaler.step(self.optimizer) 780 self.scaler.update() 781 782 def _train_epoch(self, progress): 783 return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop) 784 785 def _train_epoch_mixed(self, progress): 786 return self._train_epoch_impl( 787 progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"), 788 self._backprop_mixed 789 ) 790 791 def _forward_and_loss(self, x, y): 792 pred = self.model(x) 793 if self._iteration % self.log_image_interval == 0: 794 if pred.requires_grad: 795 pred.retain_grad() 796 797 loss = self.loss(pred, y) 798 return pred, loss 799 800 def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch.Tensor], None]): 801 self.model.train() 802 803 n_iter = 0 804 t_per_iter = time.time() 805 for x, y in self.train_loader: 806 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 807 808 self.optimizer.zero_grad() 809 810 with forward_context(): 811 pred, loss = self._forward_and_loss(x, y) 812 813 backprop(loss) 814 815 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 816 if self.logger is not None: 817 self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True) 818 819 self._iteration += 1 820 n_iter += 1 821 if self._iteration >= self.max_iteration: 822 break 823 progress.update(1) 824 825 t_per_iter = (time.time() - t_per_iter) / n_iter 826 return t_per_iter 827 828 def _validate(self): 829 return self._validate_impl(contextlib.nullcontext) 830 831 def _validate_mixed(self): 832 return self._validate_impl( 833 partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda") 834 ) 835 836 def _validate_impl(self, forward_context): 837 self.model.eval() 838 839 metric_val = 0.0 840 loss_val = 0.0 841 842 with torch.no_grad(): 843 for x, y in self.val_loader: 844 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 845 with forward_context(): 846 pred, loss = self._forward_and_loss(x, y) 847 metric = self.metric(pred, y) 848 849 loss_val += loss.item() 850 metric_val += metric.item() 851 852 metric_val /= len(self.val_loader) 853 loss_val /= len(self.val_loader) 854 if self.logger is not None: 855 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred) 856 return metric_val
Trainer class for training a segmentation network.
The trainer class implements the core logic for training a network with pytorch.
It implements a training loop to run training and validation, which is started with fit
.
The checkpoints and logs from the training run will be saved in the current working directory,
or in the directory specifified by save_root
. Training can be continued from a checkpoint
by passing it's location to the load_from_checkpoint
argument of fit
.
A pre-configured instance of the trainer can be obtained from torch_em.default_segmentation_trainer
.
Alternatively, the trainer class can also be instantiated as in this example:
import torch
from torch_em.loss import DiceLoss
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
from torch_em.trainer import DefaultTrainer
# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
# Create the model and optimizer.
model = UNet2d(in_channels=1, out_channels=1)
optimizer = torch.optim.AdamW(model.parameters())
trainer = DefaultTrainer(
name="unet-training",
train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
model=model,
loss=DiceLoss(), # The loss function.
optimizer=optimizer,
metric=DiceLoss(), # The metric. The trainer expects smaller values to represent better results.
device="cuda", # The device to use for training.
)
trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations.
Arguments:
- name: The name of the checkpoint that will be created by the trainer.
- train_loader: The data loader containing the training data.
- val_loader: The data loader containing the validation data.
- model: The model to train.
- loss: The loss function for training.
- optimizer: The optimizer.
- metric: The metric for validation.
- device: The torch device to use for training. If None, will use a GPU if available.
- lr_scheduler: The learning rate scheduler.
- log_image_interval: The interval for saving images during logging, in training iterations.
- mixed_precision: Whether to train with mixed precision.
- early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
- logger: The logger class. Will be instantiated for logging.
By default uses
torch_em.training.tensorboard_logger.TensorboardLogger
. - logger_kwargs: The keyword arguments for the logger class.
- id_: Unique identifier for the trainer. If None then
name
will be used. - save_root: The root folder for saving the checkpoint and logs.
- compile_model: Whether to compile the model before training.
- rank: Rank argument for distributed training. See
torch_em.multi_gpu_training
for details.
85 def __init__( 86 self, 87 name: Optional[str], 88 train_loader: torch.utils.data.DataLoader, 89 val_loader: torch.utils.data.DataLoader, 90 model: torch.nn.Module, 91 loss: torch.nn.Module, 92 optimizer: torch.optim.Optimizer, 93 metric: Callable, 94 device: Union[str, torch.device], 95 lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 96 log_image_interval: int = 100, 97 mixed_precision: bool = True, 98 early_stopping: Optional[int] = None, 99 logger=TensorboardLogger, 100 logger_kwargs: Optional[Dict[str, Any]] = None, 101 id_: Optional[str] = None, 102 save_root: Optional[str] = None, 103 compile_model: Optional[Union[bool, str]] = None, 104 rank: Optional[int] = None, 105 ): 106 if name is None and not issubclass(logger, WandbLogger): 107 raise TypeError("Name cannot be None if not using the WandbLogger") 108 109 if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): 110 raise ValueError(f"{self.__class__} requires each dataloader to have 'shuffle' attribute.") 111 112 self._generate_name = name is None 113 self.name = name 114 self.id_ = id_ or name 115 self.train_loader = train_loader 116 self.val_loader = val_loader 117 self.model = model 118 self.loss = loss 119 self.optimizer = optimizer 120 self.metric = metric 121 self.device = torch.device(device) 122 self.lr_scheduler = lr_scheduler 123 self.log_image_interval = log_image_interval 124 self.save_root = save_root 125 self.compile_model = compile_model 126 self.rank = rank 127 128 self._iteration = 0 129 self._epoch = 0 130 self._best_epoch = 0 131 132 self.mixed_precision = mixed_precision 133 self.early_stopping = early_stopping 134 self.train_time = 0.0 135 136 if mixed_precision: 137 self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda") 138 else: 139 self.scaler = None 140 141 self.logger_class = logger 142 self.logger_kwargs = logger_kwargs 143 self.log_image_interval = log_image_interval
145 @property 146 def checkpoint_folder(self): 147 assert self.id_ is not None # Because the logger may generate and set trainer.id on logger.__init__. 148 # Save_root enables saving the checkpoints somewhere else than in the local older. 149 # This is handy for filesystems with limited space, where saving the checkpoints 150 # and log files can ead to running out of space. 151 save_root = getattr(self, "save_root", None) 152 return os.path.join("./checkpoints", self.id_) if save_root is None else\ 153 os.path.join(save_root, "./checkpoints", self.id_)
647 def fit( 648 self, 649 iterations: Optional[int] = None, 650 load_from_checkpoint: Optional[Union[os.PathLike, str]] = None, 651 epochs: Optional[int] = None, 652 save_every_kth_epoch: Optional[int] = None, 653 progress=None, 654 overwrite_training: bool = True, 655 ): 656 """Run neural network training. 657 658 Exactly one of 'iterations' or 'epochs' has to be passed. 659 660 Args: 661 iterations: How long to train, specified in iterations. 662 load_from_checkpoint: Path to a checkpoint from where training should be continued . 663 epochs: How long to train, specified in epochs. 664 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 665 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 666 progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface. 667 overwrite_training: Whether to overwrite existing checkpoints in the save directory. 668 """ 669 best_metric = self._initialize(iterations, load_from_checkpoint, epochs) 670 671 if not overwrite_training: 672 if load_from_checkpoint is not None: 673 raise ValueError( 674 "We do not support 'overwrite_training=False' and 'load_from_checkpoint' at the same time." 675 ) 676 677 if self._verify_if_training_completed(): 678 print( 679 f"The model is trained for {self.max_iteration} iterations / {self.max_epoch} epochs " 680 "and 'overwrite_training' is set to 'False'." 681 ) 682 print(f"The checkpoints are located at '{os.path.abspath(self.checkpoint_folder)}'.") 683 return 684 685 print( 686 "Start fitting for", 687 self.max_iteration - self._iteration, 688 "iterations / ", 689 self.max_epoch - self._epoch, 690 "epochs", 691 ) 692 print("with", len(self.train_loader), "iterations per epoch") 693 694 if self.mixed_precision: 695 train_epoch = self._train_epoch_mixed 696 validate = self._validate_mixed 697 print("Training with mixed precision") 698 else: 699 train_epoch = self._train_epoch 700 validate = self._validate 701 print("Training with single precision") 702 703 total_iterations = epochs * len(self.train_loader) if iterations is None else iterations 704 if progress is None: 705 progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True) 706 else: 707 progress.total = total_iterations 708 progress.set_description(f"Epoch {self._epoch}") 709 710 msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" 711 train_epochs = self.max_epoch - self._epoch 712 t_start = time.time() 713 for epoch in range(train_epochs): 714 715 # Ensure data is shuffled differently at each epoch. 716 try: 717 self.train_loader.sampler.set_epoch(epoch) 718 except AttributeError: 719 pass 720 721 # Run training and validation for this epoch 722 t_per_iter = train_epoch(progress) 723 current_metric = validate() 724 725 # perform all the post-epoch steps: 726 727 # apply the learning rate scheduler 728 if self.lr_scheduler is not None: 729 self.lr_scheduler.step(current_metric) 730 731 # how long did we train in total? 732 total_train_time = (time.time() - t_start) + self.train_time 733 734 # save this checkpoint as the new best checkpoint if 735 # it has the best overall validation metric 736 if current_metric < best_metric: 737 best_metric = current_metric 738 self._best_epoch = self._epoch 739 self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) 740 741 # save this checkpoint as the latest checkpoint 742 self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) 743 744 # if we save after every k-th epoch then check if we need to save now 745 if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: 746 self.save_checkpoint( 747 f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time 748 ) 749 750 # if early stopping has been specified then check if the stopping condition is met 751 if self.early_stopping is not None: 752 epochs_since_best = self._epoch - self._best_epoch 753 if epochs_since_best > self.early_stopping: 754 print("Stopping training because there has been no improvement for", self.early_stopping, "epochs") 755 break 756 757 self._epoch += 1 758 progress.set_description(msg % (self._epoch, t_per_iter, current_metric, best_metric), refresh=True) 759 760 print(f"Finished training after {self._epoch} epochs / {self._iteration} iterations.") 761 print(f"The best epoch is number {self._best_epoch}.") 762 763 if self._generate_name: 764 self.name = None 765 766 # Update the train time 767 self.train_time = total_train_time 768 769 # TODO save the model to wandb if we have the wandb logger 770 if isinstance(self.logger, WandbLogger): 771 self.logger.get_wandb().finish()
Run neural network training.
Exactly one of 'iterations' or 'epochs' has to be passed.
Arguments:
- iterations: How long to train, specified in iterations.
- load_from_checkpoint: Path to a checkpoint from where training should be continued .
- epochs: How long to train, specified in epochs.
- save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
- progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
- overwrite_training: Whether to overwrite existing checkpoints in the save directory.
163 class Deserializer: 164 """Determines how to deserialize the trainer kwargs from serialized 'init_data'. 165 166 Examples: 167 To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. 168 Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. 169 170 This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: 171 >>> class MyTrainer(DefaultTrainer): 172 >>> def __init__(self, *args, the_answer: int, **kwargs): 173 >>> super().__init__(*args, **kwargs) 174 >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, 175 >>> # see DefaultTrainer.Serializer 176 >>> 177 >>> class Deserializer(DefaultTrainer.Deserializer): 178 >>> def load_the_answer(self): 179 >>> generic_answer = self.init_data["the_answer"] 180 >>> # (device dependent) special deserialization 181 >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg 182 >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 183 >>> else: 184 >>> self.trainer_kwargs["the_answer"] = generic_answer * 2 185 186 Args: 187 init_data: The initialization data of the trainer. 188 save_path: The path where the checkpoint was saved. 189 device: The device. 190 """ 191 192 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 193 self.init_data = init_data 194 self.save_path = save_path 195 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 196 self.trainer_kwargs: Dict[str, Any] = dict( 197 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 198 ) 199 200 def load(self, kwarg_name: str, optional): 201 """@private 202 """ 203 # `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name' 204 if kwarg_name == "device": 205 pass # deserialized in __init__ 206 elif kwarg_name.endswith("_loader"): 207 self.load_data_loader(kwarg_name, optional) 208 else: 209 load = getattr(self, f"load_{kwarg_name}", self.load_generic) 210 load(kwarg_name, optional=optional) 211 212 def load_data_loader(self, loader_name, optional) -> None: 213 """@private 214 """ 215 ds = self.init_data.get(loader_name.replace("_loader", "_dataset")) 216 if ds is None and optional: 217 return 218 219 loader_kwargs = self.init_data[f"{loader_name}_kwargs"] 220 loader = torch.utils.data.DataLoader(ds, **loader_kwargs) 221 # monkey patch shuffle loader_name to the loader 222 loader.shuffle = loader_kwargs.get("shuffle", False) 223 self.trainer_kwargs[loader_name] = loader 224 225 def load_generic( 226 self, 227 kwarg_name: str, 228 *dynamic_args: Dict, 229 optional: bool, 230 only_class: bool = False, 231 dynamic_kwargs: Optional[Dict[str, Any]] = None, 232 ) -> None: 233 """@private 234 """ 235 if kwarg_name in self.init_data: 236 self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name] 237 return 238 239 this_cls = self.init_data.get(f"{kwarg_name}_class", None) 240 if this_cls is None: 241 if optional: 242 return 243 else: 244 raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") 245 246 assert isinstance(this_cls, str), this_cls 247 assert "." in this_cls, this_cls 248 cls_p, cls_m = this_cls.rsplit(".", 1) 249 this_cls = getattr(import_module(cls_p), cls_m) 250 if only_class: 251 self.trainer_kwargs[kwarg_name] = this_cls 252 else: 253 self.trainer_kwargs[kwarg_name] = this_cls( 254 *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) 255 ) 256 257 def load_name(self, kwarg_name: str, optional: bool): 258 """@private 259 """ 260 self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1] 261 262 def load_optimizer(self, kwarg_name: str, optional: bool): 263 """@private 264 """ 265 self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional) 266 267 def load_lr_scheduler(self, kwarg_name: str, optional: bool): 268 """@private 269 """ 270 self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional) 271 272 # todo: remove and rename kwarg 'logger' to 'logger_class' 273 def load_logger(self, kwarg_name: str, optional: bool): 274 """@private 275 """ 276 assert kwarg_name == "logger" 277 self.load_generic("logger", optional=optional, only_class=True)
Determines how to deserialize the trainer kwargs from serialized 'init_data'.
Examples:
To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. Note that
DefaultTrainer.Deserializer.load_generic()
covers most cases already.This example adds
the_answer
kwarg, which requires 'calculations' upon initialization:>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, >>> # see DefaultTrainer.Serializer >>> >>> class Deserializer(DefaultTrainer.Deserializer): >>> def load_the_answer(self): >>> generic_answer = self.init_data["the_answer"] >>> # (device dependent) special deserialization >>> if self.trainer_kwargs["device"].type == "cpu": # accessing previously deserialized kwarg >>> self.trainer_kwargs["the_answer"] = generic_answer + 1 >>> else: >>> self.trainer_kwargs["the_answer"] = generic_answer * 2
Arguments:
- init_data: The initialization data of the trainer.
- save_path: The path where the checkpoint was saved.
- device: The device.
192 def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]): 193 self.init_data = init_data 194 self.save_path = save_path 195 # Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'. 196 self.trainer_kwargs: Dict[str, Any] = dict( 197 device=torch.device(self.init_data["device"]) if device is None else torch.device(device) 198 )
329 class Serializer: 330 """Implements how to serialize trainer kwargs from a trainer instance. 331 332 Examples: 333 To extend the serialization process you can inherite from this Serializer in a derived Trainer class. 334 Note that the methods `dump_generic_builtin()`, `dump_generic_class()` and `dump_generic_instance()` 335 called by the `dump()` method when appropriate cover most cases already. 336 337 This example adds `the_answer` kwarg, which requires extra steps on dumping only because we don't keep a 338 'the_answer' attribute: 339 >>> class MyTrainer(DefaultTrainer): 340 >>> def __init__(self, *args, the_answer: int, **kwargs): 341 >>> super().__init__(*args, **kwargs) 342 >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, 343 >>> # but let's make things more interesting... 344 >>> self.the = the_answer // 10 345 >>> self.answer = the_answer % 10 346 >>> 347 >>> class Serializer(DefaultTrainer.Serializer): 348 >>> trainer: MyTrainer 349 >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg 350 >>> assert kwarg_name == "the_answer" 351 >>> # populate self.init_data with the serialized data required by Deserializer 352 >>> # to restore the trainer kwargs 353 >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer 354 355 This example with both Serializer and Deserializer adds `the_answer` kwarg, 356 while saving it in two separate entries 'the' and 'answer' 357 >>> class MyTrainer(DefaultTrainer): 358 >>> def __init__(self, *args, the_answer: int, **kwargs): 359 >>> super().__init__(*args, **kwargs) 360 >>> self.the_answer = the_answer 361 >>> 362 >>> class Serializer(DefaultTrainer.Serializer): 363 >>> trainer: MyTrainer 364 >>> def dump_the_answer(self, kwarg_name: str): 365 >>> assert kwarg_name == "the_answer" 366 >>> self.init_data.update({ 367 >>> "the": self.trainer.the_answer // 10, 368 >>> "answer": self.trainer.the_answer % 10 369 >>> }) 370 >>> 371 >>> class Deserializer(DefaultTrainer.Deserializer): 372 >>> def load_the_answer(self, kwarg_name: str, optional: bool): 373 >>> assert kwarg_name == "the_answer" 374 >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' 375 >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"] 376 377 Args: 378 trainer: The trainer instance. 379 """ 380 381 def __init__(self, trainer: DefaultTrainer): 382 self.trainer = trainer 383 self.init_data = {} # to be populated during serialization process 384 385 def dump(self, kwarg_name: str) -> None: 386 """@private 387 """ 388 dumper = getattr(self, f"dump_{kwarg_name}", None) 389 if dumper is not None: 390 dumper(kwarg_name) 391 elif kwarg_name.endswith("_loader"): 392 self.dump_data_loader(kwarg_name) 393 elif kwarg_name.endswith("_class"): 394 self.dump_generic_class(kwarg_name) 395 elif not hasattr(self.trainer, kwarg_name): 396 raise AttributeError( 397 f"{self.trainer.__class__} missing attribute '{kwarg_name}' " 398 f"or special dump method {self.trainer.__class__}.Serializer.dump_{kwarg_name}()" 399 ) 400 else: 401 assert hasattr(self.trainer, kwarg_name) 402 obj = getattr(self.trainer, kwarg_name) 403 if obj is None or type(obj) in ( 404 bool, 405 bytearray, 406 bytes, 407 dict, 408 float, 409 frozenset, 410 int, 411 list, 412 set, 413 str, 414 tuple, 415 ): 416 self.dump_generic_builtin(kwarg_name) 417 else: 418 self.dump_generic_instance(kwarg_name) 419 420 def dump_generic_builtin(self, kwarg_name: str) -> None: 421 """@private 422 """ 423 assert hasattr(self.trainer, kwarg_name) 424 self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name) 425 426 def dump_generic_class(self, kwarg_name: str) -> None: 427 """@private 428 """ 429 assert hasattr(self.trainer, kwarg_name) 430 assert kwarg_name.endswith("_class") 431 obj = getattr(self.trainer, kwarg_name) 432 self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}" 433 434 def dump_generic_instance(self, kwarg_name: str) -> None: 435 """@private 436 """ 437 assert hasattr(self.trainer, kwarg_name) 438 instance = getattr(self.trainer, kwarg_name) 439 self.init_data.update( 440 { 441 f"{kwarg_name}_class": f"{instance.__class__.__module__}.{instance.__class__.__name__}", 442 f"{kwarg_name}_kwargs": get_constructor_arguments(instance), 443 } 444 ) 445 446 def dump_device(self, kwarg_name: str): 447 """@private 448 """ 449 assert hasattr(self.trainer, kwarg_name) 450 self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name)) 451 452 def dump_data_loader(self, kwarg_name: str) -> None: 453 """@private 454 """ 455 assert hasattr(self.trainer, kwarg_name) 456 loader = getattr(self.trainer, kwarg_name) 457 if loader is None: 458 return 459 self.init_data.update( 460 { 461 f"{kwarg_name.replace('_loader', '_dataset')}": loader.dataset, 462 f"{kwarg_name}_kwargs": get_constructor_arguments(loader), 463 } 464 ) 465 466 def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class' 467 """@private 468 """ 469 self.dump_generic_class(f"{kwarg_name}_class") 470 471 def dump_model(self, kwarg_name: str): 472 """@private 473 """ 474 if is_compiled(self.trainer.model): 475 self.init_data.update( 476 {"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs} 477 ) 478 else: 479 self.dump_generic_instance("model")
Implements how to serialize trainer kwargs from a trainer instance.
Examples:
To extend the serialization process you can inherite from this Serializer in a derived Trainer class. Note that the methods
dump_generic_builtin()
,dump_generic_class()
anddump_generic_instance()
called by thedump()
method when appropriate cover most cases already.This example adds
the_answer
kwarg, which requires extra steps on dumping only because we don't keep a 'the_answer' attribute:>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> # self.the_answer = the_answer # this would allow the default Serializer to save the new kwarg, >>> # but let's make things more interesting... >>> self.the = the_answer // 10 >>> self.answer = the_answer % 10 >>> >>> class Serializer(DefaultTrainer.Serializer): >>> trainer: MyTrainer >>> def dump_the_answer(self, kwarg_name: str) -> None: # custom dump method for 'the_answer' kwarg >>> assert kwarg_name == "the_answer" >>> # populate self.init_data with the serialized data required by Deserializer >>> # to restore the trainer kwargs >>> self.init_data["the_answer"] = self.trainer.the * 10 + self.trainer.answer
This example with both Serializer and Deserializer adds
the_answer
kwarg, while saving it in two separate entries 'the' and 'answer'>>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) >>> self.the_answer = the_answer >>> >>> class Serializer(DefaultTrainer.Serializer): >>> trainer: MyTrainer >>> def dump_the_answer(self, kwarg_name: str): >>> assert kwarg_name == "the_answer" >>> self.init_data.update({ >>> "the": self.trainer.the_answer // 10, >>> "answer": self.trainer.the_answer % 10 >>> }) >>> >>> class Deserializer(DefaultTrainer.Deserializer): >>> def load_the_answer(self, kwarg_name: str, optional: bool): >>> assert kwarg_name == "the_answer" >>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name' >>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
Arguments:
- trainer: The trainer instance.