torch_em.self_training.mean_teacher

  1import time
  2from copy import deepcopy
  3from typing import Callable, Optional
  4
  5import torch
  6import torch_em
  7from torch_em.util import get_constructor_arguments
  8
  9from .logger import SelfTrainingTensorboardLogger
 10
 11
 12class Dummy(torch.nn.Module):
 13    init_kwargs = {}
 14
 15
 16class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer):
 17    """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
 18
 19    Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780.
 20    It uses a teacher model derived from the student model via EMA of weights
 21    to predict pseudo-labels on unlabeled data. We support two training strategies:
 22    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 23    - Training only on the unsupervised data.
 24
 25    This class expects the following data loaders:
 26    - unsupervised_train_loader: Returns two augmentations of the same input.
 27    - supervised_train_loader (optional): Returns input and labels.
 28    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 29    - supervised_val_loader (optional): Same as supervised_train_loader
 30    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 31
 32    And the following elements to customize the pseudo labeling:
 33    - pseudo_labeler: to compute the psuedo-labels
 34        - Parameters: teacher, teacher_input
 35        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 36    - unsupervised_loss: the loss between model predictions and pseudo labels
 37        - Parameters: model, model_input, pseudo_labels, label_filter
 38        - Returns: loss
 39    - supervised_loss (optional): the supervised loss function
 40        - Parameters: model, input, labels
 41        - Returns: loss
 42    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 43        - Parameters: model, model_input, pseudo_labels, label_filter
 44        - Returns: loss, metric
 45    - supervised_loss_and_metric (optional): the supervised loss function and metric
 46        - Parameters: model, input, labels
 47        - Returns: loss, metric
 48    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 49
 50    If the parameter reinit_teacher is set to true, the teacher weights are re-initialized.
 51    If it is None, the most appropriate initialization scheme for the training approach is chosen:
 52    - semi-supervised training -> reinit, because we usually train a model from scratch
 53    - unsupervised training -> do not reinit, because we usually fine-tune a model
 54
 55    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 56    for setting the ratio between supervised and unsupervised training samples
 57
 58    Args:
 59        model: The model to be trained.
 60        unsupervised_train_loader: The loader for unsupervised training.
 61        unsupervised_loss: The loss for unsupervised training.
 62        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 63        supervised_train_loader: The loader for supervised training.
 64        supervised_loss: The loss for supervised training.
 65        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 66        supervised_loss_and_metric: The loss and metrhic for supervised training.
 67        logger: The logger.
 68        momentum: The momentum value for the exponential moving weight average of the teacher model.
 69        reinit_teacher: Whether to reinit the teacher model before starting the training.
 70        sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
 71        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 72    """
 73
 74    def __init__(
 75        self,
 76        model: torch.nn.Module,
 77        unsupervised_train_loader: torch.utils.data.DataLoader,
 78        unsupervised_loss: torch.utils.data.DataLoader,
 79        pseudo_labeler: Callable,
 80        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 81        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 82        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 83        supervised_loss: Optional[Callable] = None,
 84        unsupervised_loss_and_metric: Optional[Callable] = None,
 85        supervised_loss_and_metric: Optional[Callable] = None,
 86        logger=SelfTrainingTensorboardLogger,
 87        momentum: float = 0.999,
 88        reinit_teacher: Optional[bool] = None,
 89        sampler: Optional[Callable] = None,
 90        **kwargs,
 91    ):
 92        self.sampler = sampler
 93        # Do we have supervised data or not?
 94        if supervised_train_loader is None:
 95            # No. -> We use the unsupervised training logic.
 96            train_loader = unsupervised_train_loader
 97            self._train_epoch_impl = self._train_epoch_unsupervised
 98        else:
 99            # Yes. -> We use the semi-supervised training logic.
100            assert supervised_loss is not None
101            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
102                else unsupervised_train_loader
103            self._train_epoch_impl = self._train_epoch_semisupervised
104
105        self.unsupervised_train_loader = unsupervised_train_loader
106        self.supervised_train_loader = supervised_train_loader
107
108        # Check that we have at least one of supvervised / unsupervised val loader.
109        assert sum((
110            supervised_val_loader is not None,
111            unsupervised_val_loader is not None,
112        )) > 0
113        self.supervised_val_loader = supervised_val_loader
114        self.unsupervised_val_loader = unsupervised_val_loader
115
116        if self.unsupervised_val_loader is None:
117            val_loader = self.supervised_val_loader
118        else:
119            val_loader = self.unsupervised_train_loader
120
121        # Check that we have at least one of supvervised / unsupervised loss and metric.
122        assert sum((
123            supervised_loss_and_metric is not None,
124            unsupervised_loss_and_metric is not None,
125        )) > 0
126        self.supervised_loss_and_metric = supervised_loss_and_metric
127        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
128
129        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
130        kwargs.pop("train_loader", None)
131        kwargs.pop("val_loader", None)
132        kwargs.pop("metric", None)
133        kwargs.pop("loss", None)
134        super().__init__(
135            model=model, train_loader=train_loader, val_loader=val_loader,
136            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
137        )
138
139        self.unsupervised_loss = unsupervised_loss
140        self.supervised_loss = supervised_loss
141
142        self.pseudo_labeler = pseudo_labeler
143        self.momentum = momentum
144
145        # determine how we initialize the teacher weights (copy or reinitialization)
146        if reinit_teacher is None:
147            # semisupervised training: reinitialize
148            # unsupervised training: copy
149            self.reinit_teacher = supervised_train_loader is not None
150        else:
151            self.reinit_teacher = reinit_teacher
152
153        with torch.no_grad():
154            self.teacher = deepcopy(self.model)
155            if self.reinit_teacher:
156                for layer in self.teacher.children():
157                    if hasattr(layer, "reset_parameters"):
158                        layer.reset_parameters()
159            for param in self.teacher.parameters():
160                param.requires_grad = False
161
162        self._kwargs = kwargs
163
164    def _momentum_update(self):
165        # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations
166        # to avoid a large gap between teacher and student weights, leading to inconsistent predictions
167        # if we don't reinit this is not necessary
168        if self.reinit_teacher:
169            current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum)
170        else:
171            current_momentum = self.momentum
172
173        for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()):
174            param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum)
175
176    #
177    # functionality for saving checkpoints and initialization
178    #
179
180    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
181        """@private
182        """
183        train_loader_kwargs = get_constructor_arguments(self.train_loader)
184        val_loader_kwargs = get_constructor_arguments(self.val_loader)
185        extra_state = {
186            "teacher_state": self.teacher.state_dict(),
187            "init": {
188                "train_loader_kwargs": train_loader_kwargs,
189                "train_dataset": self.train_loader.dataset,
190                "val_loader_kwargs": val_loader_kwargs,
191                "val_dataset": self.val_loader.dataset,
192                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
193                "loss_kwargs": {},
194                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
195                "metric_kwargs": {},
196            },
197        }
198        extra_state.update(**extra_save_dict)
199        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
200
201    def load_checkpoint(self, checkpoint="best"):
202        """@private
203        """
204        save_dict = super().load_checkpoint(checkpoint)
205        self.teacher.load_state_dict(save_dict["teacher_state"])
206        self.teacher.to(self.device)
207        return save_dict
208
209    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
210        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
211        self.teacher.to(self.device)
212        return best_metric
213
214    #
215    # training and validation functionality
216    #
217
218    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
219        self.model.train()
220
221        n_iter = 0
222        t_per_iter = time.time()
223
224        # Sample from both the supervised and unsupervised loader.
225        for xu1, xu2 in self.unsupervised_train_loader:
226            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
227
228            teacher_input, model_input = xu1, xu2
229
230            with forward_context(), torch.no_grad():
231                # Compute the pseudo labels.
232                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
233
234            # If we have a sampler then check if the current batch matches the condition for inclusion in training.
235            if self.sampler is not None:
236                keep_batch = self.sampler(pseudo_labels, label_filter)
237                if not keep_batch:
238                    continue
239
240            self.optimizer.zero_grad()
241            # Perform unsupervised training
242            with forward_context():
243                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
244            backprop(loss)
245
246            if self.logger is not None:
247                with torch.no_grad(), forward_context():
248                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
249                self.logger.log_train_unsupervised(
250                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
251                )
252                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
253                self.logger.log_lr(self._iteration, lr)
254
255            with torch.no_grad():
256                self._momentum_update()
257
258            self._iteration += 1
259            n_iter += 1
260            if self._iteration >= self.max_iteration:
261                break
262            progress.update(1)
263
264        t_per_iter = (time.time() - t_per_iter) / n_iter
265        return t_per_iter
266
267    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
268        self.model.train()
269
270        n_iter = 0
271        t_per_iter = time.time()
272
273        # Sample from both the supervised and unsupervised loader.
274        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
275            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
276            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
277
278            # Perform supervised training.
279            self.optimizer.zero_grad()
280            with forward_context():
281                # We pass the model, the input and the labels to the supervised loss function,
282                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
283                supervised_loss = self.supervised_loss(self.model, xs, ys)
284
285            teacher_input, model_input = xu1, xu2
286
287            with forward_context(), torch.no_grad():
288                # Compute the pseudo labels.
289                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
290
291            # Perform unsupervised training
292            with forward_context():
293                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
294
295            loss = (supervised_loss + unsupervised_loss) / 2
296            backprop(loss)
297
298            if self.logger is not None:
299                with torch.no_grad(), forward_context():
300                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
301                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
302
303                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
304                self.logger.log_train_unsupervised(
305                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
306                )
307
308                self.logger.log_combined_loss(self._iteration, loss)
309                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
310                self.logger.log_lr(self._iteration, lr)
311
312            with torch.no_grad():
313                self._momentum_update()
314
315            self._iteration += 1
316            n_iter += 1
317            if self._iteration >= self.max_iteration:
318                break
319            progress.update(1)
320
321        t_per_iter = (time.time() - t_per_iter) / n_iter
322        return t_per_iter
323
324    def _validate_supervised(self, forward_context):
325        metric_val = 0.0
326        loss_val = 0.0
327
328        for x, y in self.supervised_val_loader:
329            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
330            with forward_context():
331                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
332            loss_val += loss.item()
333            metric_val += metric.item()
334
335        metric_val /= len(self.supervised_val_loader)
336        loss_val /= len(self.supervised_val_loader)
337
338        if self.logger is not None:
339            with forward_context():
340                pred = self.model(x)
341            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
342
343        return metric_val
344
345    def _validate_unsupervised(self, forward_context):
346        metric_val = 0.0
347        loss_val = 0.0
348
349        for x1, x2 in self.unsupervised_val_loader:
350            x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True)
351            teacher_input, model_input = x1, x2
352            with forward_context():
353                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
354                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
355            loss_val += loss.item()
356            metric_val += metric.item()
357
358        metric_val /= len(self.unsupervised_val_loader)
359        loss_val /= len(self.unsupervised_val_loader)
360
361        if self.logger is not None:
362            with forward_context():
363                pred = self.model(model_input)
364            self.logger.log_validation_unsupervised(
365                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
366            )
367
368        return metric_val
369
370    def _validate_impl(self, forward_context):
371        self.model.eval()
372
373        with torch.no_grad():
374
375            if self.supervised_val_loader is None:
376                supervised_metric = None
377            else:
378                supervised_metric = self._validate_supervised(forward_context)
379
380            if self.unsupervised_val_loader is None:
381                unsupervised_metric = None
382            else:
383                unsupervised_metric = self._validate_unsupervised(forward_context)
384
385        if unsupervised_metric is None:
386            metric = supervised_metric
387        elif supervised_metric is None:
388            metric = unsupervised_metric
389        else:
390            metric = (supervised_metric + unsupervised_metric) / 2
391
392        return metric
class Dummy(torch.nn.modules.module.Module):
13class Dummy(torch.nn.Module):
14    init_kwargs = {}

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

init_kwargs = {}
class MeanTeacherTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 17class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer):
 18    """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
 19
 20    Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780.
 21    It uses a teacher model derived from the student model via EMA of weights
 22    to predict pseudo-labels on unlabeled data. We support two training strategies:
 23    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 24    - Training only on the unsupervised data.
 25
 26    This class expects the following data loaders:
 27    - unsupervised_train_loader: Returns two augmentations of the same input.
 28    - supervised_train_loader (optional): Returns input and labels.
 29    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 30    - supervised_val_loader (optional): Same as supervised_train_loader
 31    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 32
 33    And the following elements to customize the pseudo labeling:
 34    - pseudo_labeler: to compute the psuedo-labels
 35        - Parameters: teacher, teacher_input
 36        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 37    - unsupervised_loss: the loss between model predictions and pseudo labels
 38        - Parameters: model, model_input, pseudo_labels, label_filter
 39        - Returns: loss
 40    - supervised_loss (optional): the supervised loss function
 41        - Parameters: model, input, labels
 42        - Returns: loss
 43    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 44        - Parameters: model, model_input, pseudo_labels, label_filter
 45        - Returns: loss, metric
 46    - supervised_loss_and_metric (optional): the supervised loss function and metric
 47        - Parameters: model, input, labels
 48        - Returns: loss, metric
 49    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 50
 51    If the parameter reinit_teacher is set to true, the teacher weights are re-initialized.
 52    If it is None, the most appropriate initialization scheme for the training approach is chosen:
 53    - semi-supervised training -> reinit, because we usually train a model from scratch
 54    - unsupervised training -> do not reinit, because we usually fine-tune a model
 55
 56    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 57    for setting the ratio between supervised and unsupervised training samples
 58
 59    Args:
 60        model: The model to be trained.
 61        unsupervised_train_loader: The loader for unsupervised training.
 62        unsupervised_loss: The loss for unsupervised training.
 63        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 64        supervised_train_loader: The loader for supervised training.
 65        supervised_loss: The loss for supervised training.
 66        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 67        supervised_loss_and_metric: The loss and metrhic for supervised training.
 68        logger: The logger.
 69        momentum: The momentum value for the exponential moving weight average of the teacher model.
 70        reinit_teacher: Whether to reinit the teacher model before starting the training.
 71        sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
 72        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 73    """
 74
 75    def __init__(
 76        self,
 77        model: torch.nn.Module,
 78        unsupervised_train_loader: torch.utils.data.DataLoader,
 79        unsupervised_loss: torch.utils.data.DataLoader,
 80        pseudo_labeler: Callable,
 81        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 82        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 83        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 84        supervised_loss: Optional[Callable] = None,
 85        unsupervised_loss_and_metric: Optional[Callable] = None,
 86        supervised_loss_and_metric: Optional[Callable] = None,
 87        logger=SelfTrainingTensorboardLogger,
 88        momentum: float = 0.999,
 89        reinit_teacher: Optional[bool] = None,
 90        sampler: Optional[Callable] = None,
 91        **kwargs,
 92    ):
 93        self.sampler = sampler
 94        # Do we have supervised data or not?
 95        if supervised_train_loader is None:
 96            # No. -> We use the unsupervised training logic.
 97            train_loader = unsupervised_train_loader
 98            self._train_epoch_impl = self._train_epoch_unsupervised
 99        else:
100            # Yes. -> We use the semi-supervised training logic.
101            assert supervised_loss is not None
102            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
103                else unsupervised_train_loader
104            self._train_epoch_impl = self._train_epoch_semisupervised
105
106        self.unsupervised_train_loader = unsupervised_train_loader
107        self.supervised_train_loader = supervised_train_loader
108
109        # Check that we have at least one of supvervised / unsupervised val loader.
110        assert sum((
111            supervised_val_loader is not None,
112            unsupervised_val_loader is not None,
113        )) > 0
114        self.supervised_val_loader = supervised_val_loader
115        self.unsupervised_val_loader = unsupervised_val_loader
116
117        if self.unsupervised_val_loader is None:
118            val_loader = self.supervised_val_loader
119        else:
120            val_loader = self.unsupervised_train_loader
121
122        # Check that we have at least one of supvervised / unsupervised loss and metric.
123        assert sum((
124            supervised_loss_and_metric is not None,
125            unsupervised_loss_and_metric is not None,
126        )) > 0
127        self.supervised_loss_and_metric = supervised_loss_and_metric
128        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
129
130        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
131        kwargs.pop("train_loader", None)
132        kwargs.pop("val_loader", None)
133        kwargs.pop("metric", None)
134        kwargs.pop("loss", None)
135        super().__init__(
136            model=model, train_loader=train_loader, val_loader=val_loader,
137            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
138        )
139
140        self.unsupervised_loss = unsupervised_loss
141        self.supervised_loss = supervised_loss
142
143        self.pseudo_labeler = pseudo_labeler
144        self.momentum = momentum
145
146        # determine how we initialize the teacher weights (copy or reinitialization)
147        if reinit_teacher is None:
148            # semisupervised training: reinitialize
149            # unsupervised training: copy
150            self.reinit_teacher = supervised_train_loader is not None
151        else:
152            self.reinit_teacher = reinit_teacher
153
154        with torch.no_grad():
155            self.teacher = deepcopy(self.model)
156            if self.reinit_teacher:
157                for layer in self.teacher.children():
158                    if hasattr(layer, "reset_parameters"):
159                        layer.reset_parameters()
160            for param in self.teacher.parameters():
161                param.requires_grad = False
162
163        self._kwargs = kwargs
164
165    def _momentum_update(self):
166        # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations
167        # to avoid a large gap between teacher and student weights, leading to inconsistent predictions
168        # if we don't reinit this is not necessary
169        if self.reinit_teacher:
170            current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum)
171        else:
172            current_momentum = self.momentum
173
174        for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()):
175            param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum)
176
177    #
178    # functionality for saving checkpoints and initialization
179    #
180
181    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
182        """@private
183        """
184        train_loader_kwargs = get_constructor_arguments(self.train_loader)
185        val_loader_kwargs = get_constructor_arguments(self.val_loader)
186        extra_state = {
187            "teacher_state": self.teacher.state_dict(),
188            "init": {
189                "train_loader_kwargs": train_loader_kwargs,
190                "train_dataset": self.train_loader.dataset,
191                "val_loader_kwargs": val_loader_kwargs,
192                "val_dataset": self.val_loader.dataset,
193                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
194                "loss_kwargs": {},
195                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
196                "metric_kwargs": {},
197            },
198        }
199        extra_state.update(**extra_save_dict)
200        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
201
202    def load_checkpoint(self, checkpoint="best"):
203        """@private
204        """
205        save_dict = super().load_checkpoint(checkpoint)
206        self.teacher.load_state_dict(save_dict["teacher_state"])
207        self.teacher.to(self.device)
208        return save_dict
209
210    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
211        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
212        self.teacher.to(self.device)
213        return best_metric
214
215    #
216    # training and validation functionality
217    #
218
219    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
220        self.model.train()
221
222        n_iter = 0
223        t_per_iter = time.time()
224
225        # Sample from both the supervised and unsupervised loader.
226        for xu1, xu2 in self.unsupervised_train_loader:
227            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
228
229            teacher_input, model_input = xu1, xu2
230
231            with forward_context(), torch.no_grad():
232                # Compute the pseudo labels.
233                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
234
235            # If we have a sampler then check if the current batch matches the condition for inclusion in training.
236            if self.sampler is not None:
237                keep_batch = self.sampler(pseudo_labels, label_filter)
238                if not keep_batch:
239                    continue
240
241            self.optimizer.zero_grad()
242            # Perform unsupervised training
243            with forward_context():
244                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
245            backprop(loss)
246
247            if self.logger is not None:
248                with torch.no_grad(), forward_context():
249                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
250                self.logger.log_train_unsupervised(
251                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
252                )
253                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
254                self.logger.log_lr(self._iteration, lr)
255
256            with torch.no_grad():
257                self._momentum_update()
258
259            self._iteration += 1
260            n_iter += 1
261            if self._iteration >= self.max_iteration:
262                break
263            progress.update(1)
264
265        t_per_iter = (time.time() - t_per_iter) / n_iter
266        return t_per_iter
267
268    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
269        self.model.train()
270
271        n_iter = 0
272        t_per_iter = time.time()
273
274        # Sample from both the supervised and unsupervised loader.
275        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
276            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
277            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
278
279            # Perform supervised training.
280            self.optimizer.zero_grad()
281            with forward_context():
282                # We pass the model, the input and the labels to the supervised loss function,
283                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
284                supervised_loss = self.supervised_loss(self.model, xs, ys)
285
286            teacher_input, model_input = xu1, xu2
287
288            with forward_context(), torch.no_grad():
289                # Compute the pseudo labels.
290                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
291
292            # Perform unsupervised training
293            with forward_context():
294                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
295
296            loss = (supervised_loss + unsupervised_loss) / 2
297            backprop(loss)
298
299            if self.logger is not None:
300                with torch.no_grad(), forward_context():
301                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
302                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
303
304                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
305                self.logger.log_train_unsupervised(
306                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
307                )
308
309                self.logger.log_combined_loss(self._iteration, loss)
310                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
311                self.logger.log_lr(self._iteration, lr)
312
313            with torch.no_grad():
314                self._momentum_update()
315
316            self._iteration += 1
317            n_iter += 1
318            if self._iteration >= self.max_iteration:
319                break
320            progress.update(1)
321
322        t_per_iter = (time.time() - t_per_iter) / n_iter
323        return t_per_iter
324
325    def _validate_supervised(self, forward_context):
326        metric_val = 0.0
327        loss_val = 0.0
328
329        for x, y in self.supervised_val_loader:
330            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
331            with forward_context():
332                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
333            loss_val += loss.item()
334            metric_val += metric.item()
335
336        metric_val /= len(self.supervised_val_loader)
337        loss_val /= len(self.supervised_val_loader)
338
339        if self.logger is not None:
340            with forward_context():
341                pred = self.model(x)
342            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
343
344        return metric_val
345
346    def _validate_unsupervised(self, forward_context):
347        metric_val = 0.0
348        loss_val = 0.0
349
350        for x1, x2 in self.unsupervised_val_loader:
351            x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True)
352            teacher_input, model_input = x1, x2
353            with forward_context():
354                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
355                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
356            loss_val += loss.item()
357            metric_val += metric.item()
358
359        metric_val /= len(self.unsupervised_val_loader)
360        loss_val /= len(self.unsupervised_val_loader)
361
362        if self.logger is not None:
363            with forward_context():
364                pred = self.model(model_input)
365            self.logger.log_validation_unsupervised(
366                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
367            )
368
369        return metric_val
370
371    def _validate_impl(self, forward_context):
372        self.model.eval()
373
374        with torch.no_grad():
375
376            if self.supervised_val_loader is None:
377                supervised_metric = None
378            else:
379                supervised_metric = self._validate_supervised(forward_context)
380
381            if self.unsupervised_val_loader is None:
382                unsupervised_metric = None
383            else:
384                unsupervised_metric = self._validate_unsupervised(forward_context)
385
386        if unsupervised_metric is None:
387            metric = supervised_metric
388        elif supervised_metric is None:
389            metric = unsupervised_metric
390        else:
391            metric = (supervised_metric + unsupervised_metric) / 2
392
393        return metric

Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.

Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. It uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. We support two training strategies:

  • Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
  • Training only on the unsupervised data.

This class expects the following data loaders:

  • unsupervised_train_loader: Returns two augmentations of the same input.
  • supervised_train_loader (optional): Returns input and labels.
  • unsupervised_val_loader (optional): Same as unsupervised_train_loader
  • supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.

And the following elements to customize the pseudo labeling:

  • pseudo_labeler: to compute the psuedo-labels
    • Parameters: teacher, teacher_input
    • Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
  • unsupervised_loss: the loss between model predictions and pseudo labels
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss
  • supervised_loss (optional): the supervised loss function
    • Parameters: model, input, labels
    • Returns: loss
  • unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss, metric
  • supervised_loss_and_metric (optional): the supervised loss function and metric
    • Parameters: model, input, labels
    • Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.

If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. If it is None, the most appropriate initialization scheme for the training approach is chosen:

  • semi-supervised training -> reinit, because we usually train a model from scratch
  • unsupervised training -> do not reinit, because we usually fine-tune a model

Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples

Arguments:
  • model: The model to be trained.
  • unsupervised_train_loader: The loader for unsupervised training.
  • unsupervised_loss: The loss for unsupervised training.
  • pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
  • supervised_train_loader: The loader for supervised training.
  • supervised_loss: The loss for supervised training.
  • unsupervised_loss_and_metric: The loss and metric for unsupervised training.
  • supervised_loss_and_metric: The loss and metrhic for supervised training.
  • logger: The logger.
  • momentum: The momentum value for the exponential moving weight average of the teacher model.
  • reinit_teacher: Whether to reinit the teacher model before starting the training.
  • sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
  • kwargs: Additional keyword arguments for torch_em.trainer.DefaultTrainer.
MeanTeacherTrainer( model: torch.nn.modules.module.Module, unsupervised_train_loader: torch.utils.data.dataloader.DataLoader, unsupervised_loss: torch.utils.data.dataloader.DataLoader, pseudo_labeler: Callable, supervised_train_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, unsupervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_loss: Optional[Callable] = None, unsupervised_loss_and_metric: Optional[Callable] = None, supervised_loss_and_metric: Optional[Callable] = None, logger=<class 'torch_em.self_training.logger.SelfTrainingTensorboardLogger'>, momentum: float = 0.999, reinit_teacher: Optional[bool] = None, sampler: Optional[Callable] = None, **kwargs)
 75    def __init__(
 76        self,
 77        model: torch.nn.Module,
 78        unsupervised_train_loader: torch.utils.data.DataLoader,
 79        unsupervised_loss: torch.utils.data.DataLoader,
 80        pseudo_labeler: Callable,
 81        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 82        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 83        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 84        supervised_loss: Optional[Callable] = None,
 85        unsupervised_loss_and_metric: Optional[Callable] = None,
 86        supervised_loss_and_metric: Optional[Callable] = None,
 87        logger=SelfTrainingTensorboardLogger,
 88        momentum: float = 0.999,
 89        reinit_teacher: Optional[bool] = None,
 90        sampler: Optional[Callable] = None,
 91        **kwargs,
 92    ):
 93        self.sampler = sampler
 94        # Do we have supervised data or not?
 95        if supervised_train_loader is None:
 96            # No. -> We use the unsupervised training logic.
 97            train_loader = unsupervised_train_loader
 98            self._train_epoch_impl = self._train_epoch_unsupervised
 99        else:
100            # Yes. -> We use the semi-supervised training logic.
101            assert supervised_loss is not None
102            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
103                else unsupervised_train_loader
104            self._train_epoch_impl = self._train_epoch_semisupervised
105
106        self.unsupervised_train_loader = unsupervised_train_loader
107        self.supervised_train_loader = supervised_train_loader
108
109        # Check that we have at least one of supvervised / unsupervised val loader.
110        assert sum((
111            supervised_val_loader is not None,
112            unsupervised_val_loader is not None,
113        )) > 0
114        self.supervised_val_loader = supervised_val_loader
115        self.unsupervised_val_loader = unsupervised_val_loader
116
117        if self.unsupervised_val_loader is None:
118            val_loader = self.supervised_val_loader
119        else:
120            val_loader = self.unsupervised_train_loader
121
122        # Check that we have at least one of supvervised / unsupervised loss and metric.
123        assert sum((
124            supervised_loss_and_metric is not None,
125            unsupervised_loss_and_metric is not None,
126        )) > 0
127        self.supervised_loss_and_metric = supervised_loss_and_metric
128        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
129
130        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
131        kwargs.pop("train_loader", None)
132        kwargs.pop("val_loader", None)
133        kwargs.pop("metric", None)
134        kwargs.pop("loss", None)
135        super().__init__(
136            model=model, train_loader=train_loader, val_loader=val_loader,
137            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
138        )
139
140        self.unsupervised_loss = unsupervised_loss
141        self.supervised_loss = supervised_loss
142
143        self.pseudo_labeler = pseudo_labeler
144        self.momentum = momentum
145
146        # determine how we initialize the teacher weights (copy or reinitialization)
147        if reinit_teacher is None:
148            # semisupervised training: reinitialize
149            # unsupervised training: copy
150            self.reinit_teacher = supervised_train_loader is not None
151        else:
152            self.reinit_teacher = reinit_teacher
153
154        with torch.no_grad():
155            self.teacher = deepcopy(self.model)
156            if self.reinit_teacher:
157                for layer in self.teacher.children():
158                    if hasattr(layer, "reset_parameters"):
159                        layer.reset_parameters()
160            for param in self.teacher.parameters():
161                param.requires_grad = False
162
163        self._kwargs = kwargs
sampler
unsupervised_train_loader
supervised_train_loader
supervised_val_loader
unsupervised_val_loader
supervised_loss_and_metric
unsupervised_loss_and_metric
unsupervised_loss
supervised_loss
pseudo_labeler
momentum