torch_em.trainer.spoco_trainer
1import time 2from copy import deepcopy 3 4import torch 5from .default_trainer import DefaultTrainer 6from .tensorboard_logger import TensorboardLogger 7 8 9class SPOCOTrainer(DefaultTrainer): 10 def __init__( 11 self, 12 model, 13 momentum=0.999, 14 semisupervised_loss=None, 15 semisupervised_loader=None, 16 logger=TensorboardLogger, 17 **kwargs, 18 ): 19 super().__init__(model=model, logger=logger, **kwargs) 20 self.momentum = momentum 21 # copy the model and don"t require gradients for it 22 self.model2 = deepcopy(self.model) 23 for param in self.model2.parameters(): 24 param.requires_grad = False 25 # do we have a semi-supervised loss and loader? 26 assert (semisupervised_loss is None) == (semisupervised_loader is None) 27 self.semisupervised_loader = semisupervised_loader 28 self.semisupervised_loss = semisupervised_loss 29 self._kwargs = kwargs 30 31 def _momentum_update(self): 32 for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()): 33 param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum) 34 35 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 36 super().save_checkpoint( 37 name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict 38 ) 39 40 def load_checkpoint(self, checkpoint="best"): 41 save_dict = super().load_checkpoint(checkpoint) 42 self.model2.load_state_dict(save_dict["model2_state"]) 43 self.model2.to(self.device) 44 return save_dict 45 46 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 47 best_metric = super()._initialize(iterations, load_from_checkpoint, epochs) 48 self.model2.to(self.device) 49 return best_metric 50 51 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 52 self.model.train() 53 self.model2.train() 54 progress.set_description( 55 f"Run semi-supervised training for {len(self.semisupervised_loader)} iterations", refresh=True 56 ) 57 58 for x in self.semisupervised_loader: 59 x = x.to(self.device) 60 self.optimizer.zero_grad() 61 62 with forward_context(): 63 prediction = self.model(x) 64 with torch.no_grad(): 65 prediction2 = self.model2(x) 66 loss = self.semisupervised_loss(prediction, prediction2) 67 backprop(loss) 68 69 with torch.no_grad(): 70 self._momentum_update() 71 72 def _train_epoch_impl(self, progress, forward_context, backprop): 73 self.model.train() 74 self.model2.train() 75 76 n_iter = 0 77 t_per_iter = time.time() 78 for x, y in self.train_loader: 79 x, y = x.to(self.device), y.to(self.device) 80 81 self.optimizer.zero_grad() 82 83 with forward_context(): 84 prediction = self.model(x) 85 with torch.no_grad(): 86 prediction2 = self.model2(x) 87 loss = self.loss((prediction, prediction2), y) 88 89 if self._iteration % self.log_image_interval == 0: 90 prediction.retain_grad() 91 92 backprop(loss) 93 94 with torch.no_grad(): 95 self._momentum_update() 96 97 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 98 if self.logger is not None: 99 self.logger.log_train(self._iteration, loss, lr, 100 x, y, prediction, 101 log_gradients=True) 102 103 self._iteration += 1 104 n_iter += 1 105 if self._iteration >= self.max_iteration: 106 break 107 progress.update(1) 108 109 if self.semisupervised_loader is not None: 110 self._train_epoch_semisupervised(progress, forward_context, backprop) 111 t_per_iter = (time.time() - t_per_iter) / n_iter 112 return t_per_iter 113 114 def _validate_impl(self, forward_context): 115 self.model.eval() 116 self.model2.eval() 117 118 metric = 0.0 119 loss = 0.0 120 121 with torch.no_grad(): 122 for x, y in self.val_loader: 123 x, y = x.to(self.device), y.to(self.device) 124 with forward_context(): 125 prediction = self.model(x) 126 prediction2 = self.model2(x) 127 loss += self.loss((prediction, prediction2), y).item() 128 metric += self.metric(prediction, y).item() 129 130 metric /= len(self.val_loader) 131 loss /= len(self.val_loader) 132 if self.logger is not None: 133 self.logger.log_validation(self._iteration, metric, loss, x, y, prediction) 134 return metric
10class SPOCOTrainer(DefaultTrainer): 11 def __init__( 12 self, 13 model, 14 momentum=0.999, 15 semisupervised_loss=None, 16 semisupervised_loader=None, 17 logger=TensorboardLogger, 18 **kwargs, 19 ): 20 super().__init__(model=model, logger=logger, **kwargs) 21 self.momentum = momentum 22 # copy the model and don"t require gradients for it 23 self.model2 = deepcopy(self.model) 24 for param in self.model2.parameters(): 25 param.requires_grad = False 26 # do we have a semi-supervised loss and loader? 27 assert (semisupervised_loss is None) == (semisupervised_loader is None) 28 self.semisupervised_loader = semisupervised_loader 29 self.semisupervised_loss = semisupervised_loss 30 self._kwargs = kwargs 31 32 def _momentum_update(self): 33 for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()): 34 param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum) 35 36 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 37 super().save_checkpoint( 38 name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict 39 ) 40 41 def load_checkpoint(self, checkpoint="best"): 42 save_dict = super().load_checkpoint(checkpoint) 43 self.model2.load_state_dict(save_dict["model2_state"]) 44 self.model2.to(self.device) 45 return save_dict 46 47 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 48 best_metric = super()._initialize(iterations, load_from_checkpoint, epochs) 49 self.model2.to(self.device) 50 return best_metric 51 52 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 53 self.model.train() 54 self.model2.train() 55 progress.set_description( 56 f"Run semi-supervised training for {len(self.semisupervised_loader)} iterations", refresh=True 57 ) 58 59 for x in self.semisupervised_loader: 60 x = x.to(self.device) 61 self.optimizer.zero_grad() 62 63 with forward_context(): 64 prediction = self.model(x) 65 with torch.no_grad(): 66 prediction2 = self.model2(x) 67 loss = self.semisupervised_loss(prediction, prediction2) 68 backprop(loss) 69 70 with torch.no_grad(): 71 self._momentum_update() 72 73 def _train_epoch_impl(self, progress, forward_context, backprop): 74 self.model.train() 75 self.model2.train() 76 77 n_iter = 0 78 t_per_iter = time.time() 79 for x, y in self.train_loader: 80 x, y = x.to(self.device), y.to(self.device) 81 82 self.optimizer.zero_grad() 83 84 with forward_context(): 85 prediction = self.model(x) 86 with torch.no_grad(): 87 prediction2 = self.model2(x) 88 loss = self.loss((prediction, prediction2), y) 89 90 if self._iteration % self.log_image_interval == 0: 91 prediction.retain_grad() 92 93 backprop(loss) 94 95 with torch.no_grad(): 96 self._momentum_update() 97 98 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 99 if self.logger is not None: 100 self.logger.log_train(self._iteration, loss, lr, 101 x, y, prediction, 102 log_gradients=True) 103 104 self._iteration += 1 105 n_iter += 1 106 if self._iteration >= self.max_iteration: 107 break 108 progress.update(1) 109 110 if self.semisupervised_loader is not None: 111 self._train_epoch_semisupervised(progress, forward_context, backprop) 112 t_per_iter = (time.time() - t_per_iter) / n_iter 113 return t_per_iter 114 115 def _validate_impl(self, forward_context): 116 self.model.eval() 117 self.model2.eval() 118 119 metric = 0.0 120 loss = 0.0 121 122 with torch.no_grad(): 123 for x, y in self.val_loader: 124 x, y = x.to(self.device), y.to(self.device) 125 with forward_context(): 126 prediction = self.model(x) 127 prediction2 = self.model2(x) 128 loss += self.loss((prediction, prediction2), y).item() 129 metric += self.metric(prediction, y).item() 130 131 metric /= len(self.val_loader) 132 loss /= len(self.val_loader) 133 if self.logger is not None: 134 self.logger.log_validation(self._iteration, metric, loss, x, y, prediction) 135 return metric
Trainer class for 2d/3d training on a single GPU.
SPOCOTrainer( model, momentum=0.999, semisupervised_loss=None, semisupervised_loader=None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, **kwargs)
11 def __init__( 12 self, 13 model, 14 momentum=0.999, 15 semisupervised_loss=None, 16 semisupervised_loader=None, 17 logger=TensorboardLogger, 18 **kwargs, 19 ): 20 super().__init__(model=model, logger=logger, **kwargs) 21 self.momentum = momentum 22 # copy the model and don"t require gradients for it 23 self.model2 = deepcopy(self.model) 24 for param in self.model2.parameters(): 25 param.requires_grad = False 26 # do we have a semi-supervised loss and loader? 27 assert (semisupervised_loss is None) == (semisupervised_loader is None) 28 self.semisupervised_loader = semisupervised_loader 29 self.semisupervised_loss = semisupervised_loss 30 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- mixed_precision
- early_stopping
- train_time
- scaler
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- fit