torch_em.trainer.spoco_trainer
1import time 2from copy import deepcopy 3from typing import Optional 4 5import torch 6from .default_trainer import DefaultTrainer 7from .tensorboard_logger import TensorboardLogger 8 9 10class SPOCOTrainer(DefaultTrainer): 11 """Trainer for a SPOCO model. 12 13 For details check out "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": 14 https://arxiv.org/abs/2103.14572 15 16 Args: 17 model: The model to train. 18 momentum: The momementum value for exponential moving weight averaging. 19 semisupervised_loss: Optional loss for semi-supervised learning. 20 semisupervised_loader: Optional data loader for semi-supervised learning. 21 logger: The logger. 22 kwargs: Additional keyord arguments for `torch_em.trainer.DefaultTrainer`. 23 """ 24 def __init__( 25 self, 26 model: torch.nn.Module, 27 momentum: float = 0.999, 28 semisupervised_loss: Optional[torch.nn.Module] = None, 29 semisupervised_loader: Optional[torch.utils.data.DataLoader] = None, 30 logger=TensorboardLogger, 31 **kwargs, 32 ): 33 super().__init__(model=model, logger=logger, **kwargs) 34 self.momentum = momentum 35 # copy the model and don"t require gradients for it 36 self.model2 = deepcopy(self.model) 37 for param in self.model2.parameters(): 38 param.requires_grad = False 39 # do we have a semi-supervised loss and loader? 40 assert (semisupervised_loss is None) == (semisupervised_loader is None) 41 self.semisupervised_loader = semisupervised_loader 42 self.semisupervised_loss = semisupervised_loss 43 self._kwargs = kwargs 44 45 def _momentum_update(self): 46 for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()): 47 param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum) 48 49 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 50 """@private 51 """ 52 super().save_checkpoint( 53 name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict 54 ) 55 56 def load_checkpoint(self, checkpoint="best"): 57 """@private 58 """ 59 save_dict = super().load_checkpoint(checkpoint) 60 self.model2.load_state_dict(save_dict["model2_state"]) 61 self.model2.to(self.device) 62 return save_dict 63 64 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 65 best_metric = super()._initialize(iterations, load_from_checkpoint, epochs) 66 self.model2.to(self.device) 67 return best_metric 68 69 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 70 self.model.train() 71 self.model2.train() 72 progress.set_description( 73 f"Run semi-supervised training for {len(self.semisupervised_loader)} iterations", refresh=True 74 ) 75 76 for x in self.semisupervised_loader: 77 x = x.to(self.device, non_blocking=True) 78 self.optimizer.zero_grad() 79 80 with forward_context(): 81 prediction = self.model(x) 82 with torch.no_grad(): 83 prediction2 = self.model2(x) 84 loss = self.semisupervised_loss(prediction, prediction2) 85 backprop(loss) 86 87 with torch.no_grad(): 88 self._momentum_update() 89 90 def _train_epoch_impl(self, progress, forward_context, backprop): 91 self.model.train() 92 self.model2.train() 93 94 n_iter = 0 95 t_per_iter = time.time() 96 for x, y in self.train_loader: 97 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 98 99 self.optimizer.zero_grad() 100 101 with forward_context(): 102 prediction = self.model(x) 103 with torch.no_grad(): 104 prediction2 = self.model2(x) 105 loss = self.loss((prediction, prediction2), y) 106 107 if self._iteration % self.log_image_interval == 0: 108 prediction.retain_grad() 109 110 backprop(loss) 111 112 with torch.no_grad(): 113 self._momentum_update() 114 115 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 116 if self.logger is not None: 117 self.logger.log_train(self._iteration, loss, lr, 118 x, y, prediction, 119 log_gradients=True) 120 121 self._iteration += 1 122 n_iter += 1 123 if self._iteration >= self.max_iteration: 124 break 125 progress.update(1) 126 127 if self.semisupervised_loader is not None: 128 self._train_epoch_semisupervised(progress, forward_context, backprop) 129 t_per_iter = (time.time() - t_per_iter) / n_iter 130 return t_per_iter 131 132 def _validate_impl(self, forward_context): 133 self.model.eval() 134 self.model2.eval() 135 136 metric = 0.0 137 loss = 0.0 138 139 with torch.no_grad(): 140 for x, y in self.val_loader: 141 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 142 with forward_context(): 143 prediction = self.model(x) 144 prediction2 = self.model2(x) 145 loss += self.loss((prediction, prediction2), y).item() 146 metric += self.metric(prediction, y).item() 147 148 metric /= len(self.val_loader) 149 loss /= len(self.val_loader) 150 if self.logger is not None: 151 self.logger.log_validation(self._iteration, metric, loss, x, y, prediction) 152 return metric
11class SPOCOTrainer(DefaultTrainer): 12 """Trainer for a SPOCO model. 13 14 For details check out "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": 15 https://arxiv.org/abs/2103.14572 16 17 Args: 18 model: The model to train. 19 momentum: The momementum value for exponential moving weight averaging. 20 semisupervised_loss: Optional loss for semi-supervised learning. 21 semisupervised_loader: Optional data loader for semi-supervised learning. 22 logger: The logger. 23 kwargs: Additional keyord arguments for `torch_em.trainer.DefaultTrainer`. 24 """ 25 def __init__( 26 self, 27 model: torch.nn.Module, 28 momentum: float = 0.999, 29 semisupervised_loss: Optional[torch.nn.Module] = None, 30 semisupervised_loader: Optional[torch.utils.data.DataLoader] = None, 31 logger=TensorboardLogger, 32 **kwargs, 33 ): 34 super().__init__(model=model, logger=logger, **kwargs) 35 self.momentum = momentum 36 # copy the model and don"t require gradients for it 37 self.model2 = deepcopy(self.model) 38 for param in self.model2.parameters(): 39 param.requires_grad = False 40 # do we have a semi-supervised loss and loader? 41 assert (semisupervised_loss is None) == (semisupervised_loader is None) 42 self.semisupervised_loader = semisupervised_loader 43 self.semisupervised_loss = semisupervised_loss 44 self._kwargs = kwargs 45 46 def _momentum_update(self): 47 for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()): 48 param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum) 49 50 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 51 """@private 52 """ 53 super().save_checkpoint( 54 name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict 55 ) 56 57 def load_checkpoint(self, checkpoint="best"): 58 """@private 59 """ 60 save_dict = super().load_checkpoint(checkpoint) 61 self.model2.load_state_dict(save_dict["model2_state"]) 62 self.model2.to(self.device) 63 return save_dict 64 65 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 66 best_metric = super()._initialize(iterations, load_from_checkpoint, epochs) 67 self.model2.to(self.device) 68 return best_metric 69 70 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 71 self.model.train() 72 self.model2.train() 73 progress.set_description( 74 f"Run semi-supervised training for {len(self.semisupervised_loader)} iterations", refresh=True 75 ) 76 77 for x in self.semisupervised_loader: 78 x = x.to(self.device, non_blocking=True) 79 self.optimizer.zero_grad() 80 81 with forward_context(): 82 prediction = self.model(x) 83 with torch.no_grad(): 84 prediction2 = self.model2(x) 85 loss = self.semisupervised_loss(prediction, prediction2) 86 backprop(loss) 87 88 with torch.no_grad(): 89 self._momentum_update() 90 91 def _train_epoch_impl(self, progress, forward_context, backprop): 92 self.model.train() 93 self.model2.train() 94 95 n_iter = 0 96 t_per_iter = time.time() 97 for x, y in self.train_loader: 98 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 99 100 self.optimizer.zero_grad() 101 102 with forward_context(): 103 prediction = self.model(x) 104 with torch.no_grad(): 105 prediction2 = self.model2(x) 106 loss = self.loss((prediction, prediction2), y) 107 108 if self._iteration % self.log_image_interval == 0: 109 prediction.retain_grad() 110 111 backprop(loss) 112 113 with torch.no_grad(): 114 self._momentum_update() 115 116 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 117 if self.logger is not None: 118 self.logger.log_train(self._iteration, loss, lr, 119 x, y, prediction, 120 log_gradients=True) 121 122 self._iteration += 1 123 n_iter += 1 124 if self._iteration >= self.max_iteration: 125 break 126 progress.update(1) 127 128 if self.semisupervised_loader is not None: 129 self._train_epoch_semisupervised(progress, forward_context, backprop) 130 t_per_iter = (time.time() - t_per_iter) / n_iter 131 return t_per_iter 132 133 def _validate_impl(self, forward_context): 134 self.model.eval() 135 self.model2.eval() 136 137 metric = 0.0 138 loss = 0.0 139 140 with torch.no_grad(): 141 for x, y in self.val_loader: 142 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 143 with forward_context(): 144 prediction = self.model(x) 145 prediction2 = self.model2(x) 146 loss += self.loss((prediction, prediction2), y).item() 147 metric += self.metric(prediction, y).item() 148 149 metric /= len(self.val_loader) 150 loss /= len(self.val_loader) 151 if self.logger is not None: 152 self.logger.log_validation(self._iteration, metric, loss, x, y, prediction) 153 return metric
Trainer for a SPOCO model.
For details check out "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572
Arguments:
- model: The model to train.
- momentum: The momementum value for exponential moving weight averaging.
- semisupervised_loss: Optional loss for semi-supervised learning.
- semisupervised_loader: Optional data loader for semi-supervised learning.
- logger: The logger.
- kwargs: Additional keyord arguments for
torch_em.trainer.DefaultTrainer
.
SPOCOTrainer( model: torch.nn.modules.module.Module, momentum: float = 0.999, semisupervised_loss: Optional[torch.nn.modules.module.Module] = None, semisupervised_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, **kwargs)
25 def __init__( 26 self, 27 model: torch.nn.Module, 28 momentum: float = 0.999, 29 semisupervised_loss: Optional[torch.nn.Module] = None, 30 semisupervised_loader: Optional[torch.utils.data.DataLoader] = None, 31 logger=TensorboardLogger, 32 **kwargs, 33 ): 34 super().__init__(model=model, logger=logger, **kwargs) 35 self.momentum = momentum 36 # copy the model and don"t require gradients for it 37 self.model2 = deepcopy(self.model) 38 for param in self.model2.parameters(): 39 param.requires_grad = False 40 # do we have a semi-supervised loss and loader? 41 assert (semisupervised_loss is None) == (semisupervised_loader is None) 42 self.semisupervised_loader = semisupervised_loader 43 self.semisupervised_loss = semisupervised_loss 44 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit