torch_em.self_training.fix_match

  1import time
  2from typing import Callable, List, Optional
  3
  4import torch
  5import torch_em
  6from torch_em.util import get_constructor_arguments
  7
  8from .logger import SelfTrainingTensorboardLogger
  9from .mean_teacher import Dummy
 10from ..transform.invertible_augmentations import InvertibleAugmenter
 11
 12
 13class FixMatchTrainer(torch_em.trainer.DefaultTrainer):
 14    """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.
 15
 16    FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685).
 17    It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels
 18    on unlabeled data. We support two training strategies:
 19    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 20    - Taining only on the unsupervised data.
 21
 22    This class expects the following data loaders:
 23    - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
 24    - supervised_train_loader (optional): Returns input and labels.
 25    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 26    - supervised_val_loader (optional): Same as supervised_train_loader
 27    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 28
 29    The following arguments can be used to customize the pseudo labeling:
 30    - pseudo_labeler: to compute the psuedo-labels
 31        - Parameters: model, teacher_input
 32        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 33    - unsupervised_loss: the loss between model predictions and pseudo labels
 34        - Parameters: model, model_input, pseudo_labels, label_filter
 35        - Returns: loss
 36    - supervised_loss (optional): the supervised loss function
 37        - Parameters: model, input, labels
 38        - Returns: loss
 39    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 40        - Parameters: model, model_input, pseudo_labels, label_filter
 41        - Returns: loss, metric
 42    - supervised_loss_and_metric (optional): the supervised loss function and metric
 43        - Parameters: model, input, labels
 44        - Returns: loss, metric
 45    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 46
 47    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 48    for setting the ratio between supervised and unsupervised training samples.
 49
 50    Args:
 51        model: The model to be trained.
 52        unsupervised_train_loader: The loader for unsupervised training.
 53        unsupervised_loss: The loss for unsupervised training.
 54        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 55        supervised_train_loader: The loader for supervised training.
 56        supervised_loss: The loss for supervised training.
 57        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 58        supervised_loss_and_metric: The loss and metrhic for supervised training.
 59        logger: The logger.
 60        source_distribution: The ratio of labels in the source label distribution.
 61            If given, the predicted distribution of the trained model will be regularized to
 62            match this source label distribution.
 63        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 64    """
 65
 66    def __init__(
 67        self,
 68        model: torch.nn.Module,
 69        unsupervised_train_loader: torch.utils.data.DataLoader,
 70        unsupervised_loss: torch.utils.data.DataLoader,
 71        pseudo_labeler: Callable,
 72        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 73        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 74        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 75        supervised_loss: Optional[Callable] = None,
 76        unsupervised_loss_and_metric: Optional[Callable] = None,
 77        supervised_loss_and_metric: Optional[Callable] = None,
 78        logger=SelfTrainingTensorboardLogger,
 79        source_distribution: List[float] = None,
 80        **kwargs,
 81    ):
 82        # Do we have supervised data or not?
 83        if supervised_train_loader is None:
 84            # No. -> We use the unsupervised training logic.
 85            train_loader = unsupervised_train_loader
 86            self._train_epoch_impl = self._train_epoch_unsupervised
 87        else:
 88            # Yes. -> We use the semi-supervised training logic.
 89            assert supervised_loss is not None
 90            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 91                else unsupervised_train_loader
 92            self._train_epoch_impl = self._train_epoch_semisupervised
 93
 94        self.unsupervised_train_loader = unsupervised_train_loader
 95        self.supervised_train_loader = supervised_train_loader
 96
 97        # Check that we have at least one of supvervised / unsupervised val loader.
 98        assert sum((
 99            supervised_val_loader is not None,
100            unsupervised_val_loader is not None,
101        )) > 0
102        self.supervised_val_loader = supervised_val_loader
103        self.unsupervised_val_loader = unsupervised_val_loader
104
105        if self.unsupervised_val_loader is None:
106            val_loader = self.supervised_val_loader
107        else:
108            val_loader = self.unsupervised_val_loader
109
110        # Check that we have at least one of supvervised / unsupervised loss and metric.
111        assert sum((
112            supervised_loss_and_metric is not None,
113            unsupervised_loss_and_metric is not None,
114        )) > 0
115        self.supervised_loss_and_metric = supervised_loss_and_metric
116        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
117
118        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
119        kwargs.pop("train_loader", None)
120        kwargs.pop("val_loader", None)
121        kwargs.pop("metric", None)
122        kwargs.pop("loss", None)
123        super().__init__(
124            model=model, train_loader=train_loader, val_loader=val_loader,
125            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
126        )
127
128        self.unsupervised_loss = unsupervised_loss
129        self.supervised_loss = supervised_loss
130
131        self.pseudo_labeler = pseudo_labeler
132
133        if source_distribution is None:
134            self.source_distribution = None
135        else:
136            self.source_distribution = torch.FloatTensor(source_distribution).to(self.device)
137
138        self._kwargs = kwargs
139
140    #
141    # functionality for saving checkpoints and initialization
142    #
143
144    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
145        """@private
146        """
147        train_loader_kwargs = get_constructor_arguments(self.train_loader)
148        val_loader_kwargs = get_constructor_arguments(self.val_loader)
149        extra_state = {
150            "init": {
151                "train_loader_kwargs": train_loader_kwargs,
152                "train_dataset": self.train_loader.dataset,
153                "val_loader_kwargs": val_loader_kwargs,
154                "val_dataset": self.val_loader.dataset,
155                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
156                "loss_kwargs": {},
157                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
158                "metric_kwargs": {},
159            },
160        }
161        extra_state.update(**extra_save_dict)
162        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
163
164    # Distribution alignment:
165    # Encourages the distribution of the model's generated pseudo labels to match the marginal
166    # distribution of pseudo labels from the source transfer (key idea: to maximize the mutual information).
167    def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5):
168        """@private
169        """
170        if self.source_distribution is not None:
171            pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0)
172            _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True)
173            target_distribution = target_distribution / target_distribution.sum()
174            distribution_ratio = self.source_distribution / target_distribution
175            pseudo_labels = torch.where(
176                pseudo_labels < label_threshold,
177                pseudo_labels * distribution_ratio[0],
178                pseudo_labels * distribution_ratio[1]
179            ).clip(0, 1)
180
181        return pseudo_labels
182
183    #
184    # training and validation functionality
185    #
186
187    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
188        self.model.train()
189
190        n_iter = 0
191        t_per_iter = time.time()
192
193        # Sample from both the supervised and unsupervised loader.
194        for xu1, xu2 in self.unsupervised_train_loader:
195            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
196
197            teacher_input, model_input = xu1, xu2
198
199            with forward_context(), torch.no_grad():
200                # Compute the pseudo labels.
201                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
202
203            pseudo_labels = pseudo_labels.detach()
204            if label_filter is not None:
205                label_filter = label_filter.detach()
206
207            # Perform distribution alignment for pseudo labels
208            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
209
210            self.optimizer.zero_grad()
211            # Perform unsupervised training
212            with forward_context():
213                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
214
215            backprop(loss)
216
217            if self.logger is not None:
218                with torch.no_grad(), forward_context():
219                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
220                self.logger.log_train_unsupervised(
221                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
222                )
223                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
224                self.logger.log_lr(self._iteration, lr)
225                if self.pseudo_labeler.confidence_threshold is not None:
226                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
227
228            self._iteration += 1
229            n_iter += 1
230            if self._iteration >= self.max_iteration:
231                break
232            progress.update(1)
233
234        t_per_iter = (time.time() - t_per_iter) / n_iter
235        return t_per_iter
236
237    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
238        self.model.train()
239
240        n_iter = 0
241        t_per_iter = time.time()
242
243        # Sample from both the supervised and unsupervised loader.
244        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
245            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
246            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
247
248            # Perform supervised training.
249            self.optimizer.zero_grad()
250            with forward_context():
251                # We pass the model, the input and the labels to the supervised loss function,
252                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
253                supervised_loss = self.supervised_loss(self.model, xs, ys)
254
255            teacher_input, model_input = xu1, xu2
256
257            with forward_context(), torch.no_grad():
258                # Compute the pseudo labels.
259                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
260
261            pseudo_labels = pseudo_labels.detach()
262            if label_filter is not None:
263                label_filter = label_filter.detach()
264
265            # Perform distribution alignment for pseudo labels
266            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
267
268            # Perform unsupervised training
269            with forward_context():
270                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
271
272            loss = (supervised_loss + unsupervised_loss) / 2
273            backprop(loss)
274
275            if self.logger is not None:
276                with torch.no_grad(), forward_context():
277                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
278                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
279
280                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
281                self.logger.log_train_unsupervised(
282                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
283                )
284
285                self.logger.log_combined_loss(self._iteration, loss)
286                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
287                self.logger.log_lr(self._iteration, lr)
288                if self.pseudo_labeler.confidence_threshold is not None:
289                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
290
291            self._iteration += 1
292            n_iter += 1
293            if self._iteration >= self.max_iteration:
294                break
295            progress.update(1)
296
297        t_per_iter = (time.time() - t_per_iter) / n_iter
298        return t_per_iter
299
300    def _validate_supervised(self, forward_context):
301        metric_val = 0.0
302        loss_val = 0.0
303
304        for x, y in self.supervised_val_loader:
305            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
306            with forward_context():
307                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
308            loss_val += loss.item()
309            metric_val += metric.item()
310
311        metric_val /= len(self.supervised_val_loader)
312        loss_val /= len(self.supervised_val_loader)
313
314        if self.logger is not None:
315            with forward_context():
316                pred = self.model(x)
317            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
318
319        return metric_val
320
321    def _validate_unsupervised(self, forward_context):
322        metric_val = 0.0
323        loss_val = 0.0
324
325        for x1, x2 in self.unsupervised_val_loader:
326            x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True)
327            teacher_input, model_input = x1, x2
328            with forward_context():
329                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
330                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
331            loss_val += loss.item()
332            metric_val += metric.item()
333
334        metric_val /= len(self.unsupervised_val_loader)
335        loss_val /= len(self.unsupervised_val_loader)
336
337        if self.logger is not None:
338            with forward_context():
339                pred = self.model(model_input)
340            self.logger.log_validation_unsupervised(
341                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
342            )
343
344        return metric_val
345
346    def _validate_impl(self, forward_context):
347        self.model.eval()
348
349        with torch.no_grad():
350
351            if self.supervised_val_loader is None:
352                supervised_metric = None
353            else:
354                supervised_metric = self._validate_supervised(forward_context)
355
356            if self.unsupervised_val_loader is None:
357                unsupervised_metric = None
358            else:
359                unsupervised_metric = self._validate_unsupervised(forward_context)
360
361        if unsupervised_metric is None:
362            metric = supervised_metric
363        elif supervised_metric is None:
364            metric = unsupervised_metric
365        else:
366            metric = (supervised_metric + unsupervised_metric) / 2
367
368        return metric
369
370
371class FixMatchTrainerWithInvertibleAugmentations(FixMatchTrainer):
372    """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.
373
374    FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685).
375    It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels
376    on unlabeled data. We support two training strategies:
377    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
378    - Taining only on the unsupervised data.
379
380    This class expects the following data loaders:
381    - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
382    - supervised_train_loader (optional): Returns input and labels.
383    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
384    - supervised_val_loader (optional): Same as supervised_train_loader
385    At least one of unsupervised_val_loader and supervised_val_loader must be given.
386
387    The augmenter defines separate invertible transforms for teacher and student inputs.
388    Teacher and student views are generated independently, and the corresponding inverse
389    transforms map predictions and pseudo-labels back into a shared reference frame.
390
391    The following arguments can be used to customize the pseudo labeling:
392    - pseudo_labeler: to compute the psuedo-labels
393        - Parameters: model, teacher_input
394        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
395    - unsupervised_loss: the loss between model predictions and pseudo labels
396        - Parameters: model, model_input, pseudo_labels, label_filter
397        - Returns: loss
398    - supervised_loss (optional): the supervised loss function
399        - Parameters: model, input, labels
400        - Returns: loss
401    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
402        - Parameters: model, model_input, pseudo_labels, label_filter
403        - Returns: loss, metric
404    - supervised_loss_and_metric (optional): the supervised loss function and metric
405        - Parameters: model, input, labels
406        - Returns: loss, metric
407    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
408
409    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
410    for setting the ratio between supervised and unsupervised training samples.
411
412    Args:
413        model: The model to be trained.
414        unsupervised_train_loader: The loader for unsupervised training.
415        unsupervised_loss: The loss for unsupervised training.
416        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
417        augmenter: Invertible augmenter providing separate teacher and student transforms,
418            including corresponding inverse transforms to align predictions in a common frame.
419        supervised_train_loader: The loader for supervised training.
420        supervised_loss: The loss for supervised training.
421        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
422        supervised_loss_and_metric: The loss and metrhic for supervised training.
423        logger: The logger.
424        source_distribution: The ratio of labels in the source label distribution.
425            If given, the predicted distribution of the trained model will be regularized to
426            match this source label distribution.
427        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
428    """
429
430    def __init__(
431        self,
432        model: torch.nn.Module,
433        unsupervised_train_loader: torch.utils.data.DataLoader,
434        unsupervised_loss: torch.utils.data.DataLoader,
435        pseudo_labeler: Callable,
436        augmenter: InvertibleAugmenter,
437        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
438        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
439        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
440        supervised_loss: Optional[Callable] = None,
441        unsupervised_loss_and_metric: Optional[Callable] = None,
442        supervised_loss_and_metric: Optional[Callable] = None,
443        logger=SelfTrainingTensorboardLogger,
444        source_distribution: List[float] = None,
445        **kwargs,
446    ):
447        super().__init__(
448            model=model,
449            unsupervised_train_loader=unsupervised_train_loader,
450            unsupervised_loss=unsupervised_loss,
451            pseudo_labeler=pseudo_labeler,
452            supervised_train_loader=supervised_train_loader,
453            unsupervised_val_loader=unsupervised_val_loader,
454            supervised_val_loader=supervised_val_loader,
455            supervised_loss=supervised_loss,
456            unsupervised_loss_and_metric=unsupervised_loss_and_metric,
457            supervised_loss_and_metric=supervised_loss_and_metric,
458            logger=logger,
459            source_distribution=source_distribution,
460            **kwargs,
461        )
462
463        self.augmenter = augmenter
464
465    #
466    # training and validation functionality
467    #
468
469    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
470        self.model.train()
471
472        n_iter = 0
473        t_per_iter = time.time()
474
475        # Sample from both the supervised and unsupervised loader.
476        for xu in self.unsupervised_train_loader:
477            self.augmenter.reset_all()
478            xu = xu.to(self.device, non_blocking=True)
479
480            xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu)
481            teacher_input, model_input = xu1, xu2
482
483            with forward_context(), torch.no_grad():
484                # Compute the pseudo labels.
485                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
486
487            pseudo_labels = pseudo_labels.detach()
488            if label_filter is not None:
489                label_filter = label_filter.detach()
490
491            # Perform distribution alignment for pseudo labels, then invert into the reference frame.
492            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
493            pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
494            label_filter_inv = (
495                self.augmenter.teacher.reverse_transform(label_filter)
496                if label_filter is not None else None
497            )
498
499            self.optimizer.zero_grad()
500            # Perform unsupervised training
501            with forward_context():
502                pred = self.model(model_input)
503                pred_inv = self.augmenter.student.reverse_transform(pred)
504                loss = self.unsupervised_loss(pred_inv, pseudo_labels_inv, label_filter_inv)
505
506            backprop(loss)
507
508            if self.logger is not None:
509                with torch.no_grad(), forward_context():
510                    pred = pred if self._iteration % self.log_image_interval == 0 else None
511                self.logger.log_train_unsupervised(
512                    self._iteration, loss, xu, xu, pred_inv, pseudo_labels_inv, label_filter_inv
513                )
514                self.logger.log_train_augmentations(
515                    self._iteration, xu1, xu2, pseudo_labels, pred,
516                )
517                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
518                self.logger.log_lr(self._iteration, lr)
519                if self.pseudo_labeler.confidence_threshold is not None:
520                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
521
522            self._iteration += 1
523            n_iter += 1
524            if self._iteration >= self.max_iteration:
525                break
526            progress.update(1)
527
528        t_per_iter = (time.time() - t_per_iter) / n_iter
529        return t_per_iter
530
531    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
532        self.model.train()
533
534        n_iter = 0
535        t_per_iter = time.time()
536
537        # Sample from both the supervised and unsupervised loader.
538        for (xs, ys), xu in zip(self.supervised_train_loader, self.unsupervised_train_loader):
539            self.augmenter.reset_all()
540            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
541            xu = xu.to(self.device, non_blocking=True)
542
543            xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu)
544            teacher_input, model_input = xu1, xu2
545
546            # Perform supervised training.
547            self.optimizer.zero_grad()
548            with forward_context():
549                # We pass the model, the input and the labels to the supervised loss function,
550                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
551                supervised_pred = self.model(xs)
552                supervised_loss = self.supervised_loss(supervised_pred, ys)
553
554            with forward_context(), torch.no_grad():
555                # Compute the pseudo labels.
556                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
557
558            pseudo_labels = pseudo_labels.detach()
559            if label_filter is not None:
560                label_filter = label_filter.detach()
561
562            # Perform distribution alignment for pseudo labels, then invert into the reference frame.
563            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
564            pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
565            label_filter_inv = (
566                self.augmenter.teacher.reverse_transform(label_filter)
567                if label_filter is not None else None
568            )
569
570            # Perform unsupervised training
571            with forward_context():
572                unsup_pred = self.model(model_input)
573                unsup_pred_inv = self.augmenter.student.reverse_transform(unsup_pred)
574                unsupervised_loss = self.unsupervised_loss(unsup_pred_inv, pseudo_labels_inv, label_filter_inv)
575
576            loss = (supervised_loss + unsupervised_loss) / 2
577            backprop(loss)
578
579            if self.logger is not None:
580                with torch.no_grad(), forward_context():
581                    unsup_pred = unsup_pred if self._iteration % self.log_image_interval == 0 else None
582                    supervised_pred = supervised_pred if self._iteration % self.log_image_interval == 0 else None
583
584                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
585                self.logger.log_train_unsupervised(
586                    self._iteration, unsupervised_loss, xu, xu, unsup_pred_inv, pseudo_labels_inv, label_filter_inv
587                )
588                self.logger.log_train_augmentations(
589                    self._iteration, xu1, xu2, pseudo_labels, unsup_pred,
590                )
591
592                self.logger.log_combined_loss(self._iteration, loss)
593                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
594                self.logger.log_lr(self._iteration, lr)
595                if self.pseudo_labeler.confidence_threshold is not None:
596                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
597
598            self._iteration += 1
599            n_iter += 1
600            if self._iteration >= self.max_iteration:
601                break
602            progress.update(1)
603
604        t_per_iter = (time.time() - t_per_iter) / n_iter
605        return t_per_iter
606
607    def _validate_supervised(self, forward_context):
608        metric_val = 0.0
609        loss_val = 0.0
610
611        for x, y in self.supervised_val_loader:
612            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
613            with forward_context():
614                pred = self.model(x)
615                loss, metric = self.supervised_loss_and_metric(pred, y)
616            loss_val += loss.item()
617            metric_val += metric.item()
618
619        metric_val /= len(self.supervised_val_loader)
620        loss_val /= len(self.supervised_val_loader)
621
622        if self.logger is not None:
623            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
624
625        return metric_val
626
627    def _validate_unsupervised(self, forward_context):
628        metric_val = 0.0
629        loss_val = 0.0
630
631        for x in self.unsupervised_val_loader:
632            self.augmenter.reset_all()
633            x = x.to(self.device, non_blocking=True)
634
635            x1, x2 = self.augmenter.teacher.transform(x), self.augmenter.student.transform(x)
636            teacher_input, model_input = x1, x2
637
638            with forward_context():
639                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
640                pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
641                label_filter_inv = (
642                    self.augmenter.teacher.reverse_transform(label_filter)
643                    if label_filter is not None else None
644                )
645
646                pred = self.model(model_input)
647                pred_inv = self.augmenter.student.reverse_transform(pred)
648                loss, metric = self.unsupervised_loss_and_metric(pred_inv, pseudo_labels_inv, label_filter_inv)
649            loss_val += loss.item()
650            metric_val += metric.item()
651
652        metric_val /= len(self.unsupervised_val_loader)
653        loss_val /= len(self.unsupervised_val_loader)
654
655        if self.logger is not None:
656            self.logger.log_validation_unsupervised(
657                self._iteration, metric_val, loss_val, x, x, pred_inv, pseudo_labels_inv, label_filter_inv
658            )
659            self.logger.log_validation_augmentations(
660                self._iteration, x1, x2, pseudo_labels, pred,
661            )
662
663        self.pseudo_labeler.step(metric_val, self._epoch)
664
665        return metric_val
666
667    def _validate_impl(self, forward_context):
668        self.model.eval()
669
670        with torch.no_grad():
671
672            if self.supervised_val_loader is None:
673                supervised_metric = None
674            else:
675                supervised_metric = self._validate_supervised(forward_context)
676
677            if self.unsupervised_val_loader is None:
678                unsupervised_metric = None
679            else:
680                unsupervised_metric = self._validate_unsupervised(forward_context)
681
682        if unsupervised_metric is None:
683            metric = supervised_metric
684        elif supervised_metric is None:
685            metric = unsupervised_metric
686        else:
687            metric = (supervised_metric + unsupervised_metric) / 2
688
689        return metric
class FixMatchTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 14class FixMatchTrainer(torch_em.trainer.DefaultTrainer):
 15    """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.
 16
 17    FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685).
 18    It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels
 19    on unlabeled data. We support two training strategies:
 20    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 21    - Taining only on the unsupervised data.
 22
 23    This class expects the following data loaders:
 24    - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
 25    - supervised_train_loader (optional): Returns input and labels.
 26    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 27    - supervised_val_loader (optional): Same as supervised_train_loader
 28    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 29
 30    The following arguments can be used to customize the pseudo labeling:
 31    - pseudo_labeler: to compute the psuedo-labels
 32        - Parameters: model, teacher_input
 33        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 34    - unsupervised_loss: the loss between model predictions and pseudo labels
 35        - Parameters: model, model_input, pseudo_labels, label_filter
 36        - Returns: loss
 37    - supervised_loss (optional): the supervised loss function
 38        - Parameters: model, input, labels
 39        - Returns: loss
 40    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 41        - Parameters: model, model_input, pseudo_labels, label_filter
 42        - Returns: loss, metric
 43    - supervised_loss_and_metric (optional): the supervised loss function and metric
 44        - Parameters: model, input, labels
 45        - Returns: loss, metric
 46    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 47
 48    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 49    for setting the ratio between supervised and unsupervised training samples.
 50
 51    Args:
 52        model: The model to be trained.
 53        unsupervised_train_loader: The loader for unsupervised training.
 54        unsupervised_loss: The loss for unsupervised training.
 55        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 56        supervised_train_loader: The loader for supervised training.
 57        supervised_loss: The loss for supervised training.
 58        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 59        supervised_loss_and_metric: The loss and metrhic for supervised training.
 60        logger: The logger.
 61        source_distribution: The ratio of labels in the source label distribution.
 62            If given, the predicted distribution of the trained model will be regularized to
 63            match this source label distribution.
 64        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 65    """
 66
 67    def __init__(
 68        self,
 69        model: torch.nn.Module,
 70        unsupervised_train_loader: torch.utils.data.DataLoader,
 71        unsupervised_loss: torch.utils.data.DataLoader,
 72        pseudo_labeler: Callable,
 73        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 74        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 75        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 76        supervised_loss: Optional[Callable] = None,
 77        unsupervised_loss_and_metric: Optional[Callable] = None,
 78        supervised_loss_and_metric: Optional[Callable] = None,
 79        logger=SelfTrainingTensorboardLogger,
 80        source_distribution: List[float] = None,
 81        **kwargs,
 82    ):
 83        # Do we have supervised data or not?
 84        if supervised_train_loader is None:
 85            # No. -> We use the unsupervised training logic.
 86            train_loader = unsupervised_train_loader
 87            self._train_epoch_impl = self._train_epoch_unsupervised
 88        else:
 89            # Yes. -> We use the semi-supervised training logic.
 90            assert supervised_loss is not None
 91            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 92                else unsupervised_train_loader
 93            self._train_epoch_impl = self._train_epoch_semisupervised
 94
 95        self.unsupervised_train_loader = unsupervised_train_loader
 96        self.supervised_train_loader = supervised_train_loader
 97
 98        # Check that we have at least one of supvervised / unsupervised val loader.
 99        assert sum((
100            supervised_val_loader is not None,
101            unsupervised_val_loader is not None,
102        )) > 0
103        self.supervised_val_loader = supervised_val_loader
104        self.unsupervised_val_loader = unsupervised_val_loader
105
106        if self.unsupervised_val_loader is None:
107            val_loader = self.supervised_val_loader
108        else:
109            val_loader = self.unsupervised_val_loader
110
111        # Check that we have at least one of supvervised / unsupervised loss and metric.
112        assert sum((
113            supervised_loss_and_metric is not None,
114            unsupervised_loss_and_metric is not None,
115        )) > 0
116        self.supervised_loss_and_metric = supervised_loss_and_metric
117        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
118
119        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
120        kwargs.pop("train_loader", None)
121        kwargs.pop("val_loader", None)
122        kwargs.pop("metric", None)
123        kwargs.pop("loss", None)
124        super().__init__(
125            model=model, train_loader=train_loader, val_loader=val_loader,
126            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
127        )
128
129        self.unsupervised_loss = unsupervised_loss
130        self.supervised_loss = supervised_loss
131
132        self.pseudo_labeler = pseudo_labeler
133
134        if source_distribution is None:
135            self.source_distribution = None
136        else:
137            self.source_distribution = torch.FloatTensor(source_distribution).to(self.device)
138
139        self._kwargs = kwargs
140
141    #
142    # functionality for saving checkpoints and initialization
143    #
144
145    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
146        """@private
147        """
148        train_loader_kwargs = get_constructor_arguments(self.train_loader)
149        val_loader_kwargs = get_constructor_arguments(self.val_loader)
150        extra_state = {
151            "init": {
152                "train_loader_kwargs": train_loader_kwargs,
153                "train_dataset": self.train_loader.dataset,
154                "val_loader_kwargs": val_loader_kwargs,
155                "val_dataset": self.val_loader.dataset,
156                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
157                "loss_kwargs": {},
158                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
159                "metric_kwargs": {},
160            },
161        }
162        extra_state.update(**extra_save_dict)
163        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
164
165    # Distribution alignment:
166    # Encourages the distribution of the model's generated pseudo labels to match the marginal
167    # distribution of pseudo labels from the source transfer (key idea: to maximize the mutual information).
168    def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5):
169        """@private
170        """
171        if self.source_distribution is not None:
172            pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0)
173            _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True)
174            target_distribution = target_distribution / target_distribution.sum()
175            distribution_ratio = self.source_distribution / target_distribution
176            pseudo_labels = torch.where(
177                pseudo_labels < label_threshold,
178                pseudo_labels * distribution_ratio[0],
179                pseudo_labels * distribution_ratio[1]
180            ).clip(0, 1)
181
182        return pseudo_labels
183
184    #
185    # training and validation functionality
186    #
187
188    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
189        self.model.train()
190
191        n_iter = 0
192        t_per_iter = time.time()
193
194        # Sample from both the supervised and unsupervised loader.
195        for xu1, xu2 in self.unsupervised_train_loader:
196            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
197
198            teacher_input, model_input = xu1, xu2
199
200            with forward_context(), torch.no_grad():
201                # Compute the pseudo labels.
202                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
203
204            pseudo_labels = pseudo_labels.detach()
205            if label_filter is not None:
206                label_filter = label_filter.detach()
207
208            # Perform distribution alignment for pseudo labels
209            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
210
211            self.optimizer.zero_grad()
212            # Perform unsupervised training
213            with forward_context():
214                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
215
216            backprop(loss)
217
218            if self.logger is not None:
219                with torch.no_grad(), forward_context():
220                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
221                self.logger.log_train_unsupervised(
222                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
223                )
224                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
225                self.logger.log_lr(self._iteration, lr)
226                if self.pseudo_labeler.confidence_threshold is not None:
227                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
228
229            self._iteration += 1
230            n_iter += 1
231            if self._iteration >= self.max_iteration:
232                break
233            progress.update(1)
234
235        t_per_iter = (time.time() - t_per_iter) / n_iter
236        return t_per_iter
237
238    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
239        self.model.train()
240
241        n_iter = 0
242        t_per_iter = time.time()
243
244        # Sample from both the supervised and unsupervised loader.
245        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
246            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
247            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
248
249            # Perform supervised training.
250            self.optimizer.zero_grad()
251            with forward_context():
252                # We pass the model, the input and the labels to the supervised loss function,
253                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
254                supervised_loss = self.supervised_loss(self.model, xs, ys)
255
256            teacher_input, model_input = xu1, xu2
257
258            with forward_context(), torch.no_grad():
259                # Compute the pseudo labels.
260                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
261
262            pseudo_labels = pseudo_labels.detach()
263            if label_filter is not None:
264                label_filter = label_filter.detach()
265
266            # Perform distribution alignment for pseudo labels
267            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
268
269            # Perform unsupervised training
270            with forward_context():
271                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
272
273            loss = (supervised_loss + unsupervised_loss) / 2
274            backprop(loss)
275
276            if self.logger is not None:
277                with torch.no_grad(), forward_context():
278                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
279                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
280
281                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
282                self.logger.log_train_unsupervised(
283                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
284                )
285
286                self.logger.log_combined_loss(self._iteration, loss)
287                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
288                self.logger.log_lr(self._iteration, lr)
289                if self.pseudo_labeler.confidence_threshold is not None:
290                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
291
292            self._iteration += 1
293            n_iter += 1
294            if self._iteration >= self.max_iteration:
295                break
296            progress.update(1)
297
298        t_per_iter = (time.time() - t_per_iter) / n_iter
299        return t_per_iter
300
301    def _validate_supervised(self, forward_context):
302        metric_val = 0.0
303        loss_val = 0.0
304
305        for x, y in self.supervised_val_loader:
306            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
307            with forward_context():
308                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
309            loss_val += loss.item()
310            metric_val += metric.item()
311
312        metric_val /= len(self.supervised_val_loader)
313        loss_val /= len(self.supervised_val_loader)
314
315        if self.logger is not None:
316            with forward_context():
317                pred = self.model(x)
318            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
319
320        return metric_val
321
322    def _validate_unsupervised(self, forward_context):
323        metric_val = 0.0
324        loss_val = 0.0
325
326        for x1, x2 in self.unsupervised_val_loader:
327            x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True)
328            teacher_input, model_input = x1, x2
329            with forward_context():
330                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
331                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
332            loss_val += loss.item()
333            metric_val += metric.item()
334
335        metric_val /= len(self.unsupervised_val_loader)
336        loss_val /= len(self.unsupervised_val_loader)
337
338        if self.logger is not None:
339            with forward_context():
340                pred = self.model(model_input)
341            self.logger.log_validation_unsupervised(
342                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
343            )
344
345        return metric_val
346
347    def _validate_impl(self, forward_context):
348        self.model.eval()
349
350        with torch.no_grad():
351
352            if self.supervised_val_loader is None:
353                supervised_metric = None
354            else:
355                supervised_metric = self._validate_supervised(forward_context)
356
357            if self.unsupervised_val_loader is None:
358                unsupervised_metric = None
359            else:
360                unsupervised_metric = self._validate_unsupervised(forward_context)
361
362        if unsupervised_metric is None:
363            metric = supervised_metric
364        elif supervised_metric is None:
365            metric = unsupervised_metric
366        else:
367            metric = (supervised_metric + unsupervised_metric) / 2
368
369        return metric

Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.

FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels on unlabeled data. We support two training strategies:

  • Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
  • Taining only on the unsupervised data.

This class expects the following data loaders:

  • unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
  • supervised_train_loader (optional): Returns input and labels.
  • unsupervised_val_loader (optional): Same as unsupervised_train_loader
  • supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.

The following arguments can be used to customize the pseudo labeling:

  • pseudo_labeler: to compute the psuedo-labels
    • Parameters: model, teacher_input
    • Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
  • unsupervised_loss: the loss between model predictions and pseudo labels
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss
  • supervised_loss (optional): the supervised loss function
    • Parameters: model, input, labels
    • Returns: loss
  • unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss, metric
  • supervised_loss_and_metric (optional): the supervised loss function and metric
    • Parameters: model, input, labels
    • Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.

Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples.

Arguments:
  • model: The model to be trained.
  • unsupervised_train_loader: The loader for unsupervised training.
  • unsupervised_loss: The loss for unsupervised training.
  • pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
  • supervised_train_loader: The loader for supervised training.
  • supervised_loss: The loss for supervised training.
  • unsupervised_loss_and_metric: The loss and metric for unsupervised training.
  • supervised_loss_and_metric: The loss and metrhic for supervised training.
  • logger: The logger.
  • source_distribution: The ratio of labels in the source label distribution. If given, the predicted distribution of the trained model will be regularized to match this source label distribution.
  • kwargs: Additional keyword arguments for torch_em.trainer.DefaultTrainer.
FixMatchTrainer( model: torch.nn.modules.module.Module, unsupervised_train_loader: torch.utils.data.dataloader.DataLoader, unsupervised_loss: torch.utils.data.dataloader.DataLoader, pseudo_labeler: Callable, supervised_train_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, unsupervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_loss: Optional[Callable] = None, unsupervised_loss_and_metric: Optional[Callable] = None, supervised_loss_and_metric: Optional[Callable] = None, logger=<class 'torch_em.self_training.logger.SelfTrainingTensorboardLogger'>, source_distribution: List[float] = None, **kwargs)
 67    def __init__(
 68        self,
 69        model: torch.nn.Module,
 70        unsupervised_train_loader: torch.utils.data.DataLoader,
 71        unsupervised_loss: torch.utils.data.DataLoader,
 72        pseudo_labeler: Callable,
 73        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 74        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 75        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 76        supervised_loss: Optional[Callable] = None,
 77        unsupervised_loss_and_metric: Optional[Callable] = None,
 78        supervised_loss_and_metric: Optional[Callable] = None,
 79        logger=SelfTrainingTensorboardLogger,
 80        source_distribution: List[float] = None,
 81        **kwargs,
 82    ):
 83        # Do we have supervised data or not?
 84        if supervised_train_loader is None:
 85            # No. -> We use the unsupervised training logic.
 86            train_loader = unsupervised_train_loader
 87            self._train_epoch_impl = self._train_epoch_unsupervised
 88        else:
 89            # Yes. -> We use the semi-supervised training logic.
 90            assert supervised_loss is not None
 91            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 92                else unsupervised_train_loader
 93            self._train_epoch_impl = self._train_epoch_semisupervised
 94
 95        self.unsupervised_train_loader = unsupervised_train_loader
 96        self.supervised_train_loader = supervised_train_loader
 97
 98        # Check that we have at least one of supvervised / unsupervised val loader.
 99        assert sum((
100            supervised_val_loader is not None,
101            unsupervised_val_loader is not None,
102        )) > 0
103        self.supervised_val_loader = supervised_val_loader
104        self.unsupervised_val_loader = unsupervised_val_loader
105
106        if self.unsupervised_val_loader is None:
107            val_loader = self.supervised_val_loader
108        else:
109            val_loader = self.unsupervised_val_loader
110
111        # Check that we have at least one of supvervised / unsupervised loss and metric.
112        assert sum((
113            supervised_loss_and_metric is not None,
114            unsupervised_loss_and_metric is not None,
115        )) > 0
116        self.supervised_loss_and_metric = supervised_loss_and_metric
117        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
118
119        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
120        kwargs.pop("train_loader", None)
121        kwargs.pop("val_loader", None)
122        kwargs.pop("metric", None)
123        kwargs.pop("loss", None)
124        super().__init__(
125            model=model, train_loader=train_loader, val_loader=val_loader,
126            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
127        )
128
129        self.unsupervised_loss = unsupervised_loss
130        self.supervised_loss = supervised_loss
131
132        self.pseudo_labeler = pseudo_labeler
133
134        if source_distribution is None:
135            self.source_distribution = None
136        else:
137            self.source_distribution = torch.FloatTensor(source_distribution).to(self.device)
138
139        self._kwargs = kwargs
unsupervised_train_loader
supervised_train_loader
supervised_val_loader
unsupervised_val_loader
supervised_loss_and_metric
unsupervised_loss_and_metric
unsupervised_loss
supervised_loss
pseudo_labeler
class FixMatchTrainerWithInvertibleAugmentations(FixMatchTrainer):
372class FixMatchTrainerWithInvertibleAugmentations(FixMatchTrainer):
373    """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.
374
375    FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685).
376    It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels
377    on unlabeled data. We support two training strategies:
378    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
379    - Taining only on the unsupervised data.
380
381    This class expects the following data loaders:
382    - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
383    - supervised_train_loader (optional): Returns input and labels.
384    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
385    - supervised_val_loader (optional): Same as supervised_train_loader
386    At least one of unsupervised_val_loader and supervised_val_loader must be given.
387
388    The augmenter defines separate invertible transforms for teacher and student inputs.
389    Teacher and student views are generated independently, and the corresponding inverse
390    transforms map predictions and pseudo-labels back into a shared reference frame.
391
392    The following arguments can be used to customize the pseudo labeling:
393    - pseudo_labeler: to compute the psuedo-labels
394        - Parameters: model, teacher_input
395        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
396    - unsupervised_loss: the loss between model predictions and pseudo labels
397        - Parameters: model, model_input, pseudo_labels, label_filter
398        - Returns: loss
399    - supervised_loss (optional): the supervised loss function
400        - Parameters: model, input, labels
401        - Returns: loss
402    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
403        - Parameters: model, model_input, pseudo_labels, label_filter
404        - Returns: loss, metric
405    - supervised_loss_and_metric (optional): the supervised loss function and metric
406        - Parameters: model, input, labels
407        - Returns: loss, metric
408    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
409
410    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
411    for setting the ratio between supervised and unsupervised training samples.
412
413    Args:
414        model: The model to be trained.
415        unsupervised_train_loader: The loader for unsupervised training.
416        unsupervised_loss: The loss for unsupervised training.
417        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
418        augmenter: Invertible augmenter providing separate teacher and student transforms,
419            including corresponding inverse transforms to align predictions in a common frame.
420        supervised_train_loader: The loader for supervised training.
421        supervised_loss: The loss for supervised training.
422        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
423        supervised_loss_and_metric: The loss and metrhic for supervised training.
424        logger: The logger.
425        source_distribution: The ratio of labels in the source label distribution.
426            If given, the predicted distribution of the trained model will be regularized to
427            match this source label distribution.
428        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
429    """
430
431    def __init__(
432        self,
433        model: torch.nn.Module,
434        unsupervised_train_loader: torch.utils.data.DataLoader,
435        unsupervised_loss: torch.utils.data.DataLoader,
436        pseudo_labeler: Callable,
437        augmenter: InvertibleAugmenter,
438        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
439        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
440        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
441        supervised_loss: Optional[Callable] = None,
442        unsupervised_loss_and_metric: Optional[Callable] = None,
443        supervised_loss_and_metric: Optional[Callable] = None,
444        logger=SelfTrainingTensorboardLogger,
445        source_distribution: List[float] = None,
446        **kwargs,
447    ):
448        super().__init__(
449            model=model,
450            unsupervised_train_loader=unsupervised_train_loader,
451            unsupervised_loss=unsupervised_loss,
452            pseudo_labeler=pseudo_labeler,
453            supervised_train_loader=supervised_train_loader,
454            unsupervised_val_loader=unsupervised_val_loader,
455            supervised_val_loader=supervised_val_loader,
456            supervised_loss=supervised_loss,
457            unsupervised_loss_and_metric=unsupervised_loss_and_metric,
458            supervised_loss_and_metric=supervised_loss_and_metric,
459            logger=logger,
460            source_distribution=source_distribution,
461            **kwargs,
462        )
463
464        self.augmenter = augmenter
465
466    #
467    # training and validation functionality
468    #
469
470    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
471        self.model.train()
472
473        n_iter = 0
474        t_per_iter = time.time()
475
476        # Sample from both the supervised and unsupervised loader.
477        for xu in self.unsupervised_train_loader:
478            self.augmenter.reset_all()
479            xu = xu.to(self.device, non_blocking=True)
480
481            xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu)
482            teacher_input, model_input = xu1, xu2
483
484            with forward_context(), torch.no_grad():
485                # Compute the pseudo labels.
486                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
487
488            pseudo_labels = pseudo_labels.detach()
489            if label_filter is not None:
490                label_filter = label_filter.detach()
491
492            # Perform distribution alignment for pseudo labels, then invert into the reference frame.
493            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
494            pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
495            label_filter_inv = (
496                self.augmenter.teacher.reverse_transform(label_filter)
497                if label_filter is not None else None
498            )
499
500            self.optimizer.zero_grad()
501            # Perform unsupervised training
502            with forward_context():
503                pred = self.model(model_input)
504                pred_inv = self.augmenter.student.reverse_transform(pred)
505                loss = self.unsupervised_loss(pred_inv, pseudo_labels_inv, label_filter_inv)
506
507            backprop(loss)
508
509            if self.logger is not None:
510                with torch.no_grad(), forward_context():
511                    pred = pred if self._iteration % self.log_image_interval == 0 else None
512                self.logger.log_train_unsupervised(
513                    self._iteration, loss, xu, xu, pred_inv, pseudo_labels_inv, label_filter_inv
514                )
515                self.logger.log_train_augmentations(
516                    self._iteration, xu1, xu2, pseudo_labels, pred,
517                )
518                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
519                self.logger.log_lr(self._iteration, lr)
520                if self.pseudo_labeler.confidence_threshold is not None:
521                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
522
523            self._iteration += 1
524            n_iter += 1
525            if self._iteration >= self.max_iteration:
526                break
527            progress.update(1)
528
529        t_per_iter = (time.time() - t_per_iter) / n_iter
530        return t_per_iter
531
532    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
533        self.model.train()
534
535        n_iter = 0
536        t_per_iter = time.time()
537
538        # Sample from both the supervised and unsupervised loader.
539        for (xs, ys), xu in zip(self.supervised_train_loader, self.unsupervised_train_loader):
540            self.augmenter.reset_all()
541            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
542            xu = xu.to(self.device, non_blocking=True)
543
544            xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu)
545            teacher_input, model_input = xu1, xu2
546
547            # Perform supervised training.
548            self.optimizer.zero_grad()
549            with forward_context():
550                # We pass the model, the input and the labels to the supervised loss function,
551                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
552                supervised_pred = self.model(xs)
553                supervised_loss = self.supervised_loss(supervised_pred, ys)
554
555            with forward_context(), torch.no_grad():
556                # Compute the pseudo labels.
557                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
558
559            pseudo_labels = pseudo_labels.detach()
560            if label_filter is not None:
561                label_filter = label_filter.detach()
562
563            # Perform distribution alignment for pseudo labels, then invert into the reference frame.
564            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
565            pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
566            label_filter_inv = (
567                self.augmenter.teacher.reverse_transform(label_filter)
568                if label_filter is not None else None
569            )
570
571            # Perform unsupervised training
572            with forward_context():
573                unsup_pred = self.model(model_input)
574                unsup_pred_inv = self.augmenter.student.reverse_transform(unsup_pred)
575                unsupervised_loss = self.unsupervised_loss(unsup_pred_inv, pseudo_labels_inv, label_filter_inv)
576
577            loss = (supervised_loss + unsupervised_loss) / 2
578            backprop(loss)
579
580            if self.logger is not None:
581                with torch.no_grad(), forward_context():
582                    unsup_pred = unsup_pred if self._iteration % self.log_image_interval == 0 else None
583                    supervised_pred = supervised_pred if self._iteration % self.log_image_interval == 0 else None
584
585                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
586                self.logger.log_train_unsupervised(
587                    self._iteration, unsupervised_loss, xu, xu, unsup_pred_inv, pseudo_labels_inv, label_filter_inv
588                )
589                self.logger.log_train_augmentations(
590                    self._iteration, xu1, xu2, pseudo_labels, unsup_pred,
591                )
592
593                self.logger.log_combined_loss(self._iteration, loss)
594                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
595                self.logger.log_lr(self._iteration, lr)
596                if self.pseudo_labeler.confidence_threshold is not None:
597                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
598
599            self._iteration += 1
600            n_iter += 1
601            if self._iteration >= self.max_iteration:
602                break
603            progress.update(1)
604
605        t_per_iter = (time.time() - t_per_iter) / n_iter
606        return t_per_iter
607
608    def _validate_supervised(self, forward_context):
609        metric_val = 0.0
610        loss_val = 0.0
611
612        for x, y in self.supervised_val_loader:
613            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
614            with forward_context():
615                pred = self.model(x)
616                loss, metric = self.supervised_loss_and_metric(pred, y)
617            loss_val += loss.item()
618            metric_val += metric.item()
619
620        metric_val /= len(self.supervised_val_loader)
621        loss_val /= len(self.supervised_val_loader)
622
623        if self.logger is not None:
624            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
625
626        return metric_val
627
628    def _validate_unsupervised(self, forward_context):
629        metric_val = 0.0
630        loss_val = 0.0
631
632        for x in self.unsupervised_val_loader:
633            self.augmenter.reset_all()
634            x = x.to(self.device, non_blocking=True)
635
636            x1, x2 = self.augmenter.teacher.transform(x), self.augmenter.student.transform(x)
637            teacher_input, model_input = x1, x2
638
639            with forward_context():
640                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
641                pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels)
642                label_filter_inv = (
643                    self.augmenter.teacher.reverse_transform(label_filter)
644                    if label_filter is not None else None
645                )
646
647                pred = self.model(model_input)
648                pred_inv = self.augmenter.student.reverse_transform(pred)
649                loss, metric = self.unsupervised_loss_and_metric(pred_inv, pseudo_labels_inv, label_filter_inv)
650            loss_val += loss.item()
651            metric_val += metric.item()
652
653        metric_val /= len(self.unsupervised_val_loader)
654        loss_val /= len(self.unsupervised_val_loader)
655
656        if self.logger is not None:
657            self.logger.log_validation_unsupervised(
658                self._iteration, metric_val, loss_val, x, x, pred_inv, pseudo_labels_inv, label_filter_inv
659            )
660            self.logger.log_validation_augmentations(
661                self._iteration, x1, x2, pseudo_labels, pred,
662            )
663
664        self.pseudo_labeler.step(metric_val, self._epoch)
665
666        return metric_val
667
668    def _validate_impl(self, forward_context):
669        self.model.eval()
670
671        with torch.no_grad():
672
673            if self.supervised_val_loader is None:
674                supervised_metric = None
675            else:
676                supervised_metric = self._validate_supervised(forward_context)
677
678            if self.unsupervised_val_loader is None:
679                unsupervised_metric = None
680            else:
681                unsupervised_metric = self._validate_unsupervised(forward_context)
682
683        if unsupervised_metric is None:
684            metric = supervised_metric
685        elif supervised_metric is None:
686            metric = unsupervised_metric
687        else:
688            metric = (supervised_metric + unsupervised_metric) / 2
689
690        return metric

Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.

FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels on unlabeled data. We support two training strategies:

  • Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
  • Taining only on the unsupervised data.

This class expects the following data loaders:

  • unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
  • supervised_train_loader (optional): Returns input and labels.
  • unsupervised_val_loader (optional): Same as unsupervised_train_loader
  • supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.

The augmenter defines separate invertible transforms for teacher and student inputs. Teacher and student views are generated independently, and the corresponding inverse transforms map predictions and pseudo-labels back into a shared reference frame.

The following arguments can be used to customize the pseudo labeling:

  • pseudo_labeler: to compute the psuedo-labels
    • Parameters: model, teacher_input
    • Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
  • unsupervised_loss: the loss between model predictions and pseudo labels
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss
  • supervised_loss (optional): the supervised loss function
    • Parameters: model, input, labels
    • Returns: loss
  • unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss, metric
  • supervised_loss_and_metric (optional): the supervised loss function and metric
    • Parameters: model, input, labels
    • Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.

Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples.

Arguments:
  • model: The model to be trained.
  • unsupervised_train_loader: The loader for unsupervised training.
  • unsupervised_loss: The loss for unsupervised training.
  • pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
  • augmenter: Invertible augmenter providing separate teacher and student transforms, including corresponding inverse transforms to align predictions in a common frame.
  • supervised_train_loader: The loader for supervised training.
  • supervised_loss: The loss for supervised training.
  • unsupervised_loss_and_metric: The loss and metric for unsupervised training.
  • supervised_loss_and_metric: The loss and metrhic for supervised training.
  • logger: The logger.
  • source_distribution: The ratio of labels in the source label distribution. If given, the predicted distribution of the trained model will be regularized to match this source label distribution.
  • kwargs: Additional keyword arguments for torch_em.trainer.DefaultTrainer.
FixMatchTrainerWithInvertibleAugmentations( model: torch.nn.modules.module.Module, unsupervised_train_loader: torch.utils.data.dataloader.DataLoader, unsupervised_loss: torch.utils.data.dataloader.DataLoader, pseudo_labeler: Callable, augmenter: torch_em.transform.invertible_augmentations.InvertibleAugmenter, supervised_train_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, unsupervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_loss: Optional[Callable] = None, unsupervised_loss_and_metric: Optional[Callable] = None, supervised_loss_and_metric: Optional[Callable] = None, logger=<class 'torch_em.self_training.logger.SelfTrainingTensorboardLogger'>, source_distribution: List[float] = None, **kwargs)
431    def __init__(
432        self,
433        model: torch.nn.Module,
434        unsupervised_train_loader: torch.utils.data.DataLoader,
435        unsupervised_loss: torch.utils.data.DataLoader,
436        pseudo_labeler: Callable,
437        augmenter: InvertibleAugmenter,
438        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
439        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
440        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
441        supervised_loss: Optional[Callable] = None,
442        unsupervised_loss_and_metric: Optional[Callable] = None,
443        supervised_loss_and_metric: Optional[Callable] = None,
444        logger=SelfTrainingTensorboardLogger,
445        source_distribution: List[float] = None,
446        **kwargs,
447    ):
448        super().__init__(
449            model=model,
450            unsupervised_train_loader=unsupervised_train_loader,
451            unsupervised_loss=unsupervised_loss,
452            pseudo_labeler=pseudo_labeler,
453            supervised_train_loader=supervised_train_loader,
454            unsupervised_val_loader=unsupervised_val_loader,
455            supervised_val_loader=supervised_val_loader,
456            supervised_loss=supervised_loss,
457            unsupervised_loss_and_metric=unsupervised_loss_and_metric,
458            supervised_loss_and_metric=supervised_loss_and_metric,
459            logger=logger,
460            source_distribution=source_distribution,
461            **kwargs,
462        )
463
464        self.augmenter = augmenter
augmenter