torch_em.self_training.mean_teacher

  1import time
  2from copy import deepcopy
  3from typing import Callable, Optional
  4
  5import torch
  6import torch_em
  7from torch_em.util import get_constructor_arguments
  8
  9from .logger import SelfTrainingTensorboardLogger
 10
 11
 12class Dummy(torch.nn.Module):
 13    init_kwargs = {}
 14
 15
 16class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer):
 17    """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
 18
 19    Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780.
 20    It uses a teacher model derived from the student model via EMA of weights
 21    to predict pseudo-labels on unlabeled data. We support two training strategies:
 22    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 23    - Training only on the unsupervised data.
 24
 25    This class expects the following data loaders:
 26    - unsupervised_train_loader: Returns two augmentations of the same input.
 27    - supervised_train_loader (optional): Returns input and labels.
 28    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 29    - supervised_val_loader (optional): Same as supervised_train_loader
 30    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 31
 32    And the following elements to customize the pseudo labeling:
 33    - pseudo_labeler: to compute the psuedo-labels
 34        - Parameters: teacher, teacher_input
 35        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 36    - unsupervised_loss: the loss between model predictions and pseudo labels
 37        - Parameters: model, model_input, pseudo_labels, label_filter
 38        - Returns: loss
 39    - supervised_loss (optional): the supervised loss function
 40        - Parameters: model, input, labels
 41        - Returns: loss
 42    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 43        - Parameters: model, model_input, pseudo_labels, label_filter
 44        - Returns: loss, metric
 45    - supervised_loss_and_metric (optional): the supervised loss function and metric
 46        - Parameters: model, input, labels
 47        - Returns: loss, metric
 48    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 49
 50    If the parameter reinit_teacher is set to true, the teacher weights are re-initialized.
 51    If it is None, the most appropriate initialization scheme for the training approach is chosen:
 52    - semi-supervised training -> reinit, because we usually train a model from scratch
 53    - unsupervised training -> do not reinit, because we usually fine-tune a model
 54
 55    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 56    for setting the ratio between supervised and unsupervised training samples
 57
 58    Args:
 59        model: The model to be trained.
 60        unsupervised_train_loader: The loader for unsupervised training.
 61        unsupervised_loss: The loss for unsupervised training.
 62        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 63        supervised_train_loader: The loader for supervised training.
 64        supervised_loss: The loss for supervised training.
 65        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 66        supervised_loss_and_metric: The loss and metrhic for supervised training.
 67        logger: The logger.
 68        momentum: The momentum value for the exponential moving weight average of the teacher model.
 69        reinit_teacher: Whether to reinit the teacher model before starting the training.
 70        sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
 71        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 72    """
 73
 74    def __init__(
 75        self,
 76        model: torch.nn.Module,
 77        unsupervised_train_loader: torch.utils.data.DataLoader,
 78        unsupervised_loss: torch.utils.data.DataLoader,
 79        pseudo_labeler: Callable,
 80        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 81        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 82        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 83        supervised_loss: Optional[Callable] = None,
 84        unsupervised_loss_and_metric: Optional[Callable] = None,
 85        supervised_loss_and_metric: Optional[Callable] = None,
 86        logger=SelfTrainingTensorboardLogger,
 87        momentum: float = 0.999,
 88        reinit_teacher: Optional[bool] = None,
 89        sampler: Optional[Callable] = None,
 90        **kwargs,
 91    ):
 92        self.sampler = sampler
 93        # Do we have supervised data or not?
 94        if supervised_train_loader is None:
 95            # No. -> We use the unsupervised training logic.
 96            train_loader = unsupervised_train_loader
 97            self._train_epoch_impl = self._train_epoch_unsupervised
 98        else:
 99            # Yes. -> We use the semi-supervised training logic.
100            assert supervised_loss is not None
101            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
102                else unsupervised_train_loader
103            self._train_epoch_impl = self._train_epoch_semisupervised
104
105        self.unsupervised_train_loader = unsupervised_train_loader
106        self.supervised_train_loader = supervised_train_loader
107
108        # Check that we have at least one of supvervised / unsupervised val loader.
109        assert sum((
110            supervised_val_loader is not None,
111            unsupervised_val_loader is not None,
112        )) > 0
113        self.supervised_val_loader = supervised_val_loader
114        self.unsupervised_val_loader = unsupervised_val_loader
115
116        if self.unsupervised_val_loader is None:
117            val_loader = self.supervised_val_loader
118        else:
119            val_loader = self.unsupervised_train_loader
120
121        # Check that we have at least one of supvervised / unsupervised loss and metric.
122        assert sum((
123            supervised_loss_and_metric is not None,
124            unsupervised_loss_and_metric is not None,
125        )) > 0
126        self.supervised_loss_and_metric = supervised_loss_and_metric
127        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
128
129        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
130        kwargs.pop("train_loader", None)
131        kwargs.pop("val_loader", None)
132        kwargs.pop("metric", None)
133        kwargs.pop("loss", None)
134        super().__init__(
135            model=model, train_loader=train_loader, val_loader=val_loader,
136            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
137        )
138
139        self.unsupervised_loss = unsupervised_loss
140        self.supervised_loss = supervised_loss
141
142        self.pseudo_labeler = pseudo_labeler
143        self.momentum = momentum
144
145        # determine how we initialize the teacher weights (copy or reinitialization)
146        if reinit_teacher is None:
147            # semisupervised training: reinitialize
148            # unsupervised training: copy
149            self.reinit_teacher = supervised_train_loader is not None
150        else:
151            self.reinit_teacher = reinit_teacher
152
153        with torch.no_grad():
154            self.teacher = deepcopy(self.model)
155            if self.reinit_teacher:
156                for layer in self.teacher.children():
157                    if hasattr(layer, "reset_parameters"):
158                        layer.reset_parameters()
159            for param in self.teacher.parameters():
160                param.requires_grad = False
161
162        self._kwargs = kwargs
163
164    def _momentum_update(self):
165        # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations
166        # to avoid a large gap between teacher and student weights, leading to inconsistent predictions
167        # if we don't reinit this is not necessary
168        if self.reinit_teacher:
169            current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum)
170        else:
171            current_momentum = self.momentum
172
173        for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()):
174            param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum)
175
176    #
177    # functionality for saving checkpoints and initialization
178    #
179
180    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
181        """@private
182        """
183        train_loader_kwargs = get_constructor_arguments(self.train_loader)
184        val_loader_kwargs = get_constructor_arguments(self.val_loader)
185        extra_state = {
186            "teacher_state": self.teacher.state_dict(),
187            "init": {
188                "train_loader_kwargs": train_loader_kwargs,
189                "train_dataset": self.train_loader.dataset,
190                "val_loader_kwargs": val_loader_kwargs,
191                "val_dataset": self.val_loader.dataset,
192                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
193                "loss_kwargs": {},
194                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
195                "metric_kwargs": {},
196            },
197        }
198        extra_state.update(**extra_save_dict)
199        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
200
201    def load_checkpoint(self, checkpoint="best"):
202        """@private
203        """
204        save_dict = super().load_checkpoint(checkpoint)
205        self.teacher.load_state_dict(save_dict["teacher_state"])
206        self.teacher.to(self.device)
207        return save_dict
208
209    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
210        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
211        self.teacher.to(self.device)
212        return best_metric
213
214    #
215    # training and validation functionality
216    #
217
218    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
219        self.model.train()
220
221        n_iter = 0
222        t_per_iter = time.time()
223
224        # Sample from both the supervised and unsupervised loader.
225        for xu1, xu2 in self.unsupervised_train_loader:
226            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
227
228            teacher_input, model_input = xu1, xu2
229
230            with forward_context(), torch.no_grad():
231                # Compute the pseudo labels.
232                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
233
234            # If we have a sampler then check if the current batch matches the condition for inclusion in training.
235            if self.sampler is not None:
236                keep_batch = self.sampler(pseudo_labels, label_filter)
237                if not keep_batch:
238                    continue
239
240            self.optimizer.zero_grad()
241            # Perform unsupervised training
242            with forward_context():
243                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
244            backprop(loss)
245
246            if self.logger is not None:
247                with torch.no_grad(), forward_context():
248                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
249                self.logger.log_train_unsupervised(
250                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
251                )
252                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
253                self.logger.log_lr(self._iteration, lr)
254                if self.pseudo_labeler.confidence_threshold is not None:
255                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
256
257            with torch.no_grad():
258                self._momentum_update()
259
260            self._iteration += 1
261            n_iter += 1
262            if self._iteration >= self.max_iteration:
263                break
264            progress.update(1)
265
266        t_per_iter = (time.time() - t_per_iter) / n_iter
267        return t_per_iter
268
269    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
270        self.model.train()
271
272        n_iter = 0
273        t_per_iter = time.time()
274
275        # Sample from both the supervised and unsupervised loader.
276        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
277            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
278            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
279
280            # Perform supervised training.
281            self.optimizer.zero_grad()
282            with forward_context():
283                # We pass the model, the input and the labels to the supervised loss function,
284                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
285                supervised_loss = self.supervised_loss(self.model, xs, ys)
286
287            teacher_input, model_input = xu1, xu2
288
289            with forward_context(), torch.no_grad():
290                # Compute the pseudo labels.
291                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
292
293            # Perform unsupervised training
294            with forward_context():
295                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
296
297            loss = (supervised_loss + unsupervised_loss) / 2
298            backprop(loss)
299
300            if self.logger is not None:
301                with torch.no_grad(), forward_context():
302                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
303                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
304
305                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
306                self.logger.log_train_unsupervised(
307                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
308                )
309
310                self.logger.log_combined_loss(self._iteration, loss)
311                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
312                self.logger.log_lr(self._iteration, lr)
313
314            with torch.no_grad():
315                self._momentum_update()
316
317            self._iteration += 1
318            n_iter += 1
319            if self._iteration >= self.max_iteration:
320                break
321            progress.update(1)
322
323        t_per_iter = (time.time() - t_per_iter) / n_iter
324        return t_per_iter
325
326    def _validate_supervised(self, forward_context):
327        metric_val = 0.0
328        loss_val = 0.0
329
330        for x, y in self.supervised_val_loader:
331            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
332            with forward_context():
333                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
334            loss_val += loss.item()
335            metric_val += metric.item()
336
337        metric_val /= len(self.supervised_val_loader)
338        loss_val /= len(self.supervised_val_loader)
339
340        if self.logger is not None:
341            with forward_context():
342                pred = self.model(x)
343            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
344
345        return metric_val
346
347    def _validate_unsupervised(self, forward_context):
348        metric_val = 0.0
349        loss_val = 0.0
350
351        for x1, x2 in self.unsupervised_val_loader:
352            x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True)
353            teacher_input, model_input = x1, x2
354            with forward_context():
355                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
356                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
357            loss_val += loss.item()
358            metric_val += metric.item()
359
360        metric_val /= len(self.unsupervised_val_loader)
361        loss_val /= len(self.unsupervised_val_loader)
362
363        if self.logger is not None:
364            with forward_context():
365                pred = self.model(model_input)
366            self.logger.log_validation_unsupervised(
367                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
368            )
369
370        self.pseudo_labeler.step(metric_val, self._epoch) # NOTE: scheduler added in validation step
371
372        return metric_val
373
374    def _validate_impl(self, forward_context):
375        self.model.eval()
376
377        with torch.no_grad():
378
379            if self.supervised_val_loader is None:
380                supervised_metric = None
381            else:
382                supervised_metric = self._validate_supervised(forward_context)
383
384            if self.unsupervised_val_loader is None:
385                unsupervised_metric = None
386            else:
387                unsupervised_metric = self._validate_unsupervised(forward_context)
388
389        if unsupervised_metric is None:
390            metric = supervised_metric
391        elif supervised_metric is None:
392            metric = unsupervised_metric
393        else:
394            metric = (supervised_metric + unsupervised_metric) / 2
395
396        return metric
class Dummy(torch.nn.modules.module.Module):
13class Dummy(torch.nn.Module):
14    init_kwargs = {}

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

init_kwargs = {}
class MeanTeacherTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 17class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer):
 18    """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
 19
 20    Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780.
 21    It uses a teacher model derived from the student model via EMA of weights
 22    to predict pseudo-labels on unlabeled data. We support two training strategies:
 23    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 24    - Training only on the unsupervised data.
 25
 26    This class expects the following data loaders:
 27    - unsupervised_train_loader: Returns two augmentations of the same input.
 28    - supervised_train_loader (optional): Returns input and labels.
 29    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 30    - supervised_val_loader (optional): Same as supervised_train_loader
 31    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 32
 33    And the following elements to customize the pseudo labeling:
 34    - pseudo_labeler: to compute the psuedo-labels
 35        - Parameters: teacher, teacher_input
 36        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 37    - unsupervised_loss: the loss between model predictions and pseudo labels
 38        - Parameters: model, model_input, pseudo_labels, label_filter
 39        - Returns: loss
 40    - supervised_loss (optional): the supervised loss function
 41        - Parameters: model, input, labels
 42        - Returns: loss
 43    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 44        - Parameters: model, model_input, pseudo_labels, label_filter
 45        - Returns: loss, metric
 46    - supervised_loss_and_metric (optional): the supervised loss function and metric
 47        - Parameters: model, input, labels
 48        - Returns: loss, metric
 49    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 50
 51    If the parameter reinit_teacher is set to true, the teacher weights are re-initialized.
 52    If it is None, the most appropriate initialization scheme for the training approach is chosen:
 53    - semi-supervised training -> reinit, because we usually train a model from scratch
 54    - unsupervised training -> do not reinit, because we usually fine-tune a model
 55
 56    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 57    for setting the ratio between supervised and unsupervised training samples
 58
 59    Args:
 60        model: The model to be trained.
 61        unsupervised_train_loader: The loader for unsupervised training.
 62        unsupervised_loss: The loss for unsupervised training.
 63        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 64        supervised_train_loader: The loader for supervised training.
 65        supervised_loss: The loss for supervised training.
 66        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 67        supervised_loss_and_metric: The loss and metrhic for supervised training.
 68        logger: The logger.
 69        momentum: The momentum value for the exponential moving weight average of the teacher model.
 70        reinit_teacher: Whether to reinit the teacher model before starting the training.
 71        sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
 72        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 73    """
 74
 75    def __init__(
 76        self,
 77        model: torch.nn.Module,
 78        unsupervised_train_loader: torch.utils.data.DataLoader,
 79        unsupervised_loss: torch.utils.data.DataLoader,
 80        pseudo_labeler: Callable,
 81        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 82        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 83        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 84        supervised_loss: Optional[Callable] = None,
 85        unsupervised_loss_and_metric: Optional[Callable] = None,
 86        supervised_loss_and_metric: Optional[Callable] = None,
 87        logger=SelfTrainingTensorboardLogger,
 88        momentum: float = 0.999,
 89        reinit_teacher: Optional[bool] = None,
 90        sampler: Optional[Callable] = None,
 91        **kwargs,
 92    ):
 93        self.sampler = sampler
 94        # Do we have supervised data or not?
 95        if supervised_train_loader is None:
 96            # No. -> We use the unsupervised training logic.
 97            train_loader = unsupervised_train_loader
 98            self._train_epoch_impl = self._train_epoch_unsupervised
 99        else:
100            # Yes. -> We use the semi-supervised training logic.
101            assert supervised_loss is not None
102            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
103                else unsupervised_train_loader
104            self._train_epoch_impl = self._train_epoch_semisupervised
105
106        self.unsupervised_train_loader = unsupervised_train_loader
107        self.supervised_train_loader = supervised_train_loader
108
109        # Check that we have at least one of supvervised / unsupervised val loader.
110        assert sum((
111            supervised_val_loader is not None,
112            unsupervised_val_loader is not None,
113        )) > 0
114        self.supervised_val_loader = supervised_val_loader
115        self.unsupervised_val_loader = unsupervised_val_loader
116
117        if self.unsupervised_val_loader is None:
118            val_loader = self.supervised_val_loader
119        else:
120            val_loader = self.unsupervised_train_loader
121
122        # Check that we have at least one of supvervised / unsupervised loss and metric.
123        assert sum((
124            supervised_loss_and_metric is not None,
125            unsupervised_loss_and_metric is not None,
126        )) > 0
127        self.supervised_loss_and_metric = supervised_loss_and_metric
128        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
129
130        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
131        kwargs.pop("train_loader", None)
132        kwargs.pop("val_loader", None)
133        kwargs.pop("metric", None)
134        kwargs.pop("loss", None)
135        super().__init__(
136            model=model, train_loader=train_loader, val_loader=val_loader,
137            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
138        )
139
140        self.unsupervised_loss = unsupervised_loss
141        self.supervised_loss = supervised_loss
142
143        self.pseudo_labeler = pseudo_labeler
144        self.momentum = momentum
145
146        # determine how we initialize the teacher weights (copy or reinitialization)
147        if reinit_teacher is None:
148            # semisupervised training: reinitialize
149            # unsupervised training: copy
150            self.reinit_teacher = supervised_train_loader is not None
151        else:
152            self.reinit_teacher = reinit_teacher
153
154        with torch.no_grad():
155            self.teacher = deepcopy(self.model)
156            if self.reinit_teacher:
157                for layer in self.teacher.children():
158                    if hasattr(layer, "reset_parameters"):
159                        layer.reset_parameters()
160            for param in self.teacher.parameters():
161                param.requires_grad = False
162
163        self._kwargs = kwargs
164
165    def _momentum_update(self):
166        # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations
167        # to avoid a large gap between teacher and student weights, leading to inconsistent predictions
168        # if we don't reinit this is not necessary
169        if self.reinit_teacher:
170            current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum)
171        else:
172            current_momentum = self.momentum
173
174        for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()):
175            param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum)
176
177    #
178    # functionality for saving checkpoints and initialization
179    #
180
181    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
182        """@private
183        """
184        train_loader_kwargs = get_constructor_arguments(self.train_loader)
185        val_loader_kwargs = get_constructor_arguments(self.val_loader)
186        extra_state = {
187            "teacher_state": self.teacher.state_dict(),
188            "init": {
189                "train_loader_kwargs": train_loader_kwargs,
190                "train_dataset": self.train_loader.dataset,
191                "val_loader_kwargs": val_loader_kwargs,
192                "val_dataset": self.val_loader.dataset,
193                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
194                "loss_kwargs": {},
195                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
196                "metric_kwargs": {},
197            },
198        }
199        extra_state.update(**extra_save_dict)
200        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
201
202    def load_checkpoint(self, checkpoint="best"):
203        """@private
204        """
205        save_dict = super().load_checkpoint(checkpoint)
206        self.teacher.load_state_dict(save_dict["teacher_state"])
207        self.teacher.to(self.device)
208        return save_dict
209
210    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
211        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
212        self.teacher.to(self.device)
213        return best_metric
214
215    #
216    # training and validation functionality
217    #
218
219    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
220        self.model.train()
221
222        n_iter = 0
223        t_per_iter = time.time()
224
225        # Sample from both the supervised and unsupervised loader.
226        for xu1, xu2 in self.unsupervised_train_loader:
227            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
228
229            teacher_input, model_input = xu1, xu2
230
231            with forward_context(), torch.no_grad():
232                # Compute the pseudo labels.
233                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
234
235            # If we have a sampler then check if the current batch matches the condition for inclusion in training.
236            if self.sampler is not None:
237                keep_batch = self.sampler(pseudo_labels, label_filter)
238                if not keep_batch:
239                    continue
240
241            self.optimizer.zero_grad()
242            # Perform unsupervised training
243            with forward_context():
244                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
245            backprop(loss)
246
247            if self.logger is not None:
248                with torch.no_grad(), forward_context():
249                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
250                self.logger.log_train_unsupervised(
251                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
252                )
253                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
254                self.logger.log_lr(self._iteration, lr)
255                if self.pseudo_labeler.confidence_threshold is not None:
256                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
257
258            with torch.no_grad():
259                self._momentum_update()
260
261            self._iteration += 1
262            n_iter += 1
263            if self._iteration >= self.max_iteration:
264                break
265            progress.update(1)
266
267        t_per_iter = (time.time() - t_per_iter) / n_iter
268        return t_per_iter
269
270    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
271        self.model.train()
272
273        n_iter = 0
274        t_per_iter = time.time()
275
276        # Sample from both the supervised and unsupervised loader.
277        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
278            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
279            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
280
281            # Perform supervised training.
282            self.optimizer.zero_grad()
283            with forward_context():
284                # We pass the model, the input and the labels to the supervised loss function,
285                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
286                supervised_loss = self.supervised_loss(self.model, xs, ys)
287
288            teacher_input, model_input = xu1, xu2
289
290            with forward_context(), torch.no_grad():
291                # Compute the pseudo labels.
292                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
293
294            # Perform unsupervised training
295            with forward_context():
296                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
297
298            loss = (supervised_loss + unsupervised_loss) / 2
299            backprop(loss)
300
301            if self.logger is not None:
302                with torch.no_grad(), forward_context():
303                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
304                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
305
306                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
307                self.logger.log_train_unsupervised(
308                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
309                )
310
311                self.logger.log_combined_loss(self._iteration, loss)
312                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
313                self.logger.log_lr(self._iteration, lr)
314
315            with torch.no_grad():
316                self._momentum_update()
317
318            self._iteration += 1
319            n_iter += 1
320            if self._iteration >= self.max_iteration:
321                break
322            progress.update(1)
323
324        t_per_iter = (time.time() - t_per_iter) / n_iter
325        return t_per_iter
326
327    def _validate_supervised(self, forward_context):
328        metric_val = 0.0
329        loss_val = 0.0
330
331        for x, y in self.supervised_val_loader:
332            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
333            with forward_context():
334                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
335            loss_val += loss.item()
336            metric_val += metric.item()
337
338        metric_val /= len(self.supervised_val_loader)
339        loss_val /= len(self.supervised_val_loader)
340
341        if self.logger is not None:
342            with forward_context():
343                pred = self.model(x)
344            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
345
346        return metric_val
347
348    def _validate_unsupervised(self, forward_context):
349        metric_val = 0.0
350        loss_val = 0.0
351
352        for x1, x2 in self.unsupervised_val_loader:
353            x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True)
354            teacher_input, model_input = x1, x2
355            with forward_context():
356                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
357                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
358            loss_val += loss.item()
359            metric_val += metric.item()
360
361        metric_val /= len(self.unsupervised_val_loader)
362        loss_val /= len(self.unsupervised_val_loader)
363
364        if self.logger is not None:
365            with forward_context():
366                pred = self.model(model_input)
367            self.logger.log_validation_unsupervised(
368                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
369            )
370
371        self.pseudo_labeler.step(metric_val, self._epoch) # NOTE: scheduler added in validation step
372
373        return metric_val
374
375    def _validate_impl(self, forward_context):
376        self.model.eval()
377
378        with torch.no_grad():
379
380            if self.supervised_val_loader is None:
381                supervised_metric = None
382            else:
383                supervised_metric = self._validate_supervised(forward_context)
384
385            if self.unsupervised_val_loader is None:
386                unsupervised_metric = None
387            else:
388                unsupervised_metric = self._validate_unsupervised(forward_context)
389
390        if unsupervised_metric is None:
391            metric = supervised_metric
392        elif supervised_metric is None:
393            metric = unsupervised_metric
394        else:
395            metric = (supervised_metric + unsupervised_metric) / 2
396
397        return metric

Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.

Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. It uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. We support two training strategies:

  • Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
  • Training only on the unsupervised data.

This class expects the following data loaders:

  • unsupervised_train_loader: Returns two augmentations of the same input.
  • supervised_train_loader (optional): Returns input and labels.
  • unsupervised_val_loader (optional): Same as unsupervised_train_loader
  • supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.

And the following elements to customize the pseudo labeling:

  • pseudo_labeler: to compute the psuedo-labels
    • Parameters: teacher, teacher_input
    • Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
  • unsupervised_loss: the loss between model predictions and pseudo labels
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss
  • supervised_loss (optional): the supervised loss function
    • Parameters: model, input, labels
    • Returns: loss
  • unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss, metric
  • supervised_loss_and_metric (optional): the supervised loss function and metric
    • Parameters: model, input, labels
    • Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.

If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. If it is None, the most appropriate initialization scheme for the training approach is chosen:

  • semi-supervised training -> reinit, because we usually train a model from scratch
  • unsupervised training -> do not reinit, because we usually fine-tune a model

Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples

Arguments:
  • model: The model to be trained.
  • unsupervised_train_loader: The loader for unsupervised training.
  • unsupervised_loss: The loss for unsupervised training.
  • pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
  • supervised_train_loader: The loader for supervised training.
  • supervised_loss: The loss for supervised training.
  • unsupervised_loss_and_metric: The loss and metric for unsupervised training.
  • supervised_loss_and_metric: The loss and metrhic for supervised training.
  • logger: The logger.
  • momentum: The momentum value for the exponential moving weight average of the teacher model.
  • reinit_teacher: Whether to reinit the teacher model before starting the training.
  • sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
  • kwargs: Additional keyword arguments for torch_em.trainer.DefaultTrainer.
MeanTeacherTrainer( model: torch.nn.modules.module.Module, unsupervised_train_loader: torch.utils.data.dataloader.DataLoader, unsupervised_loss: torch.utils.data.dataloader.DataLoader, pseudo_labeler: Callable, supervised_train_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, unsupervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_loss: Optional[Callable] = None, unsupervised_loss_and_metric: Optional[Callable] = None, supervised_loss_and_metric: Optional[Callable] = None, logger=<class 'torch_em.self_training.logger.SelfTrainingTensorboardLogger'>, momentum: float = 0.999, reinit_teacher: Optional[bool] = None, sampler: Optional[Callable] = None, **kwargs)
 75    def __init__(
 76        self,
 77        model: torch.nn.Module,
 78        unsupervised_train_loader: torch.utils.data.DataLoader,
 79        unsupervised_loss: torch.utils.data.DataLoader,
 80        pseudo_labeler: Callable,
 81        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 82        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 83        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 84        supervised_loss: Optional[Callable] = None,
 85        unsupervised_loss_and_metric: Optional[Callable] = None,
 86        supervised_loss_and_metric: Optional[Callable] = None,
 87        logger=SelfTrainingTensorboardLogger,
 88        momentum: float = 0.999,
 89        reinit_teacher: Optional[bool] = None,
 90        sampler: Optional[Callable] = None,
 91        **kwargs,
 92    ):
 93        self.sampler = sampler
 94        # Do we have supervised data or not?
 95        if supervised_train_loader is None:
 96            # No. -> We use the unsupervised training logic.
 97            train_loader = unsupervised_train_loader
 98            self._train_epoch_impl = self._train_epoch_unsupervised
 99        else:
100            # Yes. -> We use the semi-supervised training logic.
101            assert supervised_loss is not None
102            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
103                else unsupervised_train_loader
104            self._train_epoch_impl = self._train_epoch_semisupervised
105
106        self.unsupervised_train_loader = unsupervised_train_loader
107        self.supervised_train_loader = supervised_train_loader
108
109        # Check that we have at least one of supvervised / unsupervised val loader.
110        assert sum((
111            supervised_val_loader is not None,
112            unsupervised_val_loader is not None,
113        )) > 0
114        self.supervised_val_loader = supervised_val_loader
115        self.unsupervised_val_loader = unsupervised_val_loader
116
117        if self.unsupervised_val_loader is None:
118            val_loader = self.supervised_val_loader
119        else:
120            val_loader = self.unsupervised_train_loader
121
122        # Check that we have at least one of supvervised / unsupervised loss and metric.
123        assert sum((
124            supervised_loss_and_metric is not None,
125            unsupervised_loss_and_metric is not None,
126        )) > 0
127        self.supervised_loss_and_metric = supervised_loss_and_metric
128        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
129
130        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
131        kwargs.pop("train_loader", None)
132        kwargs.pop("val_loader", None)
133        kwargs.pop("metric", None)
134        kwargs.pop("loss", None)
135        super().__init__(
136            model=model, train_loader=train_loader, val_loader=val_loader,
137            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
138        )
139
140        self.unsupervised_loss = unsupervised_loss
141        self.supervised_loss = supervised_loss
142
143        self.pseudo_labeler = pseudo_labeler
144        self.momentum = momentum
145
146        # determine how we initialize the teacher weights (copy or reinitialization)
147        if reinit_teacher is None:
148            # semisupervised training: reinitialize
149            # unsupervised training: copy
150            self.reinit_teacher = supervised_train_loader is not None
151        else:
152            self.reinit_teacher = reinit_teacher
153
154        with torch.no_grad():
155            self.teacher = deepcopy(self.model)
156            if self.reinit_teacher:
157                for layer in self.teacher.children():
158                    if hasattr(layer, "reset_parameters"):
159                        layer.reset_parameters()
160            for param in self.teacher.parameters():
161                param.requires_grad = False
162
163        self._kwargs = kwargs
sampler
unsupervised_train_loader
supervised_train_loader
supervised_val_loader
unsupervised_val_loader
supervised_loss_and_metric
unsupervised_loss_and_metric
unsupervised_loss
supervised_loss
pseudo_labeler
momentum