torch_em.self_training.mean_teacher

  1import time
  2from copy import deepcopy
  3
  4import torch
  5import torch_em
  6from torch_em.util import get_constructor_arguments
  7
  8from .logger import SelfTrainingTensorboardLogger
  9
 10
 11class Dummy(torch.nn.Module):
 12    init_kwargs = {}
 13
 14
 15class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer):
 16    """This trainer implements self-training for semi-supervised learning and domain following the 'MeanTeacher'
 17    approach of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from
 18    the student model via EMA of weights to predict pseudo-labels on unlabeled data.
 19    We support two training strategies: joint training on labeled and unlabeled data
 20    (with a supervised and unsupervised loss function). And training only on the unsupervised data.
 21
 22    This class expects the following data loaders:
 23    - unsupervised_train_loader: Returns two augmentations of the same input.
 24    - supervised_train_loader (optional): Returns input and labels.
 25    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 26    - supervised_val_loader (optional): Same as supervised_train_loader
 27    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 28
 29    And the following elements to customize the pseudo labeling:
 30    - pseudo_labeler: to compute the psuedo-labels
 31        - Parameters: teacher, teacher_input
 32        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 33    - unsupervised_loss: the loss between model predictions and pseudo labels
 34        - Parameters: model, model_input, pseudo_labels, label_filter
 35        - Returns: loss
 36    - supervised_loss (optional): the supervised loss function
 37        - Parameters: model, input, labels
 38        - Returns: loss
 39    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 40        - Parameters: model, model_input, pseudo_labels, label_filter
 41        - Returns: loss, metric
 42    - supervised_loss_and_metric (optional): the supervised loss function and metric
 43        - Parameters: model, input, labels
 44        - Returns: loss, metric
 45    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 46
 47    If the parameter reinit_teacher is set to true, the teacher weights are re-initialized.
 48    If it is None, the most appropriate initialization scheme for the training approach is chosen:
 49    - semi-supervised training -> reinit, because we usually train a model from scratch
 50    - unsupervised training -> do not reinit, because we usually fine-tune a model
 51
 52    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 53    for setting the ratio between supervised and unsupervised training samples
 54
 55    Parameters:
 56        model [nn.Module] -
 57        unsupervised_train_loader [torch.DataLoader] -
 58        unsupervised_loss [callable] -
 59        pseudo_labeler [callable] -
 60        supervised_train_loader [torch.DataLoader] - (default: None)
 61        supervised_loss [callable] - (default: None)
 62        unsupervised_loss_and_metric [callable] - (default: None)
 63        supervised_loss_and_metric [callable] - (default: None)
 64        logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger)
 65        momentum [float] - (default: 0.999)
 66        reinit_teacher [bool] - (default: None)
 67        **kwargs - keyword arguments for torch_em.DataLoader
 68    """
 69
 70    def __init__(
 71        self,
 72        model,
 73        unsupervised_train_loader,
 74        unsupervised_loss,
 75        pseudo_labeler,
 76        supervised_train_loader=None,
 77        unsupervised_val_loader=None,
 78        supervised_val_loader=None,
 79        supervised_loss=None,
 80        unsupervised_loss_and_metric=None,
 81        supervised_loss_and_metric=None,
 82        logger=SelfTrainingTensorboardLogger,
 83        momentum=0.999,
 84        reinit_teacher=None,
 85        **kwargs
 86    ):
 87        # Do we have supervised data or not?
 88        if supervised_train_loader is None:
 89            # No. -> We use the unsupervised training logic.
 90            train_loader = unsupervised_train_loader
 91            self._train_epoch_impl = self._train_epoch_unsupervised
 92        else:
 93            # Yes. -> We use the semi-supervised training logic.
 94            assert supervised_loss is not None
 95            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 96                else unsupervised_train_loader
 97            self._train_epoch_impl = self._train_epoch_semisupervised
 98
 99        self.unsupervised_train_loader = unsupervised_train_loader
100        self.supervised_train_loader = supervised_train_loader
101
102        # Check that we have at least one of supvervised / unsupervised val loader.
103        assert sum((
104            supervised_val_loader is not None,
105            unsupervised_val_loader is not None,
106        )) > 0
107        self.supervised_val_loader = supervised_val_loader
108        self.unsupervised_val_loader = unsupervised_val_loader
109
110        if self.unsupervised_val_loader is None:
111            val_loader = self.supervised_val_loader
112        else:
113            val_loader = self.unsupervised_train_loader
114
115        # Check that we have at least one of supvervised / unsupervised loss and metric.
116        assert sum((
117            supervised_loss_and_metric is not None,
118            unsupervised_loss_and_metric is not None,
119        )) > 0
120        self.supervised_loss_and_metric = supervised_loss_and_metric
121        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
122
123        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
124        kwargs.pop("train_loader", None)
125        kwargs.pop("val_loader", None)
126        kwargs.pop("metric", None)
127        kwargs.pop("loss", None)
128        super().__init__(
129            model=model, train_loader=train_loader, val_loader=val_loader,
130            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
131        )
132
133        self.unsupervised_loss = unsupervised_loss
134        self.supervised_loss = supervised_loss
135
136        self.pseudo_labeler = pseudo_labeler
137        self.momentum = momentum
138
139        # determine how we initialize the teacher weights (copy or reinitialization)
140        if reinit_teacher is None:
141            # semisupervised training: reinitialize
142            # unsupervised training: copy
143            self.reinit_teacher = supervised_train_loader is not None
144        else:
145            self.reinit_teacher = reinit_teacher
146
147        with torch.no_grad():
148            self.teacher = deepcopy(self.model)
149            if self.reinit_teacher:
150                for layer in self.teacher.children():
151                    if hasattr(layer, "reset_parameters"):
152                        layer.reset_parameters()
153            for param in self.teacher.parameters():
154                param.requires_grad = False
155
156        self._kwargs = kwargs
157
158    def _momentum_update(self):
159        # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations
160        # to avoid a large gap between teacher and student weights, leading to inconsistent predictions
161        # if we don't reinit this is not necessary
162        if self.reinit_teacher:
163            current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum)
164        else:
165            current_momentum = self.momentum
166
167        for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()):
168            param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum)
169
170    #
171    # functionality for saving checkpoints and initialization
172    #
173
174    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
175        train_loader_kwargs = get_constructor_arguments(self.train_loader)
176        val_loader_kwargs = get_constructor_arguments(self.val_loader)
177        extra_state = {
178            "teacher_state": self.teacher.state_dict(),
179            "init": {
180                "train_loader_kwargs": train_loader_kwargs,
181                "train_dataset": self.train_loader.dataset,
182                "val_loader_kwargs": val_loader_kwargs,
183                "val_dataset": self.val_loader.dataset,
184                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
185                "loss_kwargs": {},
186                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
187                "metric_kwargs": {},
188            },
189        }
190        extra_state.update(**extra_save_dict)
191        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
192
193    def load_checkpoint(self, checkpoint="best"):
194        save_dict = super().load_checkpoint(checkpoint)
195        self.teacher.load_state_dict(save_dict["teacher_state"])
196        self.teacher.to(self.device)
197        return save_dict
198
199    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
200        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
201        self.teacher.to(self.device)
202        return best_metric
203
204    #
205    # training and validation functionality
206    #
207
208    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
209        self.model.train()
210
211        n_iter = 0
212        t_per_iter = time.time()
213
214        # Sample from both the supervised and unsupervised loader.
215        for xu1, xu2 in self.unsupervised_train_loader:
216            xu1, xu2 = xu1.to(self.device), xu2.to(self.device)
217
218            teacher_input, model_input = xu1, xu2
219
220            with forward_context(), torch.no_grad():
221                # Compute the pseudo labels.
222                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
223
224            self.optimizer.zero_grad()
225            # Perform unsupervised training
226            with forward_context():
227                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
228            backprop(loss)
229
230            if self.logger is not None:
231                with torch.no_grad(), forward_context():
232                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
233                self.logger.log_train_unsupervised(
234                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
235                )
236                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
237                self.logger.log_lr(self._iteration, lr)
238
239            with torch.no_grad():
240                self._momentum_update()
241
242            self._iteration += 1
243            n_iter += 1
244            if self._iteration >= self.max_iteration:
245                break
246            progress.update(1)
247
248        t_per_iter = (time.time() - t_per_iter) / n_iter
249        return t_per_iter
250
251    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
252        self.model.train()
253
254        n_iter = 0
255        t_per_iter = time.time()
256
257        # Sample from both the supervised and unsupervised loader.
258        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
259            xs, ys = xs.to(self.device), ys.to(self.device)
260            xu1, xu2 = xu1.to(self.device), xu2.to(self.device)
261
262            # Perform supervised training.
263            self.optimizer.zero_grad()
264            with forward_context():
265                # We pass the model, the input and the labels to the supervised loss function,
266                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
267                supervised_loss = self.supervised_loss(self.model, xs, ys)
268
269            teacher_input, model_input = xu1, xu2
270
271            with forward_context(), torch.no_grad():
272                # Compute the pseudo labels.
273                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
274
275            # Perform unsupervised training
276            with forward_context():
277                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
278
279            loss = (supervised_loss + unsupervised_loss) / 2
280            backprop(loss)
281
282            if self.logger is not None:
283                with torch.no_grad(), forward_context():
284                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
285                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
286
287                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
288                self.logger.log_train_unsupervised(
289                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
290                )
291
292                self.logger.log_combined_loss(self._iteration, loss)
293                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
294                self.logger.log_lr(self._iteration, lr)
295
296            with torch.no_grad():
297                self._momentum_update()
298
299            self._iteration += 1
300            n_iter += 1
301            if self._iteration >= self.max_iteration:
302                break
303            progress.update(1)
304
305        t_per_iter = (time.time() - t_per_iter) / n_iter
306        return t_per_iter
307
308    def _validate_supervised(self, forward_context):
309        metric_val = 0.0
310        loss_val = 0.0
311
312        for x, y in self.supervised_val_loader:
313            x, y = x.to(self.device), y.to(self.device)
314            with forward_context():
315                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
316            loss_val += loss.item()
317            metric_val += metric.item()
318
319        metric_val /= len(self.supervised_val_loader)
320        loss_val /= len(self.supervised_val_loader)
321
322        if self.logger is not None:
323            with forward_context():
324                pred = self.model(x)
325            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
326
327        return metric_val
328
329    def _validate_unsupervised(self, forward_context):
330        metric_val = 0.0
331        loss_val = 0.0
332
333        for x1, x2 in self.unsupervised_val_loader:
334            x1, x2 = x1.to(self.device), x2.to(self.device)
335            teacher_input, model_input = x1, x2
336            with forward_context():
337                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
338                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
339            loss_val += loss.item()
340            metric_val += metric.item()
341
342        metric_val /= len(self.unsupervised_val_loader)
343        loss_val /= len(self.unsupervised_val_loader)
344
345        if self.logger is not None:
346            with forward_context():
347                pred = self.model(model_input)
348            self.logger.log_validation_unsupervised(
349                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
350            )
351
352        return metric_val
353
354    def _validate_impl(self, forward_context):
355        self.model.eval()
356
357        with torch.no_grad():
358
359            if self.supervised_val_loader is None:
360                supervised_metric = None
361            else:
362                supervised_metric = self._validate_supervised(forward_context)
363
364            if self.unsupervised_val_loader is None:
365                unsupervised_metric = None
366            else:
367                unsupervised_metric = self._validate_unsupervised(forward_context)
368
369        if unsupervised_metric is None:
370            metric = supervised_metric
371        elif supervised_metric is None:
372            metric = unsupervised_metric
373        else:
374            metric = (supervised_metric + unsupervised_metric) / 2
375
376        return metric
class Dummy(torch.nn.modules.module.Module):
12class Dummy(torch.nn.Module):
13    init_kwargs = {}

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

init_kwargs = {}
Inherited Members
torch.nn.modules.module.Module
Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class MeanTeacherTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 16class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer):
 17    """This trainer implements self-training for semi-supervised learning and domain following the 'MeanTeacher'
 18    approach of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from
 19    the student model via EMA of weights to predict pseudo-labels on unlabeled data.
 20    We support two training strategies: joint training on labeled and unlabeled data
 21    (with a supervised and unsupervised loss function). And training only on the unsupervised data.
 22
 23    This class expects the following data loaders:
 24    - unsupervised_train_loader: Returns two augmentations of the same input.
 25    - supervised_train_loader (optional): Returns input and labels.
 26    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 27    - supervised_val_loader (optional): Same as supervised_train_loader
 28    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 29
 30    And the following elements to customize the pseudo labeling:
 31    - pseudo_labeler: to compute the psuedo-labels
 32        - Parameters: teacher, teacher_input
 33        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 34    - unsupervised_loss: the loss between model predictions and pseudo labels
 35        - Parameters: model, model_input, pseudo_labels, label_filter
 36        - Returns: loss
 37    - supervised_loss (optional): the supervised loss function
 38        - Parameters: model, input, labels
 39        - Returns: loss
 40    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 41        - Parameters: model, model_input, pseudo_labels, label_filter
 42        - Returns: loss, metric
 43    - supervised_loss_and_metric (optional): the supervised loss function and metric
 44        - Parameters: model, input, labels
 45        - Returns: loss, metric
 46    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 47
 48    If the parameter reinit_teacher is set to true, the teacher weights are re-initialized.
 49    If it is None, the most appropriate initialization scheme for the training approach is chosen:
 50    - semi-supervised training -> reinit, because we usually train a model from scratch
 51    - unsupervised training -> do not reinit, because we usually fine-tune a model
 52
 53    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 54    for setting the ratio between supervised and unsupervised training samples
 55
 56    Parameters:
 57        model [nn.Module] -
 58        unsupervised_train_loader [torch.DataLoader] -
 59        unsupervised_loss [callable] -
 60        pseudo_labeler [callable] -
 61        supervised_train_loader [torch.DataLoader] - (default: None)
 62        supervised_loss [callable] - (default: None)
 63        unsupervised_loss_and_metric [callable] - (default: None)
 64        supervised_loss_and_metric [callable] - (default: None)
 65        logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger)
 66        momentum [float] - (default: 0.999)
 67        reinit_teacher [bool] - (default: None)
 68        **kwargs - keyword arguments for torch_em.DataLoader
 69    """
 70
 71    def __init__(
 72        self,
 73        model,
 74        unsupervised_train_loader,
 75        unsupervised_loss,
 76        pseudo_labeler,
 77        supervised_train_loader=None,
 78        unsupervised_val_loader=None,
 79        supervised_val_loader=None,
 80        supervised_loss=None,
 81        unsupervised_loss_and_metric=None,
 82        supervised_loss_and_metric=None,
 83        logger=SelfTrainingTensorboardLogger,
 84        momentum=0.999,
 85        reinit_teacher=None,
 86        **kwargs
 87    ):
 88        # Do we have supervised data or not?
 89        if supervised_train_loader is None:
 90            # No. -> We use the unsupervised training logic.
 91            train_loader = unsupervised_train_loader
 92            self._train_epoch_impl = self._train_epoch_unsupervised
 93        else:
 94            # Yes. -> We use the semi-supervised training logic.
 95            assert supervised_loss is not None
 96            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 97                else unsupervised_train_loader
 98            self._train_epoch_impl = self._train_epoch_semisupervised
 99
100        self.unsupervised_train_loader = unsupervised_train_loader
101        self.supervised_train_loader = supervised_train_loader
102
103        # Check that we have at least one of supvervised / unsupervised val loader.
104        assert sum((
105            supervised_val_loader is not None,
106            unsupervised_val_loader is not None,
107        )) > 0
108        self.supervised_val_loader = supervised_val_loader
109        self.unsupervised_val_loader = unsupervised_val_loader
110
111        if self.unsupervised_val_loader is None:
112            val_loader = self.supervised_val_loader
113        else:
114            val_loader = self.unsupervised_train_loader
115
116        # Check that we have at least one of supvervised / unsupervised loss and metric.
117        assert sum((
118            supervised_loss_and_metric is not None,
119            unsupervised_loss_and_metric is not None,
120        )) > 0
121        self.supervised_loss_and_metric = supervised_loss_and_metric
122        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
123
124        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
125        kwargs.pop("train_loader", None)
126        kwargs.pop("val_loader", None)
127        kwargs.pop("metric", None)
128        kwargs.pop("loss", None)
129        super().__init__(
130            model=model, train_loader=train_loader, val_loader=val_loader,
131            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
132        )
133
134        self.unsupervised_loss = unsupervised_loss
135        self.supervised_loss = supervised_loss
136
137        self.pseudo_labeler = pseudo_labeler
138        self.momentum = momentum
139
140        # determine how we initialize the teacher weights (copy or reinitialization)
141        if reinit_teacher is None:
142            # semisupervised training: reinitialize
143            # unsupervised training: copy
144            self.reinit_teacher = supervised_train_loader is not None
145        else:
146            self.reinit_teacher = reinit_teacher
147
148        with torch.no_grad():
149            self.teacher = deepcopy(self.model)
150            if self.reinit_teacher:
151                for layer in self.teacher.children():
152                    if hasattr(layer, "reset_parameters"):
153                        layer.reset_parameters()
154            for param in self.teacher.parameters():
155                param.requires_grad = False
156
157        self._kwargs = kwargs
158
159    def _momentum_update(self):
160        # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations
161        # to avoid a large gap between teacher and student weights, leading to inconsistent predictions
162        # if we don't reinit this is not necessary
163        if self.reinit_teacher:
164            current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum)
165        else:
166            current_momentum = self.momentum
167
168        for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()):
169            param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum)
170
171    #
172    # functionality for saving checkpoints and initialization
173    #
174
175    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
176        train_loader_kwargs = get_constructor_arguments(self.train_loader)
177        val_loader_kwargs = get_constructor_arguments(self.val_loader)
178        extra_state = {
179            "teacher_state": self.teacher.state_dict(),
180            "init": {
181                "train_loader_kwargs": train_loader_kwargs,
182                "train_dataset": self.train_loader.dataset,
183                "val_loader_kwargs": val_loader_kwargs,
184                "val_dataset": self.val_loader.dataset,
185                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
186                "loss_kwargs": {},
187                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
188                "metric_kwargs": {},
189            },
190        }
191        extra_state.update(**extra_save_dict)
192        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
193
194    def load_checkpoint(self, checkpoint="best"):
195        save_dict = super().load_checkpoint(checkpoint)
196        self.teacher.load_state_dict(save_dict["teacher_state"])
197        self.teacher.to(self.device)
198        return save_dict
199
200    def _initialize(self, iterations, load_from_checkpoint, epochs=None):
201        best_metric = super()._initialize(iterations, load_from_checkpoint, epochs)
202        self.teacher.to(self.device)
203        return best_metric
204
205    #
206    # training and validation functionality
207    #
208
209    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
210        self.model.train()
211
212        n_iter = 0
213        t_per_iter = time.time()
214
215        # Sample from both the supervised and unsupervised loader.
216        for xu1, xu2 in self.unsupervised_train_loader:
217            xu1, xu2 = xu1.to(self.device), xu2.to(self.device)
218
219            teacher_input, model_input = xu1, xu2
220
221            with forward_context(), torch.no_grad():
222                # Compute the pseudo labels.
223                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
224
225            self.optimizer.zero_grad()
226            # Perform unsupervised training
227            with forward_context():
228                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
229            backprop(loss)
230
231            if self.logger is not None:
232                with torch.no_grad(), forward_context():
233                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
234                self.logger.log_train_unsupervised(
235                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
236                )
237                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
238                self.logger.log_lr(self._iteration, lr)
239
240            with torch.no_grad():
241                self._momentum_update()
242
243            self._iteration += 1
244            n_iter += 1
245            if self._iteration >= self.max_iteration:
246                break
247            progress.update(1)
248
249        t_per_iter = (time.time() - t_per_iter) / n_iter
250        return t_per_iter
251
252    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
253        self.model.train()
254
255        n_iter = 0
256        t_per_iter = time.time()
257
258        # Sample from both the supervised and unsupervised loader.
259        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
260            xs, ys = xs.to(self.device), ys.to(self.device)
261            xu1, xu2 = xu1.to(self.device), xu2.to(self.device)
262
263            # Perform supervised training.
264            self.optimizer.zero_grad()
265            with forward_context():
266                # We pass the model, the input and the labels to the supervised loss function,
267                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
268                supervised_loss = self.supervised_loss(self.model, xs, ys)
269
270            teacher_input, model_input = xu1, xu2
271
272            with forward_context(), torch.no_grad():
273                # Compute the pseudo labels.
274                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
275
276            # Perform unsupervised training
277            with forward_context():
278                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
279
280            loss = (supervised_loss + unsupervised_loss) / 2
281            backprop(loss)
282
283            if self.logger is not None:
284                with torch.no_grad(), forward_context():
285                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
286                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
287
288                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
289                self.logger.log_train_unsupervised(
290                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
291                )
292
293                self.logger.log_combined_loss(self._iteration, loss)
294                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
295                self.logger.log_lr(self._iteration, lr)
296
297            with torch.no_grad():
298                self._momentum_update()
299
300            self._iteration += 1
301            n_iter += 1
302            if self._iteration >= self.max_iteration:
303                break
304            progress.update(1)
305
306        t_per_iter = (time.time() - t_per_iter) / n_iter
307        return t_per_iter
308
309    def _validate_supervised(self, forward_context):
310        metric_val = 0.0
311        loss_val = 0.0
312
313        for x, y in self.supervised_val_loader:
314            x, y = x.to(self.device), y.to(self.device)
315            with forward_context():
316                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
317            loss_val += loss.item()
318            metric_val += metric.item()
319
320        metric_val /= len(self.supervised_val_loader)
321        loss_val /= len(self.supervised_val_loader)
322
323        if self.logger is not None:
324            with forward_context():
325                pred = self.model(x)
326            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
327
328        return metric_val
329
330    def _validate_unsupervised(self, forward_context):
331        metric_val = 0.0
332        loss_val = 0.0
333
334        for x1, x2 in self.unsupervised_val_loader:
335            x1, x2 = x1.to(self.device), x2.to(self.device)
336            teacher_input, model_input = x1, x2
337            with forward_context():
338                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
339                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
340            loss_val += loss.item()
341            metric_val += metric.item()
342
343        metric_val /= len(self.unsupervised_val_loader)
344        loss_val /= len(self.unsupervised_val_loader)
345
346        if self.logger is not None:
347            with forward_context():
348                pred = self.model(model_input)
349            self.logger.log_validation_unsupervised(
350                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
351            )
352
353        return metric_val
354
355    def _validate_impl(self, forward_context):
356        self.model.eval()
357
358        with torch.no_grad():
359
360            if self.supervised_val_loader is None:
361                supervised_metric = None
362            else:
363                supervised_metric = self._validate_supervised(forward_context)
364
365            if self.unsupervised_val_loader is None:
366                unsupervised_metric = None
367            else:
368                unsupervised_metric = self._validate_unsupervised(forward_context)
369
370        if unsupervised_metric is None:
371            metric = supervised_metric
372        elif supervised_metric is None:
373            metric = unsupervised_metric
374        else:
375            metric = (supervised_metric + unsupervised_metric) / 2
376
377        return metric

This trainer implements self-training for semi-supervised learning and domain following the 'MeanTeacher' approach of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. We support two training strategies: joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). And training only on the unsupervised data.

This class expects the following data loaders:

  • unsupervised_train_loader: Returns two augmentations of the same input.
  • supervised_train_loader (optional): Returns input and labels.
  • unsupervised_val_loader (optional): Same as unsupervised_train_loader
  • supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.

And the following elements to customize the pseudo labeling:

  • pseudo_labeler: to compute the psuedo-labels
    • Parameters: teacher, teacher_input
    • Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
  • unsupervised_loss: the loss between model predictions and pseudo labels
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss
  • supervised_loss (optional): the supervised loss function
    • Parameters: model, input, labels
    • Returns: loss
  • unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss, metric
  • supervised_loss_and_metric (optional): the supervised loss function and metric
    • Parameters: model, input, labels
    • Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.

If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. If it is None, the most appropriate initialization scheme for the training approach is chosen:

  • semi-supervised training -> reinit, because we usually train a model from scratch
  • unsupervised training -> do not reinit, because we usually fine-tune a model

Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples

Arguments:
  • model [nn.Module] -
  • unsupervised_train_loader [torch.DataLoader] -
  • unsupervised_loss [callable] -
  • pseudo_labeler [callable] -
  • supervised_train_loader [torch.DataLoader] - (default: None)
  • supervised_loss [callable] - (default: None)
  • unsupervised_loss_and_metric [callable] - (default: None)
  • supervised_loss_and_metric [callable] - (default: None)
  • logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger)
  • momentum [float] - (default: 0.999)
  • reinit_teacher [bool] - (default: None)
  • **kwargs - keyword arguments for torch_em.DataLoader
MeanTeacherTrainer( model, unsupervised_train_loader, unsupervised_loss, pseudo_labeler, supervised_train_loader=None, unsupervised_val_loader=None, supervised_val_loader=None, supervised_loss=None, unsupervised_loss_and_metric=None, supervised_loss_and_metric=None, logger=<class 'torch_em.self_training.logger.SelfTrainingTensorboardLogger'>, momentum=0.999, reinit_teacher=None, **kwargs)
 71    def __init__(
 72        self,
 73        model,
 74        unsupervised_train_loader,
 75        unsupervised_loss,
 76        pseudo_labeler,
 77        supervised_train_loader=None,
 78        unsupervised_val_loader=None,
 79        supervised_val_loader=None,
 80        supervised_loss=None,
 81        unsupervised_loss_and_metric=None,
 82        supervised_loss_and_metric=None,
 83        logger=SelfTrainingTensorboardLogger,
 84        momentum=0.999,
 85        reinit_teacher=None,
 86        **kwargs
 87    ):
 88        # Do we have supervised data or not?
 89        if supervised_train_loader is None:
 90            # No. -> We use the unsupervised training logic.
 91            train_loader = unsupervised_train_loader
 92            self._train_epoch_impl = self._train_epoch_unsupervised
 93        else:
 94            # Yes. -> We use the semi-supervised training logic.
 95            assert supervised_loss is not None
 96            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 97                else unsupervised_train_loader
 98            self._train_epoch_impl = self._train_epoch_semisupervised
 99
100        self.unsupervised_train_loader = unsupervised_train_loader
101        self.supervised_train_loader = supervised_train_loader
102
103        # Check that we have at least one of supvervised / unsupervised val loader.
104        assert sum((
105            supervised_val_loader is not None,
106            unsupervised_val_loader is not None,
107        )) > 0
108        self.supervised_val_loader = supervised_val_loader
109        self.unsupervised_val_loader = unsupervised_val_loader
110
111        if self.unsupervised_val_loader is None:
112            val_loader = self.supervised_val_loader
113        else:
114            val_loader = self.unsupervised_train_loader
115
116        # Check that we have at least one of supvervised / unsupervised loss and metric.
117        assert sum((
118            supervised_loss_and_metric is not None,
119            unsupervised_loss_and_metric is not None,
120        )) > 0
121        self.supervised_loss_and_metric = supervised_loss_and_metric
122        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
123
124        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
125        kwargs.pop("train_loader", None)
126        kwargs.pop("val_loader", None)
127        kwargs.pop("metric", None)
128        kwargs.pop("loss", None)
129        super().__init__(
130            model=model, train_loader=train_loader, val_loader=val_loader,
131            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
132        )
133
134        self.unsupervised_loss = unsupervised_loss
135        self.supervised_loss = supervised_loss
136
137        self.pseudo_labeler = pseudo_labeler
138        self.momentum = momentum
139
140        # determine how we initialize the teacher weights (copy or reinitialization)
141        if reinit_teacher is None:
142            # semisupervised training: reinitialize
143            # unsupervised training: copy
144            self.reinit_teacher = supervised_train_loader is not None
145        else:
146            self.reinit_teacher = reinit_teacher
147
148        with torch.no_grad():
149            self.teacher = deepcopy(self.model)
150            if self.reinit_teacher:
151                for layer in self.teacher.children():
152                    if hasattr(layer, "reset_parameters"):
153                        layer.reset_parameters()
154            for param in self.teacher.parameters():
155                param.requires_grad = False
156
157        self._kwargs = kwargs
unsupervised_train_loader
supervised_train_loader
supervised_val_loader
unsupervised_val_loader
supervised_loss_and_metric
unsupervised_loss_and_metric
unsupervised_loss
supervised_loss
pseudo_labeler
momentum
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
175    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
176        train_loader_kwargs = get_constructor_arguments(self.train_loader)
177        val_loader_kwargs = get_constructor_arguments(self.val_loader)
178        extra_state = {
179            "teacher_state": self.teacher.state_dict(),
180            "init": {
181                "train_loader_kwargs": train_loader_kwargs,
182                "train_dataset": self.train_loader.dataset,
183                "val_loader_kwargs": val_loader_kwargs,
184                "val_dataset": self.val_loader.dataset,
185                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
186                "loss_kwargs": {},
187                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
188                "metric_kwargs": {},
189            },
190        }
191        extra_state.update(**extra_save_dict)
192        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
def load_checkpoint(self, checkpoint='best'):
194    def load_checkpoint(self, checkpoint="best"):
195        save_dict = super().load_checkpoint(checkpoint)
196        self.teacher.load_state_dict(save_dict["teacher_state"])
197        self.teacher.to(self.device)
198        return save_dict