torch_em.trainer.spoco_trainer

  1import time
  2from copy import deepcopy
  3
  4import torch
  5from .default_trainer import DefaultTrainer
  6from .tensorboard_logger import TensorboardLogger
  7
  8
  9class SPOCOTrainer(DefaultTrainer):
 10    def __init__(
 11        self,
 12        model,
 13        momentum=0.999,
 14        semisupervised_loss=None,
 15        semisupervised_loader=None,
 16        logger=TensorboardLogger,
 17        **kwargs,
 18    ):
 19        super().__init__(model=model, logger=logger, **kwargs)
 20        self.momentum = momentum
 21        # copy the model and don"t require gradients for it
 22        self.model2 = deepcopy(self.model)
 23        for param in self.model2.parameters():
 24            param.requires_grad = False
 25        # do we have a semi-supervised loss and loader?
 26        assert (semisupervised_loss is None) == (semisupervised_loader is None)
 27        self.semisupervised_loader = semisupervised_loader
 28        self.semisupervised_loss = semisupervised_loss
 29        self._kwargs = kwargs
 30
 31    def _momentum_update(self):
 32        for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()):
 33            param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum)
 34
 35    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
 36        super().save_checkpoint(
 37            name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict
 38        )
 39
 40    def load_checkpoint(self, checkpoint="best"):
 41        save_dict = super().load_checkpoint(checkpoint)
 42        self.model2.load_state_dict(save_dict["model2_state"])
 43        self.model2.to(self.device)
 44        return save_dict
 45
 46    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
 47        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
 48        self.model2.to(self.device)
 49        return best_metric
 50
 51    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
 52        self.model.train()
 53        self.model2.train()
 54        progress.set_description(
 55            f"Run semi-supervised training for {len(self.semisupervised_loader)} iterations", refresh=True
 56        )
 57
 58        for x in self.semisupervised_loader:
 59            x = x.to(self.device)
 60            self.optimizer.zero_grad()
 61
 62            with forward_context():
 63                prediction = self.model(x)
 64                with torch.no_grad():
 65                    prediction2 = self.model2(x)
 66                loss = self.semisupervised_loss(prediction, prediction2)
 67            backprop(loss)
 68
 69            with torch.no_grad():
 70                self._momentum_update()
 71
 72    def _train_epoch_impl(self, progress, forward_context, backprop):
 73        self.model.train()
 74        self.model2.train()
 75
 76        n_iter = 0
 77        t_per_iter = time.time()
 78        for x, y in self.train_loader:
 79            x, y = x.to(self.device), y.to(self.device)
 80
 81            self.optimizer.zero_grad()
 82
 83            with forward_context():
 84                prediction = self.model(x)
 85                with torch.no_grad():
 86                    prediction2 = self.model2(x)
 87                loss = self.loss((prediction, prediction2), y)
 88
 89            if self._iteration % self.log_image_interval == 0:
 90                prediction.retain_grad()
 91
 92            backprop(loss)
 93
 94            with torch.no_grad():
 95                self._momentum_update()
 96
 97            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
 98            if self.logger is not None:
 99                self.logger.log_train(self._iteration, loss, lr,
100                                      x, y, prediction,
101                                      log_gradients=True)
102
103            self._iteration += 1
104            n_iter += 1
105            if self._iteration >= self.max_iteration:
106                break
107            progress.update(1)
108
109        if self.semisupervised_loader is not None:
110            self._train_epoch_semisupervised(progress, forward_context, backprop)
111        t_per_iter = (time.time() - t_per_iter) / n_iter
112        return t_per_iter
113
114    def _validate_impl(self, forward_context):
115        self.model.eval()
116        self.model2.eval()
117
118        metric = 0.0
119        loss = 0.0
120
121        with torch.no_grad():
122            for x, y in self.val_loader:
123                x, y = x.to(self.device), y.to(self.device)
124                with forward_context():
125                    prediction = self.model(x)
126                    prediction2 = self.model2(x)
127                loss += self.loss((prediction, prediction2), y).item()
128                metric += self.metric(prediction, y).item()
129
130        metric /= len(self.val_loader)
131        loss /= len(self.val_loader)
132        if self.logger is not None:
133            self.logger.log_validation(self._iteration, metric, loss, x, y, prediction)
134        return metric
class SPOCOTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 10class SPOCOTrainer(DefaultTrainer):
 11    def __init__(
 12        self,
 13        model,
 14        momentum=0.999,
 15        semisupervised_loss=None,
 16        semisupervised_loader=None,
 17        logger=TensorboardLogger,
 18        **kwargs,
 19    ):
 20        super().__init__(model=model, logger=logger, **kwargs)
 21        self.momentum = momentum
 22        # copy the model and don"t require gradients for it
 23        self.model2 = deepcopy(self.model)
 24        for param in self.model2.parameters():
 25            param.requires_grad = False
 26        # do we have a semi-supervised loss and loader?
 27        assert (semisupervised_loss is None) == (semisupervised_loader is None)
 28        self.semisupervised_loader = semisupervised_loader
 29        self.semisupervised_loss = semisupervised_loss
 30        self._kwargs = kwargs
 31
 32    def _momentum_update(self):
 33        for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()):
 34            param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum)
 35
 36    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
 37        super().save_checkpoint(
 38            name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict
 39        )
 40
 41    def load_checkpoint(self, checkpoint="best"):
 42        save_dict = super().load_checkpoint(checkpoint)
 43        self.model2.load_state_dict(save_dict["model2_state"])
 44        self.model2.to(self.device)
 45        return save_dict
 46
 47    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
 48        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
 49        self.model2.to(self.device)
 50        return best_metric
 51
 52    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
 53        self.model.train()
 54        self.model2.train()
 55        progress.set_description(
 56            f"Run semi-supervised training for {len(self.semisupervised_loader)} iterations", refresh=True
 57        )
 58
 59        for x in self.semisupervised_loader:
 60            x = x.to(self.device)
 61            self.optimizer.zero_grad()
 62
 63            with forward_context():
 64                prediction = self.model(x)
 65                with torch.no_grad():
 66                    prediction2 = self.model2(x)
 67                loss = self.semisupervised_loss(prediction, prediction2)
 68            backprop(loss)
 69
 70            with torch.no_grad():
 71                self._momentum_update()
 72
 73    def _train_epoch_impl(self, progress, forward_context, backprop):
 74        self.model.train()
 75        self.model2.train()
 76
 77        n_iter = 0
 78        t_per_iter = time.time()
 79        for x, y in self.train_loader:
 80            x, y = x.to(self.device), y.to(self.device)
 81
 82            self.optimizer.zero_grad()
 83
 84            with forward_context():
 85                prediction = self.model(x)
 86                with torch.no_grad():
 87                    prediction2 = self.model2(x)
 88                loss = self.loss((prediction, prediction2), y)
 89
 90            if self._iteration % self.log_image_interval == 0:
 91                prediction.retain_grad()
 92
 93            backprop(loss)
 94
 95            with torch.no_grad():
 96                self._momentum_update()
 97
 98            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
 99            if self.logger is not None:
100                self.logger.log_train(self._iteration, loss, lr,
101                                      x, y, prediction,
102                                      log_gradients=True)
103
104            self._iteration += 1
105            n_iter += 1
106            if self._iteration >= self.max_iteration:
107                break
108            progress.update(1)
109
110        if self.semisupervised_loader is not None:
111            self._train_epoch_semisupervised(progress, forward_context, backprop)
112        t_per_iter = (time.time() - t_per_iter) / n_iter
113        return t_per_iter
114
115    def _validate_impl(self, forward_context):
116        self.model.eval()
117        self.model2.eval()
118
119        metric = 0.0
120        loss = 0.0
121
122        with torch.no_grad():
123            for x, y in self.val_loader:
124                x, y = x.to(self.device), y.to(self.device)
125                with forward_context():
126                    prediction = self.model(x)
127                    prediction2 = self.model2(x)
128                loss += self.loss((prediction, prediction2), y).item()
129                metric += self.metric(prediction, y).item()
130
131        metric /= len(self.val_loader)
132        loss /= len(self.val_loader)
133        if self.logger is not None:
134            self.logger.log_validation(self._iteration, metric, loss, x, y, prediction)
135        return metric

Trainer class for 2d/3d training on a single GPU.

SPOCOTrainer( model, momentum=0.999, semisupervised_loss=None, semisupervised_loader=None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, **kwargs)
11    def __init__(
12        self,
13        model,
14        momentum=0.999,
15        semisupervised_loss=None,
16        semisupervised_loader=None,
17        logger=TensorboardLogger,
18        **kwargs,
19    ):
20        super().__init__(model=model, logger=logger, **kwargs)
21        self.momentum = momentum
22        # copy the model and don"t require gradients for it
23        self.model2 = deepcopy(self.model)
24        for param in self.model2.parameters():
25            param.requires_grad = False
26        # do we have a semi-supervised loss and loader?
27        assert (semisupervised_loss is None) == (semisupervised_loader is None)
28        self.semisupervised_loader = semisupervised_loader
29        self.semisupervised_loss = semisupervised_loss
30        self._kwargs = kwargs
momentum
model2
semisupervised_loader
semisupervised_loss
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
36    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
37        super().save_checkpoint(
38            name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict
39        )
def load_checkpoint(self, checkpoint='best'):
41    def load_checkpoint(self, checkpoint="best"):
42        save_dict = super().load_checkpoint(checkpoint)
43        self.model2.load_state_dict(save_dict["model2_state"])
44        self.model2.to(self.device)
45        return save_dict