torch_em.self_training.fix_match

  1import time
  2from typing import Callable, List, Optional
  3
  4import torch
  5import torch_em
  6from torch_em.util import get_constructor_arguments
  7
  8from .logger import SelfTrainingTensorboardLogger
  9from .mean_teacher import Dummy
 10
 11
 12class FixMatchTrainer(torch_em.trainer.DefaultTrainer):
 13    """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.
 14
 15    FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685).
 16    It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels
 17    on unlabeled data. We support two training strategies:
 18    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 19    - Taining only on the unsupervised data.
 20
 21    This class expects the following data loaders:
 22    - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
 23    - supervised_train_loader (optional): Returns input and labels.
 24    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 25    - supervised_val_loader (optional): Same as supervised_train_loader
 26    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 27
 28    The following arguments can be used to customize the pseudo labeling:
 29    - pseudo_labeler: to compute the psuedo-labels
 30        - Parameters: model, teacher_input
 31        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 32    - unsupervised_loss: the loss between model predictions and pseudo labels
 33        - Parameters: model, model_input, pseudo_labels, label_filter
 34        - Returns: loss
 35    - supervised_loss (optional): the supervised loss function
 36        - Parameters: model, input, labels
 37        - Returns: loss
 38    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 39        - Parameters: model, model_input, pseudo_labels, label_filter
 40        - Returns: loss, metric
 41    - supervised_loss_and_metric (optional): the supervised loss function and metric
 42        - Parameters: model, input, labels
 43        - Returns: loss, metric
 44    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 45
 46    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 47    for setting the ratio between supervised and unsupervised training samples.
 48
 49    Args:
 50        model: The model to be trained.
 51        unsupervised_train_loader: The loader for unsupervised training.
 52        unsupervised_loss: The loss for unsupervised training.
 53        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 54        supervised_train_loader: The loader for supervised training.
 55        supervised_loss: The loss for supervised training.
 56        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 57        supervised_loss_and_metric: The loss and metrhic for supervised training.
 58        logger: The logger.
 59        source_distribution: The ratio of labels in the source label distribution.
 60            If given, the predicted distribution of the trained model will be regularized to
 61            match this source label distribution.
 62        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 63    """
 64
 65    def __init__(
 66        self,
 67        model: torch.nn.Module,
 68        unsupervised_train_loader: torch.utils.data.DataLoader,
 69        unsupervised_loss: torch.utils.data.DataLoader,
 70        pseudo_labeler: Callable,
 71        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 72        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 73        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 74        supervised_loss: Optional[Callable] = None,
 75        unsupervised_loss_and_metric: Optional[Callable] = None,
 76        supervised_loss_and_metric: Optional[Callable] = None,
 77        logger=SelfTrainingTensorboardLogger,
 78        source_distribution: List[float] = None,
 79        **kwargs,
 80    ):
 81        # Do we have supervised data or not?
 82        if supervised_train_loader is None:
 83            # No. -> We use the unsupervised training logic.
 84            train_loader = unsupervised_train_loader
 85            self._train_epoch_impl = self._train_epoch_unsupervised
 86        else:
 87            # Yes. -> We use the semi-supervised training logic.
 88            assert supervised_loss is not None
 89            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 90                else unsupervised_train_loader
 91            self._train_epoch_impl = self._train_epoch_semisupervised
 92
 93        self.unsupervised_train_loader = unsupervised_train_loader
 94        self.supervised_train_loader = supervised_train_loader
 95
 96        # Check that we have at least one of supvervised / unsupervised val loader.
 97        assert sum((
 98            supervised_val_loader is not None,
 99            unsupervised_val_loader is not None,
100        )) > 0
101        self.supervised_val_loader = supervised_val_loader
102        self.unsupervised_val_loader = unsupervised_val_loader
103
104        if self.unsupervised_val_loader is None:
105            val_loader = self.supervised_val_loader
106        else:
107            val_loader = self.unsupervised_train_loader
108
109        # Check that we have at least one of supvervised / unsupervised loss and metric.
110        assert sum((
111            supervised_loss_and_metric is not None,
112            unsupervised_loss_and_metric is not None,
113        )) > 0
114        self.supervised_loss_and_metric = supervised_loss_and_metric
115        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
116
117        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
118        kwargs.pop("train_loader", None)
119        kwargs.pop("val_loader", None)
120        kwargs.pop("metric", None)
121        kwargs.pop("loss", None)
122        super().__init__(
123            model=model, train_loader=train_loader, val_loader=val_loader,
124            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
125        )
126
127        self.unsupervised_loss = unsupervised_loss
128        self.supervised_loss = supervised_loss
129
130        self.pseudo_labeler = pseudo_labeler
131
132        if source_distribution is None:
133            self.source_distribution = None
134        else:
135            self.source_distribution = torch.FloatTensor(source_distribution).to(self.device)
136
137        self._kwargs = kwargs
138
139    #
140    # functionality for saving checkpoints and initialization
141    #
142
143    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
144        """@private
145        """
146        train_loader_kwargs = get_constructor_arguments(self.train_loader)
147        val_loader_kwargs = get_constructor_arguments(self.val_loader)
148        extra_state = {
149            "init": {
150                "train_loader_kwargs": train_loader_kwargs,
151                "train_dataset": self.train_loader.dataset,
152                "val_loader_kwargs": val_loader_kwargs,
153                "val_dataset": self.val_loader.dataset,
154                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
155                "loss_kwargs": {},
156                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
157                "metric_kwargs": {},
158            },
159        }
160        extra_state.update(**extra_save_dict)
161        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
162
163    # Distribution alignment:
164    # Encourages the distribution of the model's generated pseudo labels to match the marginal
165    # distribution of pseudo labels from the source transfer (key idea: to maximize the mutual information).
166    def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5):
167        """@private
168        """
169        if self.source_distribution is not None:
170            pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0)
171            _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True)
172            target_distribution = target_distribution / target_distribution.sum()
173            distribution_ratio = self.source_distribution / target_distribution
174            pseudo_labels = torch.where(
175                pseudo_labels < label_threshold,
176                pseudo_labels * distribution_ratio[0],
177                pseudo_labels * distribution_ratio[1]
178            ).clip(0, 1)
179
180        return pseudo_labels
181
182    #
183    # training and validation functionality
184    #
185
186    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
187        self.model.train()
188
189        n_iter = 0
190        t_per_iter = time.time()
191
192        # Sample from both the supervised and unsupervised loader.
193        for xu1, xu2 in self.unsupervised_train_loader:
194            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
195
196            teacher_input, model_input = xu1, xu2
197
198            with forward_context(), torch.no_grad():
199                # Compute the pseudo labels.
200                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
201
202            pseudo_labels = pseudo_labels.detach()
203            if label_filter is not None:
204                label_filter = label_filter.detach()
205
206            # Perform distribution alignment for pseudo labels
207            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
208
209            self.optimizer.zero_grad()
210            # Perform unsupervised training
211            with forward_context():
212                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
213
214            backprop(loss)
215
216            if self.logger is not None:
217                with torch.no_grad(), forward_context():
218                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
219                self.logger.log_train_unsupervised(
220                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
221                )
222                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
223                self.logger.log_lr(self._iteration, lr)
224
225            self._iteration += 1
226            n_iter += 1
227            if self._iteration >= self.max_iteration:
228                break
229            progress.update(1)
230
231        t_per_iter = (time.time() - t_per_iter) / n_iter
232        return t_per_iter
233
234    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
235        self.model.train()
236
237        n_iter = 0
238        t_per_iter = time.time()
239
240        # Sample from both the supervised and unsupervised loader.
241        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
242            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
243            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
244
245            # Perform supervised training.
246            self.optimizer.zero_grad()
247            with forward_context():
248                # We pass the model, the input and the labels to the supervised loss function,
249                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
250                supervised_loss = self.supervised_loss(self.model, xs, ys)
251
252            teacher_input, model_input = xu1, xu2
253
254            with forward_context(), torch.no_grad():
255                # Compute the pseudo labels.
256                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
257
258            pseudo_labels = pseudo_labels.detach()
259            if label_filter is not None:
260                label_filter = label_filter.detach()
261
262            # Perform distribution alignment for pseudo labels
263            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
264
265            # Perform unsupervised training
266            with forward_context():
267                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
268
269            loss = (supervised_loss + unsupervised_loss) / 2
270            backprop(loss)
271
272            if self.logger is not None:
273                with torch.no_grad(), forward_context():
274                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
275                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
276
277                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
278                self.logger.log_train_unsupervised(
279                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
280                )
281
282                self.logger.log_combined_loss(self._iteration, loss)
283                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
284                self.logger.log_lr(self._iteration, lr)
285
286            self._iteration += 1
287            n_iter += 1
288            if self._iteration >= self.max_iteration:
289                break
290            progress.update(1)
291
292        t_per_iter = (time.time() - t_per_iter) / n_iter
293        return t_per_iter
294
295    def _validate_supervised(self, forward_context):
296        metric_val = 0.0
297        loss_val = 0.0
298
299        for x, y in self.supervised_val_loader:
300            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
301            with forward_context():
302                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
303            loss_val += loss.item()
304            metric_val += metric.item()
305
306        metric_val /= len(self.supervised_val_loader)
307        loss_val /= len(self.supervised_val_loader)
308
309        if self.logger is not None:
310            with forward_context():
311                pred = self.model(x)
312            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
313
314        return metric_val
315
316    def _validate_unsupervised(self, forward_context):
317        metric_val = 0.0
318        loss_val = 0.0
319
320        for x1, x2 in self.unsupervised_val_loader:
321            x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True)
322            teacher_input, model_input = x1, x2
323            with forward_context():
324                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
325                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
326            loss_val += loss.item()
327            metric_val += metric.item()
328
329        metric_val /= len(self.unsupervised_val_loader)
330        loss_val /= len(self.unsupervised_val_loader)
331
332        if self.logger is not None:
333            with forward_context():
334                pred = self.model(model_input)
335            self.logger.log_validation_unsupervised(
336                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
337            )
338
339        return metric_val
340
341    def _validate_impl(self, forward_context):
342        self.model.eval()
343
344        with torch.no_grad():
345
346            if self.supervised_val_loader is None:
347                supervised_metric = None
348            else:
349                supervised_metric = self._validate_supervised(forward_context)
350
351            if self.unsupervised_val_loader is None:
352                unsupervised_metric = None
353            else:
354                unsupervised_metric = self._validate_unsupervised(forward_context)
355
356        if unsupervised_metric is None:
357            metric = supervised_metric
358        elif supervised_metric is None:
359            metric = unsupervised_metric
360        else:
361            metric = (supervised_metric + unsupervised_metric) / 2
362
363        return metric
class FixMatchTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 13class FixMatchTrainer(torch_em.trainer.DefaultTrainer):
 14    """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.
 15
 16    FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685).
 17    It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels
 18    on unlabeled data. We support two training strategies:
 19    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 20    - Taining only on the unsupervised data.
 21
 22    This class expects the following data loaders:
 23    - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
 24    - supervised_train_loader (optional): Returns input and labels.
 25    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 26    - supervised_val_loader (optional): Same as supervised_train_loader
 27    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 28
 29    The following arguments can be used to customize the pseudo labeling:
 30    - pseudo_labeler: to compute the psuedo-labels
 31        - Parameters: model, teacher_input
 32        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 33    - unsupervised_loss: the loss between model predictions and pseudo labels
 34        - Parameters: model, model_input, pseudo_labels, label_filter
 35        - Returns: loss
 36    - supervised_loss (optional): the supervised loss function
 37        - Parameters: model, input, labels
 38        - Returns: loss
 39    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 40        - Parameters: model, model_input, pseudo_labels, label_filter
 41        - Returns: loss, metric
 42    - supervised_loss_and_metric (optional): the supervised loss function and metric
 43        - Parameters: model, input, labels
 44        - Returns: loss, metric
 45    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 46
 47    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 48    for setting the ratio between supervised and unsupervised training samples.
 49
 50    Args:
 51        model: The model to be trained.
 52        unsupervised_train_loader: The loader for unsupervised training.
 53        unsupervised_loss: The loss for unsupervised training.
 54        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 55        supervised_train_loader: The loader for supervised training.
 56        supervised_loss: The loss for supervised training.
 57        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 58        supervised_loss_and_metric: The loss and metrhic for supervised training.
 59        logger: The logger.
 60        source_distribution: The ratio of labels in the source label distribution.
 61            If given, the predicted distribution of the trained model will be regularized to
 62            match this source label distribution.
 63        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 64    """
 65
 66    def __init__(
 67        self,
 68        model: torch.nn.Module,
 69        unsupervised_train_loader: torch.utils.data.DataLoader,
 70        unsupervised_loss: torch.utils.data.DataLoader,
 71        pseudo_labeler: Callable,
 72        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 73        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 74        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 75        supervised_loss: Optional[Callable] = None,
 76        unsupervised_loss_and_metric: Optional[Callable] = None,
 77        supervised_loss_and_metric: Optional[Callable] = None,
 78        logger=SelfTrainingTensorboardLogger,
 79        source_distribution: List[float] = None,
 80        **kwargs,
 81    ):
 82        # Do we have supervised data or not?
 83        if supervised_train_loader is None:
 84            # No. -> We use the unsupervised training logic.
 85            train_loader = unsupervised_train_loader
 86            self._train_epoch_impl = self._train_epoch_unsupervised
 87        else:
 88            # Yes. -> We use the semi-supervised training logic.
 89            assert supervised_loss is not None
 90            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 91                else unsupervised_train_loader
 92            self._train_epoch_impl = self._train_epoch_semisupervised
 93
 94        self.unsupervised_train_loader = unsupervised_train_loader
 95        self.supervised_train_loader = supervised_train_loader
 96
 97        # Check that we have at least one of supvervised / unsupervised val loader.
 98        assert sum((
 99            supervised_val_loader is not None,
100            unsupervised_val_loader is not None,
101        )) > 0
102        self.supervised_val_loader = supervised_val_loader
103        self.unsupervised_val_loader = unsupervised_val_loader
104
105        if self.unsupervised_val_loader is None:
106            val_loader = self.supervised_val_loader
107        else:
108            val_loader = self.unsupervised_train_loader
109
110        # Check that we have at least one of supvervised / unsupervised loss and metric.
111        assert sum((
112            supervised_loss_and_metric is not None,
113            unsupervised_loss_and_metric is not None,
114        )) > 0
115        self.supervised_loss_and_metric = supervised_loss_and_metric
116        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
117
118        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
119        kwargs.pop("train_loader", None)
120        kwargs.pop("val_loader", None)
121        kwargs.pop("metric", None)
122        kwargs.pop("loss", None)
123        super().__init__(
124            model=model, train_loader=train_loader, val_loader=val_loader,
125            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
126        )
127
128        self.unsupervised_loss = unsupervised_loss
129        self.supervised_loss = supervised_loss
130
131        self.pseudo_labeler = pseudo_labeler
132
133        if source_distribution is None:
134            self.source_distribution = None
135        else:
136            self.source_distribution = torch.FloatTensor(source_distribution).to(self.device)
137
138        self._kwargs = kwargs
139
140    #
141    # functionality for saving checkpoints and initialization
142    #
143
144    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
145        """@private
146        """
147        train_loader_kwargs = get_constructor_arguments(self.train_loader)
148        val_loader_kwargs = get_constructor_arguments(self.val_loader)
149        extra_state = {
150            "init": {
151                "train_loader_kwargs": train_loader_kwargs,
152                "train_dataset": self.train_loader.dataset,
153                "val_loader_kwargs": val_loader_kwargs,
154                "val_dataset": self.val_loader.dataset,
155                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
156                "loss_kwargs": {},
157                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
158                "metric_kwargs": {},
159            },
160        }
161        extra_state.update(**extra_save_dict)
162        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
163
164    # Distribution alignment:
165    # Encourages the distribution of the model's generated pseudo labels to match the marginal
166    # distribution of pseudo labels from the source transfer (key idea: to maximize the mutual information).
167    def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5):
168        """@private
169        """
170        if self.source_distribution is not None:
171            pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0)
172            _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True)
173            target_distribution = target_distribution / target_distribution.sum()
174            distribution_ratio = self.source_distribution / target_distribution
175            pseudo_labels = torch.where(
176                pseudo_labels < label_threshold,
177                pseudo_labels * distribution_ratio[0],
178                pseudo_labels * distribution_ratio[1]
179            ).clip(0, 1)
180
181        return pseudo_labels
182
183    #
184    # training and validation functionality
185    #
186
187    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
188        self.model.train()
189
190        n_iter = 0
191        t_per_iter = time.time()
192
193        # Sample from both the supervised and unsupervised loader.
194        for xu1, xu2 in self.unsupervised_train_loader:
195            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
196
197            teacher_input, model_input = xu1, xu2
198
199            with forward_context(), torch.no_grad():
200                # Compute the pseudo labels.
201                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
202
203            pseudo_labels = pseudo_labels.detach()
204            if label_filter is not None:
205                label_filter = label_filter.detach()
206
207            # Perform distribution alignment for pseudo labels
208            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
209
210            self.optimizer.zero_grad()
211            # Perform unsupervised training
212            with forward_context():
213                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
214
215            backprop(loss)
216
217            if self.logger is not None:
218                with torch.no_grad(), forward_context():
219                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
220                self.logger.log_train_unsupervised(
221                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
222                )
223                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
224                self.logger.log_lr(self._iteration, lr)
225
226            self._iteration += 1
227            n_iter += 1
228            if self._iteration >= self.max_iteration:
229                break
230            progress.update(1)
231
232        t_per_iter = (time.time() - t_per_iter) / n_iter
233        return t_per_iter
234
235    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
236        self.model.train()
237
238        n_iter = 0
239        t_per_iter = time.time()
240
241        # Sample from both the supervised and unsupervised loader.
242        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
243            xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True)
244            xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
245
246            # Perform supervised training.
247            self.optimizer.zero_grad()
248            with forward_context():
249                # We pass the model, the input and the labels to the supervised loss function,
250                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
251                supervised_loss = self.supervised_loss(self.model, xs, ys)
252
253            teacher_input, model_input = xu1, xu2
254
255            with forward_context(), torch.no_grad():
256                # Compute the pseudo labels.
257                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
258
259            pseudo_labels = pseudo_labels.detach()
260            if label_filter is not None:
261                label_filter = label_filter.detach()
262
263            # Perform distribution alignment for pseudo labels
264            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
265
266            # Perform unsupervised training
267            with forward_context():
268                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
269
270            loss = (supervised_loss + unsupervised_loss) / 2
271            backprop(loss)
272
273            if self.logger is not None:
274                with torch.no_grad(), forward_context():
275                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
276                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
277
278                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
279                self.logger.log_train_unsupervised(
280                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
281                )
282
283                self.logger.log_combined_loss(self._iteration, loss)
284                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
285                self.logger.log_lr(self._iteration, lr)
286
287            self._iteration += 1
288            n_iter += 1
289            if self._iteration >= self.max_iteration:
290                break
291            progress.update(1)
292
293        t_per_iter = (time.time() - t_per_iter) / n_iter
294        return t_per_iter
295
296    def _validate_supervised(self, forward_context):
297        metric_val = 0.0
298        loss_val = 0.0
299
300        for x, y in self.supervised_val_loader:
301            x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
302            with forward_context():
303                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
304            loss_val += loss.item()
305            metric_val += metric.item()
306
307        metric_val /= len(self.supervised_val_loader)
308        loss_val /= len(self.supervised_val_loader)
309
310        if self.logger is not None:
311            with forward_context():
312                pred = self.model(x)
313            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
314
315        return metric_val
316
317    def _validate_unsupervised(self, forward_context):
318        metric_val = 0.0
319        loss_val = 0.0
320
321        for x1, x2 in self.unsupervised_val_loader:
322            x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True)
323            teacher_input, model_input = x1, x2
324            with forward_context():
325                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
326                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
327            loss_val += loss.item()
328            metric_val += metric.item()
329
330        metric_val /= len(self.unsupervised_val_loader)
331        loss_val /= len(self.unsupervised_val_loader)
332
333        if self.logger is not None:
334            with forward_context():
335                pred = self.model(model_input)
336            self.logger.log_validation_unsupervised(
337                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
338            )
339
340        return metric_val
341
342    def _validate_impl(self, forward_context):
343        self.model.eval()
344
345        with torch.no_grad():
346
347            if self.supervised_val_loader is None:
348                supervised_metric = None
349            else:
350                supervised_metric = self._validate_supervised(forward_context)
351
352            if self.unsupervised_val_loader is None:
353                unsupervised_metric = None
354            else:
355                unsupervised_metric = self._validate_unsupervised(forward_context)
356
357        if unsupervised_metric is None:
358            metric = supervised_metric
359        elif supervised_metric is None:
360            metric = unsupervised_metric
361        else:
362            metric = (supervised_metric + unsupervised_metric) / 2
363
364        return metric

Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.

FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels on unlabeled data. We support two training strategies:

  • Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
  • Taining only on the unsupervised data.

This class expects the following data loaders:

  • unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
  • supervised_train_loader (optional): Returns input and labels.
  • unsupervised_val_loader (optional): Same as unsupervised_train_loader
  • supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.

The following arguments can be used to customize the pseudo labeling:

  • pseudo_labeler: to compute the psuedo-labels
    • Parameters: model, teacher_input
    • Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
  • unsupervised_loss: the loss between model predictions and pseudo labels
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss
  • supervised_loss (optional): the supervised loss function
    • Parameters: model, input, labels
    • Returns: loss
  • unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss, metric
  • supervised_loss_and_metric (optional): the supervised loss function and metric
    • Parameters: model, input, labels
    • Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.

Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples.

Arguments:
  • model: The model to be trained.
  • unsupervised_train_loader: The loader for unsupervised training.
  • unsupervised_loss: The loss for unsupervised training.
  • pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
  • supervised_train_loader: The loader for supervised training.
  • supervised_loss: The loss for supervised training.
  • unsupervised_loss_and_metric: The loss and metric for unsupervised training.
  • supervised_loss_and_metric: The loss and metrhic for supervised training.
  • logger: The logger.
  • source_distribution: The ratio of labels in the source label distribution. If given, the predicted distribution of the trained model will be regularized to match this source label distribution.
  • kwargs: Additional keyword arguments for torch_em.trainer.DefaultTrainer.
FixMatchTrainer( model: torch.nn.modules.module.Module, unsupervised_train_loader: torch.utils.data.dataloader.DataLoader, unsupervised_loss: torch.utils.data.dataloader.DataLoader, pseudo_labeler: Callable, supervised_train_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, unsupervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_loss: Optional[Callable] = None, unsupervised_loss_and_metric: Optional[Callable] = None, supervised_loss_and_metric: Optional[Callable] = None, logger=<class 'torch_em.self_training.logger.SelfTrainingTensorboardLogger'>, source_distribution: List[float] = None, **kwargs)
 66    def __init__(
 67        self,
 68        model: torch.nn.Module,
 69        unsupervised_train_loader: torch.utils.data.DataLoader,
 70        unsupervised_loss: torch.utils.data.DataLoader,
 71        pseudo_labeler: Callable,
 72        supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
 73        unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 74        supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
 75        supervised_loss: Optional[Callable] = None,
 76        unsupervised_loss_and_metric: Optional[Callable] = None,
 77        supervised_loss_and_metric: Optional[Callable] = None,
 78        logger=SelfTrainingTensorboardLogger,
 79        source_distribution: List[float] = None,
 80        **kwargs,
 81    ):
 82        # Do we have supervised data or not?
 83        if supervised_train_loader is None:
 84            # No. -> We use the unsupervised training logic.
 85            train_loader = unsupervised_train_loader
 86            self._train_epoch_impl = self._train_epoch_unsupervised
 87        else:
 88            # Yes. -> We use the semi-supervised training logic.
 89            assert supervised_loss is not None
 90            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 91                else unsupervised_train_loader
 92            self._train_epoch_impl = self._train_epoch_semisupervised
 93
 94        self.unsupervised_train_loader = unsupervised_train_loader
 95        self.supervised_train_loader = supervised_train_loader
 96
 97        # Check that we have at least one of supvervised / unsupervised val loader.
 98        assert sum((
 99            supervised_val_loader is not None,
100            unsupervised_val_loader is not None,
101        )) > 0
102        self.supervised_val_loader = supervised_val_loader
103        self.unsupervised_val_loader = unsupervised_val_loader
104
105        if self.unsupervised_val_loader is None:
106            val_loader = self.supervised_val_loader
107        else:
108            val_loader = self.unsupervised_train_loader
109
110        # Check that we have at least one of supvervised / unsupervised loss and metric.
111        assert sum((
112            supervised_loss_and_metric is not None,
113            unsupervised_loss_and_metric is not None,
114        )) > 0
115        self.supervised_loss_and_metric = supervised_loss_and_metric
116        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
117
118        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
119        kwargs.pop("train_loader", None)
120        kwargs.pop("val_loader", None)
121        kwargs.pop("metric", None)
122        kwargs.pop("loss", None)
123        super().__init__(
124            model=model, train_loader=train_loader, val_loader=val_loader,
125            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
126        )
127
128        self.unsupervised_loss = unsupervised_loss
129        self.supervised_loss = supervised_loss
130
131        self.pseudo_labeler = pseudo_labeler
132
133        if source_distribution is None:
134            self.source_distribution = None
135        else:
136            self.source_distribution = torch.FloatTensor(source_distribution).to(self.device)
137
138        self._kwargs = kwargs
unsupervised_train_loader
supervised_train_loader
supervised_val_loader
unsupervised_val_loader
supervised_loss_and_metric
unsupervised_loss_and_metric
unsupervised_loss
supervised_loss
pseudo_labeler