torch_em.self_training.mean_teacher

  1import time
  2from copy import deepcopy
  3from typing import Callable, Optional
  4
  5import torch
  6import torch_em
  7from torch_em.util import get_constructor_arguments
  8
  9from .logger import SelfTrainingTensorboardLogger
 10from ..transform.invertible_augmentations import InvertibleAugmenter
 11
 12
 13class Dummy(torch.nn.Module):
 14    init_kwargs = {}
 15
 16
 17class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer):
 18    """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
 19
 20    Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780.
 21    It uses a teacher model derived from the student model via EMA of weights
 22    to predict pseudo-labels on unlabeled data. We support two training strategies:
 23    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 24    - Training only on the unsupervised data.
 25
 26    This class expects the following data loaders:
 27    - unsupervised_train_loader: Returns two augmentations of the same input.
 28    - supervised_train_loader (optional): Returns input and labels.
 29    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 30    - supervised_val_loader (optional): Same as supervised_train_loader
 31    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 32
 33    And the following elements to customize the pseudo labeling:
 34    - pseudo_labeler: to compute the psuedo-labels
 35        - Parameters: teacher, teacher_input
 36        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 37    - unsupervised_loss: the loss between model predictions and pseudo labels
 38        - Parameters: model, model_input, pseudo_labels, label_filter
 39        - Returns: loss
 40    - supervised_loss (optional): the supervised loss function
 41        - Parameters: model, input, labels
 42        - Returns: loss
 43    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 44        - Parameters: model, model_input, pseudo_labels, label_filter
 45        - Returns: loss, metric
 46    - supervised_loss_and_metric (optional): the supervised loss function and metric
 47        - Parameters: model, input, labels
 48        - Returns: loss, metric
 49    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 50
 51    If the parameter reinit_teacher is set to true, the teacher weights are re-initialized.
 52    If it is None, the most appropriate initialization scheme for the training approach is chosen:
 53    - semi-supervised training -> reinit, because we usually train a model from scratch
 54    - unsupervised training -> do not reinit, because we usually fine-tune a model
 55
 56    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 57    for setting the ratio between supervised and unsupervised training samples
 58
 59    Args:
 60        model: The model to be trained.
 61        unsupervised_train_loader: The loader for unsupervised training.
 62        unsupervised_loss: The loss for unsupervised training.
 63        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 64        supervised_train_loader: The loader for supervised training.
 65        supervised_loss: The loss for supervised training.
 66        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 67        supervised_loss_and_metric: The loss and metrhic for supervised training.
 68        logger: The logger.
 69        momentum: The momentum value for the exponential moving weight average of the teacher model.
 70        reinit_teacher: Whether to reinit the teacher model before starting the training.
 71        sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
 72        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 73    """
 74
 75    def __init__(
 76        self,
 77        model: torch.nn.Module,
 78        unsupervised_train_loader: torch.utils.data.DataLoader,
 79        unsupervised_loss: torch.utils.data.DataLoader,
 80        pseudo_labeler: Callable,
 81        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 82        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 83        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 84        supervised_loss: Optional[Callable] = None,
 85        unsupervised_loss_and_metric: Optional[Callable] = None,
 86        supervised_loss_and_metric: Optional[Callable] = None,
 87        logger=SelfTrainingTensorboardLogger,
 88        momentum: float = 0.999,
 89        reinit_teacher: Optional[bool] = None,
 90        sampler: Optional[Callable] = None,
 91        **kwargs,
 92    ):
 93        self.sampler = sampler
 94        # Do we have supervised data or not?
 95        if supervised_train_loader is None:
 96            # No. -> We use the unsupervised training logic.
 97            train_loader = unsupervised_train_loader
 98            self._train_epoch_impl = self._train_epoch_unsupervised
 99        else:
100            # Yes. -> We use the semi-supervised training logic.
101            assert supervised_loss is not None
102            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
103                else unsupervised_train_loader
104            self._train_epoch_impl = self._train_epoch_semisupervised
105
106        self.unsupervised_train_loader = unsupervised_train_loader
107        self.supervised_train_loader = supervised_train_loader
108
109        # Check that we have at least one of supvervised / unsupervised val loader.
110        assert sum((
111            supervised_val_loader is not None,
112            unsupervised_val_loader is not None,
113        )) > 0
114        self.supervised_val_loader = supervised_val_loader
115        self.unsupervised_val_loader = unsupervised_val_loader
116
117        if self.unsupervised_val_loader is None:
118            val_loader = self.supervised_val_loader
119        else:
120            val_loader = self.unsupervised_train_loader
121
122        # Check that we have at least one of supvervised / unsupervised loss and metric.
123        assert sum((
124            supervised_loss_and_metric is not None,
125            unsupervised_loss_and_metric is not None,
126        )) > 0
127        self.supervised_loss_and_metric = supervised_loss_and_metric
128        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
129
130        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
131        kwargs.pop("train_loader", None)
132        kwargs.pop("val_loader", None)
133        kwargs.pop("metric", None)
134        kwargs.pop("loss", None)
135        super().__init__(
136            model=model, train_loader=train_loader, val_loader=val_loader,
137            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
138        )
139
140        self.unsupervised_loss = unsupervised_loss
141        self.supervised_loss = supervised_loss
142
143        self.pseudo_labeler = pseudo_labeler
144        self.momentum = momentum
145
146        # determine how we initialize the teacher weights (copy or reinitialization)
147        if reinit_teacher is None:
148            # semisupervised training: reinitialize
149            # unsupervised training: copy
150            self.reinit_teacher = supervised_train_loader is not None
151        else:
152            self.reinit_teacher = reinit_teacher
153
154        with torch.no_grad():
155            self.teacher = deepcopy(self.model)
156            if self.reinit_teacher:
157                for layer in self.teacher.children():
158                    if hasattr(layer, "reset_parameters"):
159                        layer.reset_parameters()
160            for param in self.teacher.parameters():
161                param.requires_grad = False
162
163        self._kwargs = kwargs
164
165    def _momentum_update(self):
166        # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations
167        # to avoid a large gap between teacher and student weights, leading to inconsistent predictions
168        # if we don't reinit this is not necessary
169        if self.reinit_teacher:
170            current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum)
171        else:
172            current_momentum = self.momentum
173
174        for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()):
175            param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum)
176
177    #
178    # functionality for saving checkpoints and initialization
179    #
180
181    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
182        """@private
183        """
184        train_loader_kwargs = get_constructor_arguments(self.train_loader)
185        val_loader_kwargs = get_constructor_arguments(self.val_loader)
186        extra_state = {
187            "teacher_state": self.teacher.state_dict(),
188            "init": {
189                "train_loader_kwargs": train_loader_kwargs,
190                "train_dataset": self.train_loader.dataset,
191                "val_loader_kwargs": val_loader_kwargs,
192                "val_dataset": self.val_loader.dataset,
193                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
194                "loss_kwargs": {},
195                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
196                "metric_kwargs": {},
197            },
198        }
199        extra_state.update(**extra_save_dict)
200        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
201
202    def load_checkpoint(self, checkpoint="best"):
203        """@private
204        """
205        save_dict = super().load_checkpoint(checkpoint)
206        self.teacher.load_state_dict(save_dict["teacher_state"])
207        self.teacher.to(self.device)
208        return save_dict
209
210    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
211        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
212        self.teacher.to(self.device)
213        return best_metric
214
215    #
216    # training and validation functionality
217    #
218
219    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
220        self.model.train()
221
222        n_iter = 0
223        t_per_iter = time.time()
224
225        # Sample from both the supervised and unsupervised loader.
226        for xu1, xu2 in self.unsupervised_train_loader:
227            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
228
229            teacher_input, model_input = xu1, xu2
230
231            with forward_context(), torch.no_grad():
232                # Compute the pseudo labels.
233                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
234
235            # If we have a sampler then check if the current batch matches the condition for inclusion in training.
236            if self.sampler is not None:
237                keep_batch = self.sampler(pseudo_labels, label_filter)
238                if not keep_batch:
239                    continue
240
241            self.optimizer.zero_grad()
242            # Perform unsupervised training
243            with forward_context():
244                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
245            backprop(loss)
246
247            if self.logger is not None:
248                with torch.no_grad(), forward_context():
249                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
250                self.logger.log_train_unsupervised(
251                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
252                )
253                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
254                self.logger.log_lr(self._iteration, lr)
255                if self.pseudo_labeler.confidence_threshold is not None:
256                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
257
258            with torch.no_grad():
259                self._momentum_update()
260
261            self._iteration += 1
262            n_iter += 1
263            if self._iteration >= self.max_iteration:
264                break
265            progress.update(1)
266
267        t_per_iter = (time.time() - t_per_iter) / n_iter
268        return t_per_iter
269
270    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
271        self.model.train()
272
273        n_iter = 0
274        t_per_iter = time.time()
275
276        # Sample from both the supervised and unsupervised loader.
277        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
278            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
279            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
280
281            # Perform supervised training.
282            self.optimizer.zero_grad()
283            with forward_context():
284                # We pass the model, the input and the labels to the supervised loss function,
285                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
286                supervised_loss = self.supervised_loss(self.model, xs, ys)
287
288            teacher_input, model_input = xu1, xu2
289
290            with forward_context(), torch.no_grad():
291                # Compute the pseudo labels.
292                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
293
294            # Perform unsupervised training
295            with forward_context():
296                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
297
298            loss = (supervised_loss + unsupervised_loss) / 2
299            backprop(loss)
300
301            if self.logger is not None:
302                with torch.no_grad(), forward_context():
303                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
304                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
305
306                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
307                self.logger.log_train_unsupervised(
308                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
309                )
310
311                self.logger.log_combined_loss(self._iteration, loss)
312                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
313                self.logger.log_lr(self._iteration, lr)
314                if self.pseudo_labeler.confidence_threshold is not None:
315                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
316
317            with torch.no_grad():
318                self._momentum_update()
319
320            self._iteration += 1
321            n_iter += 1
322            if self._iteration >= self.max_iteration:
323                break
324            progress.update(1)
325
326        t_per_iter = (time.time() - t_per_iter) / n_iter
327        return t_per_iter
328
329    def _validate_supervised(self, forward_context):
330        metric_val = 0.0
331        loss_val = 0.0
332
333        for x, y in self.supervised_val_loader:
334            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
335            with forward_context():
336                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
337            loss_val += loss.item()
338            metric_val += metric.item()
339
340        metric_val /= len(self.supervised_val_loader)
341        loss_val /= len(self.supervised_val_loader)
342
343        if self.logger is not None:
344            with forward_context():
345                pred = self.model(x)
346            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
347
348        return metric_val
349
350    def _validate_unsupervised(self, forward_context):
351        metric_val = 0.0
352        loss_val = 0.0
353
354        for x1, x2 in self.unsupervised_val_loader:
355            x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True)
356            teacher_input, model_input = x1, x2
357            with forward_context():
358                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
359                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
360            loss_val += loss.item()
361            metric_val += metric.item()
362
363        metric_val /= len(self.unsupervised_val_loader)
364        loss_val /= len(self.unsupervised_val_loader)
365
366        if self.logger is not None:
367            with forward_context():
368                pred = self.model(model_input)
369            self.logger.log_validation_unsupervised(
370                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
371            )
372
373        self.pseudo_labeler.step(metric_val, self._epoch)  # NOTE: scheduler added in validation step
374
375        return metric_val
376
377    def _validate_impl(self, forward_context):
378        self.model.eval()
379
380        with torch.no_grad():
381
382            if self.supervised_val_loader is None:
383                supervised_metric = None
384            else:
385                supervised_metric = self._validate_supervised(forward_context)
386
387            if self.unsupervised_val_loader is None:
388                unsupervised_metric = None
389            else:
390                unsupervised_metric = self._validate_unsupervised(forward_context)
391
392        if unsupervised_metric is None:
393            metric = supervised_metric
394        elif supervised_metric is None:
395            metric = unsupervised_metric
396        else:
397            metric = (supervised_metric + unsupervised_metric) / 2
398
399        return metric
400
401
402class MeanTeacherTrainerWithInvertibleAugmentations(MeanTeacherTrainer):
403    """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
404
405    Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780.
406    It uses a teacher model derived from the student model via EMA of weights
407    to predict pseudo-labels on unlabeled data. We support two training strategies:
408    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
409    - Training only on the unsupervised data.
410
411    This class expects the following data loaders:
412    - unsupervised_train_loader: Returns two augmentations of the same input.
413    - supervised_train_loader (optional): Returns input and labels.
414    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
415    - supervised_val_loader (optional): Same as supervised_train_loader
416    At least one of unsupervised_val_loader and supervised_val_loader must be given.
417
418    The augmenter defines separate invertible transforms for teacher and student inputs.
419    Teacher and student views are generated independently, and the corresponding inverse
420    transforms map predictions and pseudo-labels back into a shared reference frame.
421
422    And the following elements to customize the pseudo labeling:
423    - pseudo_labeler: to compute the psuedo-labels
424        - Parameters: teacher, teacher_input
425        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
426    - unsupervised_loss: the loss between model predictions and pseudo labels
427        - Parameters: model, model_input, pseudo_labels, label_filter
428        - Returns: loss
429    - supervised_loss (optional): the supervised loss function
430        - Parameters: model, input, labels
431        - Returns: loss
432    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
433        - Parameters: model, model_input, pseudo_labels, label_filter
434        - Returns: loss, metric
435    - supervised_loss_and_metric (optional): the supervised loss function and metric
436        - Parameters: model, input, labels
437        - Returns: loss, metric
438    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
439
440    If the parameter reinit_teacher is set to true, the teacher weights are re-initialized.
441    If it is None, the most appropriate initialization scheme for the training approach is chosen:
442    - semi-supervised training -> reinit, because we usually train a model from scratch
443    - unsupervised training -> do not reinit, because we usually fine-tune a model
444
445    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
446    for setting the ratio between supervised and unsupervised training samples
447
448    Args:
449        model: The model to be trained.
450        unsupervised_train_loader: The loader for unsupervised training.
451        unsupervised_loss: The loss for unsupervised training.
452        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
453        augmenter: Invertible augmenter providing separate teacher and student transforms,
454            including corresponding inverse transforms to align predictions in a common frame.
455        supervised_train_loader: The loader for supervised training.
456        supervised_loss: The loss for supervised training.
457        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
458        supervised_loss_and_metric: The loss and metrhic for supervised training.
459        logger: The logger.
460        momentum: The momentum value for the exponential moving weight average of the teacher model.
461        reinit_teacher: Whether to reinit the teacher model before starting the training.
462        sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
463        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
464    """
465
466    def __init__(
467        self,
468        model: torch.nn.Module,
469        unsupervised_train_loader: torch.utils.data.DataLoader,
470        unsupervised_loss: torch.utils.data.DataLoader,
471        pseudo_labeler: Callable,
472        augmenter: InvertibleAugmenter,
473        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
474        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
475        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
476        supervised_loss: Optional[Callable] = None,
477        unsupervised_loss_and_metric: Optional[Callable] = None,
478        supervised_loss_and_metric: Optional[Callable] = None,
479        logger=SelfTrainingTensorboardLogger,
480        momentum: float = 0.999,
481        reinit_teacher: Optional[bool] = None,
482        sampler: Optional[Callable] = None,
483        **kwargs,
484    ):
485        super().__init__(
486            model=model,
487            unsupervised_train_loader=unsupervised_train_loader,
488            unsupervised_loss=unsupervised_loss,
489            pseudo_labeler=pseudo_labeler,
490            supervised_train_loader=supervised_train_loader,
491            unsupervised_val_loader=unsupervised_val_loader,
492            supervised_val_loader=supervised_val_loader,
493            supervised_loss=supervised_loss,
494            unsupervised_loss_and_metric=unsupervised_loss_and_metric,
495            supervised_loss_and_metric=supervised_loss_and_metric,
496            logger=logger,
497            momentum=momentum,
498            reinit_teacher=reinit_teacher,
499            sampler=sampler,
500            **kwargs,
501        )
502        self.augmenter = augmenter
503
504    #
505    # training and validation functionality
506    #
507
508    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
509        self.model.train()
510
511        n_iter = 0
512        t_per_iter = time.time()
513
514        # Sample from both the supervised and unsupervised loader.
515        for xu in self.unsupervised_train_loader:
516            self.augmenter.reset_all()
517            xu = xu.to(self.device, non_blocking=True)
518
519            xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu)
520            teacher_input, model_input = xu1, xu2
521
522            with forward_context(), torch.no_grad():
523                # Compute the pseudo labels.
524                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
525                pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
526                label_filter_inv = (
527                    self.augmenter.teacher.reverse_transform(label_filter)
528                    if label_filter is not None else None
529                )
530
531            # If we have a sampler then check if the current batch matches the condition for inclusion in training.
532            if self.sampler is not None:
533                keep_batch = self.sampler(pseudo_labels, label_filter)
534                if not keep_batch:
535                    continue
536
537            self.optimizer.zero_grad()
538            # Perform unsupervised training
539            with forward_context():
540                pred = self.model(model_input)
541                pred_inv = self.augmenter.student.reverse_transform(pred)
542                loss = self.unsupervised_loss(pred_inv, pseudo_labels_inv, label_filter_inv)
543            backprop(loss)
544
545            if self.logger is not None:
546                with torch.no_grad(), forward_context():
547                    pred = pred if self._iteration % self.log_image_interval == 0 else None
548                self.logger.log_train_unsupervised(
549                    self._iteration, loss, xu, xu, pred_inv, pseudo_labels_inv, label_filter_inv
550                )
551                self.logger.log_train_augmentations(
552                    self._iteration, xu1, xu2, pseudo_labels, pred,
553                )
554                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
555                self.logger.log_lr(self._iteration, lr)
556                if self.pseudo_labeler.confidence_threshold is not None:
557                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
558
559            with torch.no_grad():
560                self._momentum_update()
561
562            self._iteration += 1
563            n_iter += 1
564            if self._iteration >= self.max_iteration:
565                break
566            progress.update(1)
567
568        t_per_iter = (time.time() - t_per_iter) / n_iter
569        return t_per_iter
570
571    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
572        self.model.train()
573
574        n_iter = 0
575        t_per_iter = time.time()
576
577        # Sample from both the supervised and unsupervised loader.
578        for (xs, ys), xu in zip(self.supervised_train_loader, self.unsupervised_train_loader):
579            self.augmenter.reset_all()
580            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
581            xu = xu.to(self.device, non_blocking=True)
582
583            xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu)
584            teacher_input, model_input = xu1, xu2
585
586            # Perform supervised training.
587            self.optimizer.zero_grad()
588            with forward_context():
589                # We pass the model, the input and the labels to the supervised loss function,
590                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
591                supervised_pred = self.model(xs)
592                supervised_loss = self.supervised_loss(supervised_pred, ys)
593
594            with forward_context(), torch.no_grad():
595                # Compute the pseudo labels.
596                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
597                pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
598                label_filter_inv = (
599                    self.augmenter.teacher.reverse_transform(label_filter)
600                    if label_filter is not None else None
601                )
602
603            # Perform unsupervised training
604            with forward_context():
605                unsup_pred = self.model(model_input)
606                unsup_pred_inv = self.augmenter.student.reverse_transform(unsup_pred)
607                unsupervised_loss = self.unsupervised_loss(unsup_pred_inv, pseudo_labels_inv, label_filter_inv)
608
609            loss = (supervised_loss + unsupervised_loss) / 2
610            backprop(loss)
611
612            if self.logger is not None:
613                with torch.no_grad(), forward_context():
614                    unsup_pred = unsup_pred if self._iteration % self.log_image_interval == 0 else None
615                    supervised_pred = supervised_pred if self._iteration % self.log_image_interval == 0 else None
616
617                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
618                self.logger.log_train_unsupervised(
619                    self._iteration, unsupervised_loss, xu, xu, unsup_pred_inv, pseudo_labels_inv, label_filter_inv
620                )
621                self.logger.log_train_augmentations(
622                    self._iteration, xu1, xu2, pseudo_labels, unsup_pred,
623                )
624
625                self.logger.log_combined_loss(self._iteration, loss)
626                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
627                self.logger.log_lr(self._iteration, lr)
628                if self.pseudo_labeler.confidence_threshold is not None:
629                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
630
631            with torch.no_grad():
632                self._momentum_update()
633
634            self._iteration += 1
635            n_iter += 1
636            if self._iteration >= self.max_iteration:
637                break
638            progress.update(1)
639
640        t_per_iter = (time.time() - t_per_iter) / n_iter
641        return t_per_iter
642
643    def _validate_supervised(self, forward_context):
644        metric_val = 0.0
645        loss_val = 0.0
646
647        for x, y in self.supervised_val_loader:
648            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
649            with forward_context():
650                pred = self.model(x)
651                loss, metric = self.supervised_loss_and_metric(pred, y)
652            loss_val += loss.item()
653            metric_val += metric.item()
654
655        metric_val /= len(self.supervised_val_loader)
656        loss_val /= len(self.supervised_val_loader)
657
658        if self.logger is not None:
659            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
660
661        return metric_val
662
663    def _validate_unsupervised(self, forward_context):
664        metric_val = 0.0
665        loss_val = 0.0
666
667        for x in self.unsupervised_val_loader:
668            self.augmenter.reset_all()
669            x = x.to(self.device, non_blocking=True)
670            x1, x2 = self.augmenter.teacher.transform(x), self.augmenter.student.transform(x)
671            teacher_input, model_input = x1, x2
672            with forward_context():
673                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
674                pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
675                label_filter_inv = (
676                    self.augmenter.teacher.reverse_transform(label_filter)
677                    if label_filter is not None else None
678                )
679
680                pred = self.model(model_input)
681                pred_inv = self.augmenter.student.reverse_transform(pred)
682                loss, metric = self.unsupervised_loss_and_metric(pred_inv, pseudo_labels_inv, label_filter_inv)
683            loss_val += loss.item()
684            metric_val += metric.item()
685
686        metric_val /= len(self.unsupervised_val_loader)
687        loss_val /= len(self.unsupervised_val_loader)
688
689        if self.logger is not None:
690            self.logger.log_validation_unsupervised(
691                self._iteration, metric_val, loss_val, x, x, pred_inv, pseudo_labels_inv, label_filter_inv
692            )
693            self.logger.log_validation_augmentations(
694                self._iteration, x1, x2, pseudo_labels, pred,
695            )
696
697        self.pseudo_labeler.step(metric_val, self._epoch)
698
699        return metric_val
700
701    def _validate_impl(self, forward_context):
702        self.model.eval()
703
704        with torch.no_grad():
705
706            if self.supervised_val_loader is None:
707                supervised_metric = None
708            else:
709                supervised_metric = self._validate_supervised(forward_context)
710
711            if self.unsupervised_val_loader is None:
712                unsupervised_metric = None
713            else:
714                unsupervised_metric = self._validate_unsupervised(forward_context)
715
716        if unsupervised_metric is None:
717            metric = supervised_metric
718        elif supervised_metric is None:
719            metric = unsupervised_metric
720        else:
721            metric = (supervised_metric + unsupervised_metric) / 2
722
723        return metric
class Dummy(torch.nn.modules.module.Module):
14class Dummy(torch.nn.Module):
15    init_kwargs = {}

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

init_kwargs = {}
class MeanTeacherTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 18class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer):
 19    """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
 20
 21    Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780.
 22    It uses a teacher model derived from the student model via EMA of weights
 23    to predict pseudo-labels on unlabeled data. We support two training strategies:
 24    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 25    - Training only on the unsupervised data.
 26
 27    This class expects the following data loaders:
 28    - unsupervised_train_loader: Returns two augmentations of the same input.
 29    - supervised_train_loader (optional): Returns input and labels.
 30    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 31    - supervised_val_loader (optional): Same as supervised_train_loader
 32    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 33
 34    And the following elements to customize the pseudo labeling:
 35    - pseudo_labeler: to compute the psuedo-labels
 36        - Parameters: teacher, teacher_input
 37        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 38    - unsupervised_loss: the loss between model predictions and pseudo labels
 39        - Parameters: model, model_input, pseudo_labels, label_filter
 40        - Returns: loss
 41    - supervised_loss (optional): the supervised loss function
 42        - Parameters: model, input, labels
 43        - Returns: loss
 44    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 45        - Parameters: model, model_input, pseudo_labels, label_filter
 46        - Returns: loss, metric
 47    - supervised_loss_and_metric (optional): the supervised loss function and metric
 48        - Parameters: model, input, labels
 49        - Returns: loss, metric
 50    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 51
 52    If the parameter reinit_teacher is set to true, the teacher weights are re-initialized.
 53    If it is None, the most appropriate initialization scheme for the training approach is chosen:
 54    - semi-supervised training -> reinit, because we usually train a model from scratch
 55    - unsupervised training -> do not reinit, because we usually fine-tune a model
 56
 57    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 58    for setting the ratio between supervised and unsupervised training samples
 59
 60    Args:
 61        model: The model to be trained.
 62        unsupervised_train_loader: The loader for unsupervised training.
 63        unsupervised_loss: The loss for unsupervised training.
 64        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 65        supervised_train_loader: The loader for supervised training.
 66        supervised_loss: The loss for supervised training.
 67        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 68        supervised_loss_and_metric: The loss and metrhic for supervised training.
 69        logger: The logger.
 70        momentum: The momentum value for the exponential moving weight average of the teacher model.
 71        reinit_teacher: Whether to reinit the teacher model before starting the training.
 72        sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
 73        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 74    """
 75
 76    def __init__(
 77        self,
 78        model: torch.nn.Module,
 79        unsupervised_train_loader: torch.utils.data.DataLoader,
 80        unsupervised_loss: torch.utils.data.DataLoader,
 81        pseudo_labeler: Callable,
 82        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 83        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 84        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 85        supervised_loss: Optional[Callable] = None,
 86        unsupervised_loss_and_metric: Optional[Callable] = None,
 87        supervised_loss_and_metric: Optional[Callable] = None,
 88        logger=SelfTrainingTensorboardLogger,
 89        momentum: float = 0.999,
 90        reinit_teacher: Optional[bool] = None,
 91        sampler: Optional[Callable] = None,
 92        **kwargs,
 93    ):
 94        self.sampler = sampler
 95        # Do we have supervised data or not?
 96        if supervised_train_loader is None:
 97            # No. -> We use the unsupervised training logic.
 98            train_loader = unsupervised_train_loader
 99            self._train_epoch_impl = self._train_epoch_unsupervised
100        else:
101            # Yes. -> We use the semi-supervised training logic.
102            assert supervised_loss is not None
103            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
104                else unsupervised_train_loader
105            self._train_epoch_impl = self._train_epoch_semisupervised
106
107        self.unsupervised_train_loader = unsupervised_train_loader
108        self.supervised_train_loader = supervised_train_loader
109
110        # Check that we have at least one of supvervised / unsupervised val loader.
111        assert sum((
112            supervised_val_loader is not None,
113            unsupervised_val_loader is not None,
114        )) > 0
115        self.supervised_val_loader = supervised_val_loader
116        self.unsupervised_val_loader = unsupervised_val_loader
117
118        if self.unsupervised_val_loader is None:
119            val_loader = self.supervised_val_loader
120        else:
121            val_loader = self.unsupervised_train_loader
122
123        # Check that we have at least one of supvervised / unsupervised loss and metric.
124        assert sum((
125            supervised_loss_and_metric is not None,
126            unsupervised_loss_and_metric is not None,
127        )) > 0
128        self.supervised_loss_and_metric = supervised_loss_and_metric
129        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
130
131        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
132        kwargs.pop("train_loader", None)
133        kwargs.pop("val_loader", None)
134        kwargs.pop("metric", None)
135        kwargs.pop("loss", None)
136        super().__init__(
137            model=model, train_loader=train_loader, val_loader=val_loader,
138            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
139        )
140
141        self.unsupervised_loss = unsupervised_loss
142        self.supervised_loss = supervised_loss
143
144        self.pseudo_labeler = pseudo_labeler
145        self.momentum = momentum
146
147        # determine how we initialize the teacher weights (copy or reinitialization)
148        if reinit_teacher is None:
149            # semisupervised training: reinitialize
150            # unsupervised training: copy
151            self.reinit_teacher = supervised_train_loader is not None
152        else:
153            self.reinit_teacher = reinit_teacher
154
155        with torch.no_grad():
156            self.teacher = deepcopy(self.model)
157            if self.reinit_teacher:
158                for layer in self.teacher.children():
159                    if hasattr(layer, "reset_parameters"):
160                        layer.reset_parameters()
161            for param in self.teacher.parameters():
162                param.requires_grad = False
163
164        self._kwargs = kwargs
165
166    def _momentum_update(self):
167        # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations
168        # to avoid a large gap between teacher and student weights, leading to inconsistent predictions
169        # if we don't reinit this is not necessary
170        if self.reinit_teacher:
171            current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum)
172        else:
173            current_momentum = self.momentum
174
175        for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()):
176            param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum)
177
178    #
179    # functionality for saving checkpoints and initialization
180    #
181
182    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
183        """@private
184        """
185        train_loader_kwargs = get_constructor_arguments(self.train_loader)
186        val_loader_kwargs = get_constructor_arguments(self.val_loader)
187        extra_state = {
188            "teacher_state": self.teacher.state_dict(),
189            "init": {
190                "train_loader_kwargs": train_loader_kwargs,
191                "train_dataset": self.train_loader.dataset,
192                "val_loader_kwargs": val_loader_kwargs,
193                "val_dataset": self.val_loader.dataset,
194                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
195                "loss_kwargs": {},
196                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
197                "metric_kwargs": {},
198            },
199        }
200        extra_state.update(**extra_save_dict)
201        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
202
203    def load_checkpoint(self, checkpoint="best"):
204        """@private
205        """
206        save_dict = super().load_checkpoint(checkpoint)
207        self.teacher.load_state_dict(save_dict["teacher_state"])
208        self.teacher.to(self.device)
209        return save_dict
210
211    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
212        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
213        self.teacher.to(self.device)
214        return best_metric
215
216    #
217    # training and validation functionality
218    #
219
220    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
221        self.model.train()
222
223        n_iter = 0
224        t_per_iter = time.time()
225
226        # Sample from both the supervised and unsupervised loader.
227        for xu1, xu2 in self.unsupervised_train_loader:
228            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
229
230            teacher_input, model_input = xu1, xu2
231
232            with forward_context(), torch.no_grad():
233                # Compute the pseudo labels.
234                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
235
236            # If we have a sampler then check if the current batch matches the condition for inclusion in training.
237            if self.sampler is not None:
238                keep_batch = self.sampler(pseudo_labels, label_filter)
239                if not keep_batch:
240                    continue
241
242            self.optimizer.zero_grad()
243            # Perform unsupervised training
244            with forward_context():
245                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
246            backprop(loss)
247
248            if self.logger is not None:
249                with torch.no_grad(), forward_context():
250                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
251                self.logger.log_train_unsupervised(
252                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
253                )
254                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
255                self.logger.log_lr(self._iteration, lr)
256                if self.pseudo_labeler.confidence_threshold is not None:
257                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
258
259            with torch.no_grad():
260                self._momentum_update()
261
262            self._iteration += 1
263            n_iter += 1
264            if self._iteration >= self.max_iteration:
265                break
266            progress.update(1)
267
268        t_per_iter = (time.time() - t_per_iter) / n_iter
269        return t_per_iter
270
271    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
272        self.model.train()
273
274        n_iter = 0
275        t_per_iter = time.time()
276
277        # Sample from both the supervised and unsupervised loader.
278        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
279            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
280            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
281
282            # Perform supervised training.
283            self.optimizer.zero_grad()
284            with forward_context():
285                # We pass the model, the input and the labels to the supervised loss function,
286                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
287                supervised_loss = self.supervised_loss(self.model, xs, ys)
288
289            teacher_input, model_input = xu1, xu2
290
291            with forward_context(), torch.no_grad():
292                # Compute the pseudo labels.
293                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
294
295            # Perform unsupervised training
296            with forward_context():
297                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
298
299            loss = (supervised_loss + unsupervised_loss) / 2
300            backprop(loss)
301
302            if self.logger is not None:
303                with torch.no_grad(), forward_context():
304                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
305                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
306
307                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
308                self.logger.log_train_unsupervised(
309                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
310                )
311
312                self.logger.log_combined_loss(self._iteration, loss)
313                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
314                self.logger.log_lr(self._iteration, lr)
315                if self.pseudo_labeler.confidence_threshold is not None:
316                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
317
318            with torch.no_grad():
319                self._momentum_update()
320
321            self._iteration += 1
322            n_iter += 1
323            if self._iteration >= self.max_iteration:
324                break
325            progress.update(1)
326
327        t_per_iter = (time.time() - t_per_iter) / n_iter
328        return t_per_iter
329
330    def _validate_supervised(self, forward_context):
331        metric_val = 0.0
332        loss_val = 0.0
333
334        for x, y in self.supervised_val_loader:
335            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
336            with forward_context():
337                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
338            loss_val += loss.item()
339            metric_val += metric.item()
340
341        metric_val /= len(self.supervised_val_loader)
342        loss_val /= len(self.supervised_val_loader)
343
344        if self.logger is not None:
345            with forward_context():
346                pred = self.model(x)
347            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
348
349        return metric_val
350
351    def _validate_unsupervised(self, forward_context):
352        metric_val = 0.0
353        loss_val = 0.0
354
355        for x1, x2 in self.unsupervised_val_loader:
356            x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True)
357            teacher_input, model_input = x1, x2
358            with forward_context():
359                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
360                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
361            loss_val += loss.item()
362            metric_val += metric.item()
363
364        metric_val /= len(self.unsupervised_val_loader)
365        loss_val /= len(self.unsupervised_val_loader)
366
367        if self.logger is not None:
368            with forward_context():
369                pred = self.model(model_input)
370            self.logger.log_validation_unsupervised(
371                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
372            )
373
374        self.pseudo_labeler.step(metric_val, self._epoch)  # NOTE: scheduler added in validation step
375
376        return metric_val
377
378    def _validate_impl(self, forward_context):
379        self.model.eval()
380
381        with torch.no_grad():
382
383            if self.supervised_val_loader is None:
384                supervised_metric = None
385            else:
386                supervised_metric = self._validate_supervised(forward_context)
387
388            if self.unsupervised_val_loader is None:
389                unsupervised_metric = None
390            else:
391                unsupervised_metric = self._validate_unsupervised(forward_context)
392
393        if unsupervised_metric is None:
394            metric = supervised_metric
395        elif supervised_metric is None:
396            metric = unsupervised_metric
397        else:
398            metric = (supervised_metric + unsupervised_metric) / 2
399
400        return metric

Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.

Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. It uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. We support two training strategies:

  • Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
  • Training only on the unsupervised data.

This class expects the following data loaders:

  • unsupervised_train_loader: Returns two augmentations of the same input.
  • supervised_train_loader (optional): Returns input and labels.
  • unsupervised_val_loader (optional): Same as unsupervised_train_loader
  • supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.

And the following elements to customize the pseudo labeling:

  • pseudo_labeler: to compute the psuedo-labels
    • Parameters: teacher, teacher_input
    • Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
  • unsupervised_loss: the loss between model predictions and pseudo labels
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss
  • supervised_loss (optional): the supervised loss function
    • Parameters: model, input, labels
    • Returns: loss
  • unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss, metric
  • supervised_loss_and_metric (optional): the supervised loss function and metric
    • Parameters: model, input, labels
    • Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.

If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. If it is None, the most appropriate initialization scheme for the training approach is chosen:

  • semi-supervised training -> reinit, because we usually train a model from scratch
  • unsupervised training -> do not reinit, because we usually fine-tune a model

Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples

Arguments:
  • model: The model to be trained.
  • unsupervised_train_loader: The loader for unsupervised training.
  • unsupervised_loss: The loss for unsupervised training.
  • pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
  • supervised_train_loader: The loader for supervised training.
  • supervised_loss: The loss for supervised training.
  • unsupervised_loss_and_metric: The loss and metric for unsupervised training.
  • supervised_loss_and_metric: The loss and metrhic for supervised training.
  • logger: The logger.
  • momentum: The momentum value for the exponential moving weight average of the teacher model.
  • reinit_teacher: Whether to reinit the teacher model before starting the training.
  • sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
  • kwargs: Additional keyword arguments for torch_em.trainer.DefaultTrainer.
MeanTeacherTrainer( model: torch.nn.modules.module.Module, unsupervised_train_loader: torch.utils.data.dataloader.DataLoader, unsupervised_loss: torch.utils.data.dataloader.DataLoader, pseudo_labeler: Callable, supervised_train_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, unsupervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_loss: Optional[Callable] = None, unsupervised_loss_and_metric: Optional[Callable] = None, supervised_loss_and_metric: Optional[Callable] = None, logger=<class 'torch_em.self_training.logger.SelfTrainingTensorboardLogger'>, momentum: float = 0.999, reinit_teacher: Optional[bool] = None, sampler: Optional[Callable] = None, **kwargs)
 76    def __init__(
 77        self,
 78        model: torch.nn.Module,
 79        unsupervised_train_loader: torch.utils.data.DataLoader,
 80        unsupervised_loss: torch.utils.data.DataLoader,
 81        pseudo_labeler: Callable,
 82        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 83        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 84        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 85        supervised_loss: Optional[Callable] = None,
 86        unsupervised_loss_and_metric: Optional[Callable] = None,
 87        supervised_loss_and_metric: Optional[Callable] = None,
 88        logger=SelfTrainingTensorboardLogger,
 89        momentum: float = 0.999,
 90        reinit_teacher: Optional[bool] = None,
 91        sampler: Optional[Callable] = None,
 92        **kwargs,
 93    ):
 94        self.sampler = sampler
 95        # Do we have supervised data or not?
 96        if supervised_train_loader is None:
 97            # No. -> We use the unsupervised training logic.
 98            train_loader = unsupervised_train_loader
 99            self._train_epoch_impl = self._train_epoch_unsupervised
100        else:
101            # Yes. -> We use the semi-supervised training logic.
102            assert supervised_loss is not None
103            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
104                else unsupervised_train_loader
105            self._train_epoch_impl = self._train_epoch_semisupervised
106
107        self.unsupervised_train_loader = unsupervised_train_loader
108        self.supervised_train_loader = supervised_train_loader
109
110        # Check that we have at least one of supvervised / unsupervised val loader.
111        assert sum((
112            supervised_val_loader is not None,
113            unsupervised_val_loader is not None,
114        )) > 0
115        self.supervised_val_loader = supervised_val_loader
116        self.unsupervised_val_loader = unsupervised_val_loader
117
118        if self.unsupervised_val_loader is None:
119            val_loader = self.supervised_val_loader
120        else:
121            val_loader = self.unsupervised_train_loader
122
123        # Check that we have at least one of supvervised / unsupervised loss and metric.
124        assert sum((
125            supervised_loss_and_metric is not None,
126            unsupervised_loss_and_metric is not None,
127        )) > 0
128        self.supervised_loss_and_metric = supervised_loss_and_metric
129        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
130
131        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
132        kwargs.pop("train_loader", None)
133        kwargs.pop("val_loader", None)
134        kwargs.pop("metric", None)
135        kwargs.pop("loss", None)
136        super().__init__(
137            model=model, train_loader=train_loader, val_loader=val_loader,
138            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
139        )
140
141        self.unsupervised_loss = unsupervised_loss
142        self.supervised_loss = supervised_loss
143
144        self.pseudo_labeler = pseudo_labeler
145        self.momentum = momentum
146
147        # determine how we initialize the teacher weights (copy or reinitialization)
148        if reinit_teacher is None:
149            # semisupervised training: reinitialize
150            # unsupervised training: copy
151            self.reinit_teacher = supervised_train_loader is not None
152        else:
153            self.reinit_teacher = reinit_teacher
154
155        with torch.no_grad():
156            self.teacher = deepcopy(self.model)
157            if self.reinit_teacher:
158                for layer in self.teacher.children():
159                    if hasattr(layer, "reset_parameters"):
160                        layer.reset_parameters()
161            for param in self.teacher.parameters():
162                param.requires_grad = False
163
164        self._kwargs = kwargs
sampler
unsupervised_train_loader
supervised_train_loader
supervised_val_loader
unsupervised_val_loader
supervised_loss_and_metric
unsupervised_loss_and_metric
unsupervised_loss
supervised_loss
pseudo_labeler
momentum
class MeanTeacherTrainerWithInvertibleAugmentations(MeanTeacherTrainer):
403class MeanTeacherTrainerWithInvertibleAugmentations(MeanTeacherTrainer):
404    """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
405
406    Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780.
407    It uses a teacher model derived from the student model via EMA of weights
408    to predict pseudo-labels on unlabeled data. We support two training strategies:
409    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
410    - Training only on the unsupervised data.
411
412    This class expects the following data loaders:
413    - unsupervised_train_loader: Returns two augmentations of the same input.
414    - supervised_train_loader (optional): Returns input and labels.
415    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
416    - supervised_val_loader (optional): Same as supervised_train_loader
417    At least one of unsupervised_val_loader and supervised_val_loader must be given.
418
419    The augmenter defines separate invertible transforms for teacher and student inputs.
420    Teacher and student views are generated independently, and the corresponding inverse
421    transforms map predictions and pseudo-labels back into a shared reference frame.
422
423    And the following elements to customize the pseudo labeling:
424    - pseudo_labeler: to compute the psuedo-labels
425        - Parameters: teacher, teacher_input
426        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
427    - unsupervised_loss: the loss between model predictions and pseudo labels
428        - Parameters: model, model_input, pseudo_labels, label_filter
429        - Returns: loss
430    - supervised_loss (optional): the supervised loss function
431        - Parameters: model, input, labels
432        - Returns: loss
433    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
434        - Parameters: model, model_input, pseudo_labels, label_filter
435        - Returns: loss, metric
436    - supervised_loss_and_metric (optional): the supervised loss function and metric
437        - Parameters: model, input, labels
438        - Returns: loss, metric
439    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
440
441    If the parameter reinit_teacher is set to true, the teacher weights are re-initialized.
442    If it is None, the most appropriate initialization scheme for the training approach is chosen:
443    - semi-supervised training -> reinit, because we usually train a model from scratch
444    - unsupervised training -> do not reinit, because we usually fine-tune a model
445
446    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
447    for setting the ratio between supervised and unsupervised training samples
448
449    Args:
450        model: The model to be trained.
451        unsupervised_train_loader: The loader for unsupervised training.
452        unsupervised_loss: The loss for unsupervised training.
453        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
454        augmenter: Invertible augmenter providing separate teacher and student transforms,
455            including corresponding inverse transforms to align predictions in a common frame.
456        supervised_train_loader: The loader for supervised training.
457        supervised_loss: The loss for supervised training.
458        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
459        supervised_loss_and_metric: The loss and metrhic for supervised training.
460        logger: The logger.
461        momentum: The momentum value for the exponential moving weight average of the teacher model.
462        reinit_teacher: Whether to reinit the teacher model before starting the training.
463        sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
464        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
465    """
466
467    def __init__(
468        self,
469        model: torch.nn.Module,
470        unsupervised_train_loader: torch.utils.data.DataLoader,
471        unsupervised_loss: torch.utils.data.DataLoader,
472        pseudo_labeler: Callable,
473        augmenter: InvertibleAugmenter,
474        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
475        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
476        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
477        supervised_loss: Optional[Callable] = None,
478        unsupervised_loss_and_metric: Optional[Callable] = None,
479        supervised_loss_and_metric: Optional[Callable] = None,
480        logger=SelfTrainingTensorboardLogger,
481        momentum: float = 0.999,
482        reinit_teacher: Optional[bool] = None,
483        sampler: Optional[Callable] = None,
484        **kwargs,
485    ):
486        super().__init__(
487            model=model,
488            unsupervised_train_loader=unsupervised_train_loader,
489            unsupervised_loss=unsupervised_loss,
490            pseudo_labeler=pseudo_labeler,
491            supervised_train_loader=supervised_train_loader,
492            unsupervised_val_loader=unsupervised_val_loader,
493            supervised_val_loader=supervised_val_loader,
494            supervised_loss=supervised_loss,
495            unsupervised_loss_and_metric=unsupervised_loss_and_metric,
496            supervised_loss_and_metric=supervised_loss_and_metric,
497            logger=logger,
498            momentum=momentum,
499            reinit_teacher=reinit_teacher,
500            sampler=sampler,
501            **kwargs,
502        )
503        self.augmenter = augmenter
504
505    #
506    # training and validation functionality
507    #
508
509    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
510        self.model.train()
511
512        n_iter = 0
513        t_per_iter = time.time()
514
515        # Sample from both the supervised and unsupervised loader.
516        for xu in self.unsupervised_train_loader:
517            self.augmenter.reset_all()
518            xu = xu.to(self.device, non_blocking=True)
519
520            xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu)
521            teacher_input, model_input = xu1, xu2
522
523            with forward_context(), torch.no_grad():
524                # Compute the pseudo labels.
525                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
526                pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
527                label_filter_inv = (
528                    self.augmenter.teacher.reverse_transform(label_filter)
529                    if label_filter is not None else None
530                )
531
532            # If we have a sampler then check if the current batch matches the condition for inclusion in training.
533            if self.sampler is not None:
534                keep_batch = self.sampler(pseudo_labels, label_filter)
535                if not keep_batch:
536                    continue
537
538            self.optimizer.zero_grad()
539            # Perform unsupervised training
540            with forward_context():
541                pred = self.model(model_input)
542                pred_inv = self.augmenter.student.reverse_transform(pred)
543                loss = self.unsupervised_loss(pred_inv, pseudo_labels_inv, label_filter_inv)
544            backprop(loss)
545
546            if self.logger is not None:
547                with torch.no_grad(), forward_context():
548                    pred = pred if self._iteration % self.log_image_interval == 0 else None
549                self.logger.log_train_unsupervised(
550                    self._iteration, loss, xu, xu, pred_inv, pseudo_labels_inv, label_filter_inv
551                )
552                self.logger.log_train_augmentations(
553                    self._iteration, xu1, xu2, pseudo_labels, pred,
554                )
555                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
556                self.logger.log_lr(self._iteration, lr)
557                if self.pseudo_labeler.confidence_threshold is not None:
558                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
559
560            with torch.no_grad():
561                self._momentum_update()
562
563            self._iteration += 1
564            n_iter += 1
565            if self._iteration >= self.max_iteration:
566                break
567            progress.update(1)
568
569        t_per_iter = (time.time() - t_per_iter) / n_iter
570        return t_per_iter
571
572    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
573        self.model.train()
574
575        n_iter = 0
576        t_per_iter = time.time()
577
578        # Sample from both the supervised and unsupervised loader.
579        for (xs, ys), xu in zip(self.supervised_train_loader, self.unsupervised_train_loader):
580            self.augmenter.reset_all()
581            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
582            xu = xu.to(self.device, non_blocking=True)
583
584            xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu)
585            teacher_input, model_input = xu1, xu2
586
587            # Perform supervised training.
588            self.optimizer.zero_grad()
589            with forward_context():
590                # We pass the model, the input and the labels to the supervised loss function,
591                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
592                supervised_pred = self.model(xs)
593                supervised_loss = self.supervised_loss(supervised_pred, ys)
594
595            with forward_context(), torch.no_grad():
596                # Compute the pseudo labels.
597                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
598                pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
599                label_filter_inv = (
600                    self.augmenter.teacher.reverse_transform(label_filter)
601                    if label_filter is not None else None
602                )
603
604            # Perform unsupervised training
605            with forward_context():
606                unsup_pred = self.model(model_input)
607                unsup_pred_inv = self.augmenter.student.reverse_transform(unsup_pred)
608                unsupervised_loss = self.unsupervised_loss(unsup_pred_inv, pseudo_labels_inv, label_filter_inv)
609
610            loss = (supervised_loss + unsupervised_loss) / 2
611            backprop(loss)
612
613            if self.logger is not None:
614                with torch.no_grad(), forward_context():
615                    unsup_pred = unsup_pred if self._iteration % self.log_image_interval == 0 else None
616                    supervised_pred = supervised_pred if self._iteration % self.log_image_interval == 0 else None
617
618                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
619                self.logger.log_train_unsupervised(
620                    self._iteration, unsupervised_loss, xu, xu, unsup_pred_inv, pseudo_labels_inv, label_filter_inv
621                )
622                self.logger.log_train_augmentations(
623                    self._iteration, xu1, xu2, pseudo_labels, unsup_pred,
624                )
625
626                self.logger.log_combined_loss(self._iteration, loss)
627                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
628                self.logger.log_lr(self._iteration, lr)
629                if self.pseudo_labeler.confidence_threshold is not None:
630                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
631
632            with torch.no_grad():
633                self._momentum_update()
634
635            self._iteration += 1
636            n_iter += 1
637            if self._iteration >= self.max_iteration:
638                break
639            progress.update(1)
640
641        t_per_iter = (time.time() - t_per_iter) / n_iter
642        return t_per_iter
643
644    def _validate_supervised(self, forward_context):
645        metric_val = 0.0
646        loss_val = 0.0
647
648        for x, y in self.supervised_val_loader:
649            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
650            with forward_context():
651                pred = self.model(x)
652                loss, metric = self.supervised_loss_and_metric(pred, y)
653            loss_val += loss.item()
654            metric_val += metric.item()
655
656        metric_val /= len(self.supervised_val_loader)
657        loss_val /= len(self.supervised_val_loader)
658
659        if self.logger is not None:
660            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
661
662        return metric_val
663
664    def _validate_unsupervised(self, forward_context):
665        metric_val = 0.0
666        loss_val = 0.0
667
668        for x in self.unsupervised_val_loader:
669            self.augmenter.reset_all()
670            x = x.to(self.device, non_blocking=True)
671            x1, x2 = self.augmenter.teacher.transform(x), self.augmenter.student.transform(x)
672            teacher_input, model_input = x1, x2
673            with forward_context():
674                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
675                pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
676                label_filter_inv = (
677                    self.augmenter.teacher.reverse_transform(label_filter)
678                    if label_filter is not None else None
679                )
680
681                pred = self.model(model_input)
682                pred_inv = self.augmenter.student.reverse_transform(pred)
683                loss, metric = self.unsupervised_loss_and_metric(pred_inv, pseudo_labels_inv, label_filter_inv)
684            loss_val += loss.item()
685            metric_val += metric.item()
686
687        metric_val /= len(self.unsupervised_val_loader)
688        loss_val /= len(self.unsupervised_val_loader)
689
690        if self.logger is not None:
691            self.logger.log_validation_unsupervised(
692                self._iteration, metric_val, loss_val, x, x, pred_inv, pseudo_labels_inv, label_filter_inv
693            )
694            self.logger.log_validation_augmentations(
695                self._iteration, x1, x2, pseudo_labels, pred,
696            )
697
698        self.pseudo_labeler.step(metric_val, self._epoch)
699
700        return metric_val
701
702    def _validate_impl(self, forward_context):
703        self.model.eval()
704
705        with torch.no_grad():
706
707            if self.supervised_val_loader is None:
708                supervised_metric = None
709            else:
710                supervised_metric = self._validate_supervised(forward_context)
711
712            if self.unsupervised_val_loader is None:
713                unsupervised_metric = None
714            else:
715                unsupervised_metric = self._validate_unsupervised(forward_context)
716
717        if unsupervised_metric is None:
718            metric = supervised_metric
719        elif supervised_metric is None:
720            metric = unsupervised_metric
721        else:
722            metric = (supervised_metric + unsupervised_metric) / 2
723
724        return metric

Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.

Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. It uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. We support two training strategies:

  • Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
  • Training only on the unsupervised data.

This class expects the following data loaders:

  • unsupervised_train_loader: Returns two augmentations of the same input.
  • supervised_train_loader (optional): Returns input and labels.
  • unsupervised_val_loader (optional): Same as unsupervised_train_loader
  • supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.

The augmenter defines separate invertible transforms for teacher and student inputs. Teacher and student views are generated independently, and the corresponding inverse transforms map predictions and pseudo-labels back into a shared reference frame.

And the following elements to customize the pseudo labeling:

  • pseudo_labeler: to compute the psuedo-labels
    • Parameters: teacher, teacher_input
    • Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
  • unsupervised_loss: the loss between model predictions and pseudo labels
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss
  • supervised_loss (optional): the supervised loss function
    • Parameters: model, input, labels
    • Returns: loss
  • unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss, metric
  • supervised_loss_and_metric (optional): the supervised loss function and metric
    • Parameters: model, input, labels
    • Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.

If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. If it is None, the most appropriate initialization scheme for the training approach is chosen:

  • semi-supervised training -> reinit, because we usually train a model from scratch
  • unsupervised training -> do not reinit, because we usually fine-tune a model

Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples

Arguments:
  • model: The model to be trained.
  • unsupervised_train_loader: The loader for unsupervised training.
  • unsupervised_loss: The loss for unsupervised training.
  • pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
  • augmenter: Invertible augmenter providing separate teacher and student transforms, including corresponding inverse transforms to align predictions in a common frame.
  • supervised_train_loader: The loader for supervised training.
  • supervised_loss: The loss for supervised training.
  • unsupervised_loss_and_metric: The loss and metric for unsupervised training.
  • supervised_loss_and_metric: The loss and metrhic for supervised training.
  • logger: The logger.
  • momentum: The momentum value for the exponential moving weight average of the teacher model.
  • reinit_teacher: Whether to reinit the teacher model before starting the training.
  • sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
  • kwargs: Additional keyword arguments for torch_em.trainer.DefaultTrainer.
MeanTeacherTrainerWithInvertibleAugmentations( model: torch.nn.modules.module.Module, unsupervised_train_loader: torch.utils.data.dataloader.DataLoader, unsupervised_loss: torch.utils.data.dataloader.DataLoader, pseudo_labeler: Callable, augmenter: torch_em.transform.invertible_augmentations.InvertibleAugmenter, supervised_train_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, unsupervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_loss: Optional[Callable] = None, unsupervised_loss_and_metric: Optional[Callable] = None, supervised_loss_and_metric: Optional[Callable] = None, logger=<class 'torch_em.self_training.logger.SelfTrainingTensorboardLogger'>, momentum: float = 0.999, reinit_teacher: Optional[bool] = None, sampler: Optional[Callable] = None, **kwargs)
467    def __init__(
468        self,
469        model: torch.nn.Module,
470        unsupervised_train_loader: torch.utils.data.DataLoader,
471        unsupervised_loss: torch.utils.data.DataLoader,
472        pseudo_labeler: Callable,
473        augmenter: InvertibleAugmenter,
474        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
475        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
476        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
477        supervised_loss: Optional[Callable] = None,
478        unsupervised_loss_and_metric: Optional[Callable] = None,
479        supervised_loss_and_metric: Optional[Callable] = None,
480        logger=SelfTrainingTensorboardLogger,
481        momentum: float = 0.999,
482        reinit_teacher: Optional[bool] = None,
483        sampler: Optional[Callable] = None,
484        **kwargs,
485    ):
486        super().__init__(
487            model=model,
488            unsupervised_train_loader=unsupervised_train_loader,
489            unsupervised_loss=unsupervised_loss,
490            pseudo_labeler=pseudo_labeler,
491            supervised_train_loader=supervised_train_loader,
492            unsupervised_val_loader=unsupervised_val_loader,
493            supervised_val_loader=supervised_val_loader,
494            supervised_loss=supervised_loss,
495            unsupervised_loss_and_metric=unsupervised_loss_and_metric,
496            supervised_loss_and_metric=supervised_loss_and_metric,
497            logger=logger,
498            momentum=momentum,
499            reinit_teacher=reinit_teacher,
500            sampler=sampler,
501            **kwargs,
502        )
503        self.augmenter = augmenter
augmenter