torch_em.trainer.spoco_trainer

  1import time
  2from copy import deepcopy
  3from typing import Optional
  4
  5import torch
  6from .default_trainer import DefaultTrainer
  7from .tensorboard_logger import TensorboardLogger
  8
  9
 10class SPOCOTrainer(DefaultTrainer):
 11    """Trainer for a SPOCO model.
 12
 13    For details check out "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings":
 14    https://arxiv.org/abs/2103.14572
 15
 16    Args:
 17        model: The model to train.
 18        momentum: The momementum value for exponential moving weight averaging.
 19        semisupervised_loss: Optional loss for semi-supervised learning.
 20        semisupervised_loader: Optional data loader for semi-supervised learning.
 21        logger: The logger.
 22        kwargs: Additional keyord arguments for `torch_em.trainer.DefaultTrainer`.
 23    """
 24    def __init__(
 25        self,
 26        model: torch.nn.Module,
 27        momentum: float = 0.999,
 28        semisupervised_loss: Optional[torch.nn.Module] = None,
 29        semisupervised_loader: Optional[torch.utils.data.DataLoader] = None,
 30        logger=TensorboardLogger,
 31        **kwargs,
 32    ):
 33        super().__init__(model=model, logger=logger, **kwargs)
 34        self.momentum = momentum
 35        # copy the model and don"t require gradients for it
 36        self.model2 = deepcopy(self.model)
 37        for param in self.model2.parameters():
 38            param.requires_grad = False
 39        # do we have a semi-supervised loss and loader?
 40        assert (semisupervised_loss is None) == (semisupervised_loader is None)
 41        self.semisupervised_loader = semisupervised_loader
 42        self.semisupervised_loss = semisupervised_loss
 43        self._kwargs = kwargs
 44
 45    def _momentum_update(self):
 46        for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()):
 47            param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum)
 48
 49    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
 50        """@private
 51        """
 52        super().save_checkpoint(
 53            name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict
 54        )
 55
 56    def load_checkpoint(self, checkpoint="best"):
 57        """@private
 58        """
 59        save_dict = super().load_checkpoint(checkpoint)
 60        self.model2.load_state_dict(save_dict["model2_state"])
 61        self.model2.to(self.device)
 62        return save_dict
 63
 64    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
 65        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
 66        self.model2.to(self.device)
 67        return best_metric
 68
 69    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
 70        self.model.train()
 71        self.model2.train()
 72        progress.set_description(
 73            f"Run semi-supervised training for {len(self.semisupervised_loader)} iterations", refresh=True
 74        )
 75
 76        for x in self.semisupervised_loader:
 77            x = x.to(self.device, non_blocking=True)
 78            self.optimizer.zero_grad()
 79
 80            with forward_context():
 81                prediction = self.model(x)
 82                with torch.no_grad():
 83                    prediction2 = self.model2(x)
 84                loss = self.semisupervised_loss(prediction, prediction2)
 85            backprop(loss)
 86
 87            with torch.no_grad():
 88                self._momentum_update()
 89
 90    def _train_epoch_impl(self, progress, forward_context, backprop):
 91        self.model.train()
 92        self.model2.train()
 93
 94        n_iter = 0
 95        t_per_iter = time.time()
 96        for x, y in self.train_loader:
 97            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
 98
 99            self.optimizer.zero_grad()
100
101            with forward_context():
102                prediction = self.model(x)
103                with torch.no_grad():
104                    prediction2 = self.model2(x)
105                loss = self.loss((prediction, prediction2), y)
106
107            if self._iteration % self.log_image_interval == 0:
108                prediction.retain_grad()
109
110            backprop(loss)
111
112            with torch.no_grad():
113                self._momentum_update()
114
115            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
116            if self.logger is not None:
117                self.logger.log_train(self._iteration, loss, lr,
118                                      x, y, prediction,
119                                      log_gradients=True)
120
121            self._iteration += 1
122            n_iter += 1
123            if self._iteration >= self.max_iteration:
124                break
125            progress.update(1)
126
127        if self.semisupervised_loader is not None:
128            self._train_epoch_semisupervised(progress, forward_context, backprop)
129        t_per_iter = (time.time() - t_per_iter) / n_iter
130        return t_per_iter
131
132    def _validate_impl(self, forward_context):
133        self.model.eval()
134        self.model2.eval()
135
136        metric = 0.0
137        loss = 0.0
138
139        with torch.no_grad():
140            for x, y in self.val_loader:
141                x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
142                with forward_context():
143                    prediction = self.model(x)
144                    prediction2 = self.model2(x)
145                loss += self.loss((prediction, prediction2), y).item()
146                metric += self.metric(prediction, y).item()
147
148        metric /= len(self.val_loader)
149        loss /= len(self.val_loader)
150        if self.logger is not None:
151            self.logger.log_validation(self._iteration, metric, loss, x, y, prediction)
152        return metric
class SPOCOTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 11class SPOCOTrainer(DefaultTrainer):
 12    """Trainer for a SPOCO model.
 13
 14    For details check out "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings":
 15    https://arxiv.org/abs/2103.14572
 16
 17    Args:
 18        model: The model to train.
 19        momentum: The momementum value for exponential moving weight averaging.
 20        semisupervised_loss: Optional loss for semi-supervised learning.
 21        semisupervised_loader: Optional data loader for semi-supervised learning.
 22        logger: The logger.
 23        kwargs: Additional keyord arguments for `torch_em.trainer.DefaultTrainer`.
 24    """
 25    def __init__(
 26        self,
 27        model: torch.nn.Module,
 28        momentum: float = 0.999,
 29        semisupervised_loss: Optional[torch.nn.Module] = None,
 30        semisupervised_loader: Optional[torch.utils.data.DataLoader] = None,
 31        logger=TensorboardLogger,
 32        **kwargs,
 33    ):
 34        super().__init__(model=model, logger=logger, **kwargs)
 35        self.momentum = momentum
 36        # copy the model and don"t require gradients for it
 37        self.model2 = deepcopy(self.model)
 38        for param in self.model2.parameters():
 39            param.requires_grad = False
 40        # do we have a semi-supervised loss and loader?
 41        assert (semisupervised_loss is None) == (semisupervised_loader is None)
 42        self.semisupervised_loader = semisupervised_loader
 43        self.semisupervised_loss = semisupervised_loss
 44        self._kwargs = kwargs
 45
 46    def _momentum_update(self):
 47        for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()):
 48            param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum)
 49
 50    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
 51        """@private
 52        """
 53        super().save_checkpoint(
 54            name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict
 55        )
 56
 57    def load_checkpoint(self, checkpoint="best"):
 58        """@private
 59        """
 60        save_dict = super().load_checkpoint(checkpoint)
 61        self.model2.load_state_dict(save_dict["model2_state"])
 62        self.model2.to(self.device)
 63        return save_dict
 64
 65    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
 66        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
 67        self.model2.to(self.device)
 68        return best_metric
 69
 70    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
 71        self.model.train()
 72        self.model2.train()
 73        progress.set_description(
 74            f"Run semi-supervised training for {len(self.semisupervised_loader)} iterations", refresh=True
 75        )
 76
 77        for x in self.semisupervised_loader:
 78            x = x.to(self.device, non_blocking=True)
 79            self.optimizer.zero_grad()
 80
 81            with forward_context():
 82                prediction = self.model(x)
 83                with torch.no_grad():
 84                    prediction2 = self.model2(x)
 85                loss = self.semisupervised_loss(prediction, prediction2)
 86            backprop(loss)
 87
 88            with torch.no_grad():
 89                self._momentum_update()
 90
 91    def _train_epoch_impl(self, progress, forward_context, backprop):
 92        self.model.train()
 93        self.model2.train()
 94
 95        n_iter = 0
 96        t_per_iter = time.time()
 97        for x, y in self.train_loader:
 98            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
 99
100            self.optimizer.zero_grad()
101
102            with forward_context():
103                prediction = self.model(x)
104                with torch.no_grad():
105                    prediction2 = self.model2(x)
106                loss = self.loss((prediction, prediction2), y)
107
108            if self._iteration % self.log_image_interval == 0:
109                prediction.retain_grad()
110
111            backprop(loss)
112
113            with torch.no_grad():
114                self._momentum_update()
115
116            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
117            if self.logger is not None:
118                self.logger.log_train(self._iteration, loss, lr,
119                                      x, y, prediction,
120                                      log_gradients=True)
121
122            self._iteration += 1
123            n_iter += 1
124            if self._iteration >= self.max_iteration:
125                break
126            progress.update(1)
127
128        if self.semisupervised_loader is not None:
129            self._train_epoch_semisupervised(progress, forward_context, backprop)
130        t_per_iter = (time.time() - t_per_iter) / n_iter
131        return t_per_iter
132
133    def _validate_impl(self, forward_context):
134        self.model.eval()
135        self.model2.eval()
136
137        metric = 0.0
138        loss = 0.0
139
140        with torch.no_grad():
141            for x, y in self.val_loader:
142                x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
143                with forward_context():
144                    prediction = self.model(x)
145                    prediction2 = self.model2(x)
146                loss += self.loss((prediction, prediction2), y).item()
147                metric += self.metric(prediction, y).item()
148
149        metric /= len(self.val_loader)
150        loss /= len(self.val_loader)
151        if self.logger is not None:
152            self.logger.log_validation(self._iteration, metric, loss, x, y, prediction)
153        return metric

Trainer for a SPOCO model.

For details check out "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings": https://arxiv.org/abs/2103.14572

Arguments:
  • model: The model to train.
  • momentum: The momementum value for exponential moving weight averaging.
  • semisupervised_loss: Optional loss for semi-supervised learning.
  • semisupervised_loader: Optional data loader for semi-supervised learning.
  • logger: The logger.
  • kwargs: Additional keyord arguments for torch_em.trainer.DefaultTrainer.
SPOCOTrainer( model: torch.nn.modules.module.Module, momentum: float = 0.999, semisupervised_loss: Optional[torch.nn.modules.module.Module] = None, semisupervised_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, **kwargs)
25    def __init__(
26        self,
27        model: torch.nn.Module,
28        momentum: float = 0.999,
29        semisupervised_loss: Optional[torch.nn.Module] = None,
30        semisupervised_loader: Optional[torch.utils.data.DataLoader] = None,
31        logger=TensorboardLogger,
32        **kwargs,
33    ):
34        super().__init__(model=model, logger=logger, **kwargs)
35        self.momentum = momentum
36        # copy the model and don"t require gradients for it
37        self.model2 = deepcopy(self.model)
38        for param in self.model2.parameters():
39            param.requires_grad = False
40        # do we have a semi-supervised loss and loader?
41        assert (semisupervised_loss is None) == (semisupervised_loader is None)
42        self.semisupervised_loader = semisupervised_loader
43        self.semisupervised_loss = semisupervised_loss
44        self._kwargs = kwargs
momentum
model2
semisupervised_loader
semisupervised_loss