torch_em.self_training.fix_match

  1import time
  2
  3import torch
  4import torch_em
  5from torch_em.util import get_constructor_arguments
  6
  7from .logger import SelfTrainingTensorboardLogger
  8from .mean_teacher import Dummy
  9
 10
 11class FixMatchTrainer(torch_em.trainer.DefaultTrainer):
 12    """This trainer implements self-traning for semi-supervised learning and domain following the 'FixMatch' approach
 13    of Sohn et al. (https://arxiv.org/abs/2001.07685). This approach uses a (teacher) model derived from the
 14    student model via sharing the weights to predict pseudo-labels on unlabeled data.
 15    We support two training strategies: joint training on labeled and unlabeled data
 16    (with a supervised and unsupervised loss function). And training only on the unsupervised data.
 17
 18    This class expects the following data loaders:
 19    - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
 20    - supervised_train_loader (optional): Returns input and labels.
 21    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 22    - supervised_val_loader (optional): Same as supervised_train_loader
 23    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 24
 25    And the following elements to customize the pseudo labeling:
 26    - pseudo_labeler: to compute the psuedo-labels
 27        - Parameters: model, teacher_input
 28        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 29    - unsupervised_loss: the loss between model predictions and pseudo labels
 30        - Parameters: model, model_input, pseudo_labels, label_filter
 31        - Returns: loss
 32    - supervised_loss (optional): the supervised loss function
 33        - Parameters: model, input, labels
 34        - Returns: loss
 35    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 36        - Parameters: model, model_input, pseudo_labels, label_filter
 37        - Returns: loss, metric
 38    - supervised_loss_and_metric (optional): the supervised loss function and metric
 39        - Parameters: model, input, labels
 40        - Returns: loss, metric
 41    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 42
 43    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 44    for setting the ratio between supervised and unsupervised training samples
 45
 46    Parameters:
 47        model [nn.Module] -
 48        unsupervised_train_loader [torch.DataLoader] -
 49        unsupervised_loss [callable] -
 50        pseudo_labeler [callable] -
 51        supervised_train_loader [torch.DataLoader] - (default: None)
 52        supervised_loss [callable] - (default: None)
 53        unsupervised_loss_and_metric [callable] - (default: None)
 54        supervised_loss_and_metric [callable] - (default: None)
 55        logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger)
 56        momentum [float] - (default: 0.999)
 57        source_distribution [list] - (default: None)
 58        **kwargs - keyword arguments for torch_em.DataLoader
 59    """
 60
 61    def __init__(
 62        self,
 63        model,
 64        unsupervised_train_loader,
 65        unsupervised_loss,
 66        pseudo_labeler,
 67        supervised_train_loader=None,
 68        unsupervised_val_loader=None,
 69        supervised_val_loader=None,
 70        supervised_loss=None,
 71        unsupervised_loss_and_metric=None,
 72        supervised_loss_and_metric=None,
 73        logger=SelfTrainingTensorboardLogger,
 74        source_distribution=None,
 75        **kwargs
 76    ):
 77        # Do we have supervised data or not?
 78        if supervised_train_loader is None:
 79            # No. -> We use the unsupervised training logic.
 80            train_loader = unsupervised_train_loader
 81            self._train_epoch_impl = self._train_epoch_unsupervised
 82        else:
 83            # Yes. -> We use the semi-supervised training logic.
 84            assert supervised_loss is not None
 85            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 86                else unsupervised_train_loader
 87            self._train_epoch_impl = self._train_epoch_semisupervised
 88
 89        self.unsupervised_train_loader = unsupervised_train_loader
 90        self.supervised_train_loader = supervised_train_loader
 91
 92        # Check that we have at least one of supvervised / unsupervised val loader.
 93        assert sum((
 94            supervised_val_loader is not None,
 95            unsupervised_val_loader is not None,
 96        )) > 0
 97        self.supervised_val_loader = supervised_val_loader
 98        self.unsupervised_val_loader = unsupervised_val_loader
 99
100        if self.unsupervised_val_loader is None:
101            val_loader = self.supervised_val_loader
102        else:
103            val_loader = self.unsupervised_train_loader
104
105        # Check that we have at least one of supvervised / unsupervised loss and metric.
106        assert sum((
107            supervised_loss_and_metric is not None,
108            unsupervised_loss_and_metric is not None,
109        )) > 0
110        self.supervised_loss_and_metric = supervised_loss_and_metric
111        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
112
113        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
114        kwargs.pop("train_loader", None)
115        kwargs.pop("val_loader", None)
116        kwargs.pop("metric", None)
117        kwargs.pop("loss", None)
118        super().__init__(
119            model=model, train_loader=train_loader, val_loader=val_loader,
120            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
121        )
122
123        self.unsupervised_loss = unsupervised_loss
124        self.supervised_loss = supervised_loss
125
126        self.pseudo_labeler = pseudo_labeler
127
128        if source_distribution is None:
129            self.source_distribution = None
130        else:
131            self.source_distribution = torch.FloatTensor(source_distribution).to(self.device)
132
133        self._kwargs = kwargs
134
135    #
136    # functionality for saving checkpoints and initialization
137    #
138
139    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
140        train_loader_kwargs = get_constructor_arguments(self.train_loader)
141        val_loader_kwargs = get_constructor_arguments(self.val_loader)
142        extra_state = {
143            "init": {
144                "train_loader_kwargs": train_loader_kwargs,
145                "train_dataset": self.train_loader.dataset,
146                "val_loader_kwargs": val_loader_kwargs,
147                "val_dataset": self.val_loader.dataset,
148                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
149                "loss_kwargs": {},
150                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
151                "metric_kwargs": {},
152            },
153        }
154        extra_state.update(**extra_save_dict)
155        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
156
157    # distribution alignment - encourages the distribution of the model's generated pseudo labels to match the marginal
158    #                          distribution of pseudo labels from the source transfer
159    #                          (key idea: to maximize the mutual information)
160    def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5):
161        if self.source_distribution is not None:
162            pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0)
163            _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True)
164            target_distribution = target_distribution / target_distribution.sum()
165            distribution_ratio = self.source_distribution / target_distribution
166            pseudo_labels = torch.where(
167                pseudo_labels < label_threshold,
168                pseudo_labels * distribution_ratio[0],
169                pseudo_labels * distribution_ratio[1]
170            ).clip(0, 1)
171
172        return pseudo_labels
173
174    #
175    # training and validation functionality
176    #
177
178    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
179        self.model.train()
180
181        n_iter = 0
182        t_per_iter = time.time()
183
184        # Sample from both the supervised and unsupervised loader.
185        for xu1, xu2 in self.unsupervised_train_loader:
186            xu1, xu2 = xu1.to(self.device), xu2.to(self.device)
187
188            teacher_input, model_input = xu1, xu2
189
190            with forward_context(), torch.no_grad():
191                # Compute the pseudo labels.
192                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
193
194            pseudo_labels = pseudo_labels.detach()
195            if label_filter is not None:
196                label_filter = label_filter.detach()
197
198            # Perform distribution alignment for pseudo labels
199            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
200
201            self.optimizer.zero_grad()
202            # Perform unsupervised training
203            with forward_context():
204                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
205
206            backprop(loss)
207
208            if self.logger is not None:
209                with torch.no_grad(), forward_context():
210                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
211                self.logger.log_train_unsupervised(
212                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
213                )
214                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
215                self.logger.log_lr(self._iteration, lr)
216
217            self._iteration += 1
218            n_iter += 1
219            if self._iteration >= self.max_iteration:
220                break
221            progress.update(1)
222
223        t_per_iter = (time.time() - t_per_iter) / n_iter
224        return t_per_iter
225
226    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
227        self.model.train()
228
229        n_iter = 0
230        t_per_iter = time.time()
231
232        # Sample from both the supervised and unsupervised loader.
233        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
234            xs, ys = xs.to(self.device), ys.to(self.device)
235            xu1, xu2 = xu1.to(self.device), xu2.to(self.device)
236
237            # Perform supervised training.
238            self.optimizer.zero_grad()
239            with forward_context():
240                # We pass the model, the input and the labels to the supervised loss function,
241                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
242                supervised_loss = self.supervised_loss(self.model, xs, ys)
243
244            teacher_input, model_input = xu1, xu2
245
246            with forward_context(), torch.no_grad():
247                # Compute the pseudo labels.
248                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
249
250            pseudo_labels = pseudo_labels.detach()
251            if label_filter is not None:
252                label_filter = label_filter.detach()
253
254            # Perform distribution alignment for pseudo labels
255            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
256
257            # Perform unsupervised training
258            with forward_context():
259                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
260
261            loss = (supervised_loss + unsupervised_loss) / 2
262            backprop(loss)
263
264            if self.logger is not None:
265                with torch.no_grad(), forward_context():
266                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
267                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
268
269                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
270                self.logger.log_train_unsupervised(
271                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
272                )
273
274                self.logger.log_combined_loss(self._iteration, loss)
275                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
276                self.logger.log_lr(self._iteration, lr)
277
278            self._iteration += 1
279            n_iter += 1
280            if self._iteration >= self.max_iteration:
281                break
282            progress.update(1)
283
284        t_per_iter = (time.time() - t_per_iter) / n_iter
285        return t_per_iter
286
287    def _validate_supervised(self, forward_context):
288        metric_val = 0.0
289        loss_val = 0.0
290
291        for x, y in self.supervised_val_loader:
292            x, y = x.to(self.device), y.to(self.device)
293            with forward_context():
294                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
295            loss_val += loss.item()
296            metric_val += metric.item()
297
298        metric_val /= len(self.supervised_val_loader)
299        loss_val /= len(self.supervised_val_loader)
300
301        if self.logger is not None:
302            with forward_context():
303                pred = self.model(x)
304            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
305
306        return metric_val
307
308    def _validate_unsupervised(self, forward_context):
309        metric_val = 0.0
310        loss_val = 0.0
311
312        for x1, x2 in self.unsupervised_val_loader:
313            x1, x2 = x1.to(self.device), x2.to(self.device)
314            teacher_input, model_input = x1, x2
315            with forward_context():
316                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
317                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
318            loss_val += loss.item()
319            metric_val += metric.item()
320
321        metric_val /= len(self.unsupervised_val_loader)
322        loss_val /= len(self.unsupervised_val_loader)
323
324        if self.logger is not None:
325            with forward_context():
326                pred = self.model(model_input)
327            self.logger.log_validation_unsupervised(
328                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
329            )
330
331        return metric_val
332
333    def _validate_impl(self, forward_context):
334        self.model.eval()
335
336        with torch.no_grad():
337
338            if self.supervised_val_loader is None:
339                supervised_metric = None
340            else:
341                supervised_metric = self._validate_supervised(forward_context)
342
343            if self.unsupervised_val_loader is None:
344                unsupervised_metric = None
345            else:
346                unsupervised_metric = self._validate_unsupervised(forward_context)
347
348        if unsupervised_metric is None:
349            metric = supervised_metric
350        elif supervised_metric is None:
351            metric = unsupervised_metric
352        else:
353            metric = (supervised_metric + unsupervised_metric) / 2
354
355        return metric
class FixMatchTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 12class FixMatchTrainer(torch_em.trainer.DefaultTrainer):
 13    """This trainer implements self-traning for semi-supervised learning and domain following the 'FixMatch' approach
 14    of Sohn et al. (https://arxiv.org/abs/2001.07685). This approach uses a (teacher) model derived from the
 15    student model via sharing the weights to predict pseudo-labels on unlabeled data.
 16    We support two training strategies: joint training on labeled and unlabeled data
 17    (with a supervised and unsupervised loss function). And training only on the unsupervised data.
 18
 19    This class expects the following data loaders:
 20    - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
 21    - supervised_train_loader (optional): Returns input and labels.
 22    - unsupervised_val_loader (optional): Same as unsupervised_train_loader
 23    - supervised_val_loader (optional): Same as supervised_train_loader
 24    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 25
 26    And the following elements to customize the pseudo labeling:
 27    - pseudo_labeler: to compute the psuedo-labels
 28        - Parameters: model, teacher_input
 29        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 30    - unsupervised_loss: the loss between model predictions and pseudo labels
 31        - Parameters: model, model_input, pseudo_labels, label_filter
 32        - Returns: loss
 33    - supervised_loss (optional): the supervised loss function
 34        - Parameters: model, input, labels
 35        - Returns: loss
 36    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 37        - Parameters: model, model_input, pseudo_labels, label_filter
 38        - Returns: loss, metric
 39    - supervised_loss_and_metric (optional): the supervised loss function and metric
 40        - Parameters: model, input, labels
 41        - Returns: loss, metric
 42    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 43
 44    Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader'
 45    for setting the ratio between supervised and unsupervised training samples
 46
 47    Parameters:
 48        model [nn.Module] -
 49        unsupervised_train_loader [torch.DataLoader] -
 50        unsupervised_loss [callable] -
 51        pseudo_labeler [callable] -
 52        supervised_train_loader [torch.DataLoader] - (default: None)
 53        supervised_loss [callable] - (default: None)
 54        unsupervised_loss_and_metric [callable] - (default: None)
 55        supervised_loss_and_metric [callable] - (default: None)
 56        logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger)
 57        momentum [float] - (default: 0.999)
 58        source_distribution [list] - (default: None)
 59        **kwargs - keyword arguments for torch_em.DataLoader
 60    """
 61
 62    def __init__(
 63        self,
 64        model,
 65        unsupervised_train_loader,
 66        unsupervised_loss,
 67        pseudo_labeler,
 68        supervised_train_loader=None,
 69        unsupervised_val_loader=None,
 70        supervised_val_loader=None,
 71        supervised_loss=None,
 72        unsupervised_loss_and_metric=None,
 73        supervised_loss_and_metric=None,
 74        logger=SelfTrainingTensorboardLogger,
 75        source_distribution=None,
 76        **kwargs
 77    ):
 78        # Do we have supervised data or not?
 79        if supervised_train_loader is None:
 80            # No. -> We use the unsupervised training logic.
 81            train_loader = unsupervised_train_loader
 82            self._train_epoch_impl = self._train_epoch_unsupervised
 83        else:
 84            # Yes. -> We use the semi-supervised training logic.
 85            assert supervised_loss is not None
 86            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 87                else unsupervised_train_loader
 88            self._train_epoch_impl = self._train_epoch_semisupervised
 89
 90        self.unsupervised_train_loader = unsupervised_train_loader
 91        self.supervised_train_loader = supervised_train_loader
 92
 93        # Check that we have at least one of supvervised / unsupervised val loader.
 94        assert sum((
 95            supervised_val_loader is not None,
 96            unsupervised_val_loader is not None,
 97        )) > 0
 98        self.supervised_val_loader = supervised_val_loader
 99        self.unsupervised_val_loader = unsupervised_val_loader
100
101        if self.unsupervised_val_loader is None:
102            val_loader = self.supervised_val_loader
103        else:
104            val_loader = self.unsupervised_train_loader
105
106        # Check that we have at least one of supvervised / unsupervised loss and metric.
107        assert sum((
108            supervised_loss_and_metric is not None,
109            unsupervised_loss_and_metric is not None,
110        )) > 0
111        self.supervised_loss_and_metric = supervised_loss_and_metric
112        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
113
114        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
115        kwargs.pop("train_loader", None)
116        kwargs.pop("val_loader", None)
117        kwargs.pop("metric", None)
118        kwargs.pop("loss", None)
119        super().__init__(
120            model=model, train_loader=train_loader, val_loader=val_loader,
121            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
122        )
123
124        self.unsupervised_loss = unsupervised_loss
125        self.supervised_loss = supervised_loss
126
127        self.pseudo_labeler = pseudo_labeler
128
129        if source_distribution is None:
130            self.source_distribution = None
131        else:
132            self.source_distribution = torch.FloatTensor(source_distribution).to(self.device)
133
134        self._kwargs = kwargs
135
136    #
137    # functionality for saving checkpoints and initialization
138    #
139
140    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
141        train_loader_kwargs = get_constructor_arguments(self.train_loader)
142        val_loader_kwargs = get_constructor_arguments(self.val_loader)
143        extra_state = {
144            "init": {
145                "train_loader_kwargs": train_loader_kwargs,
146                "train_dataset": self.train_loader.dataset,
147                "val_loader_kwargs": val_loader_kwargs,
148                "val_dataset": self.val_loader.dataset,
149                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
150                "loss_kwargs": {},
151                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
152                "metric_kwargs": {},
153            },
154        }
155        extra_state.update(**extra_save_dict)
156        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
157
158    # distribution alignment - encourages the distribution of the model's generated pseudo labels to match the marginal
159    #                          distribution of pseudo labels from the source transfer
160    #                          (key idea: to maximize the mutual information)
161    def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5):
162        if self.source_distribution is not None:
163            pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0)
164            _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True)
165            target_distribution = target_distribution / target_distribution.sum()
166            distribution_ratio = self.source_distribution / target_distribution
167            pseudo_labels = torch.where(
168                pseudo_labels < label_threshold,
169                pseudo_labels * distribution_ratio[0],
170                pseudo_labels * distribution_ratio[1]
171            ).clip(0, 1)
172
173        return pseudo_labels
174
175    #
176    # training and validation functionality
177    #
178
179    def _train_epoch_unsupervised(self, progress, forward_context, backprop):
180        self.model.train()
181
182        n_iter = 0
183        t_per_iter = time.time()
184
185        # Sample from both the supervised and unsupervised loader.
186        for xu1, xu2 in self.unsupervised_train_loader:
187            xu1, xu2 = xu1.to(self.device), xu2.to(self.device)
188
189            teacher_input, model_input = xu1, xu2
190
191            with forward_context(), torch.no_grad():
192                # Compute the pseudo labels.
193                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
194
195            pseudo_labels = pseudo_labels.detach()
196            if label_filter is not None:
197                label_filter = label_filter.detach()
198
199            # Perform distribution alignment for pseudo labels
200            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
201
202            self.optimizer.zero_grad()
203            # Perform unsupervised training
204            with forward_context():
205                loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
206
207            backprop(loss)
208
209            if self.logger is not None:
210                with torch.no_grad(), forward_context():
211                    pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
212                self.logger.log_train_unsupervised(
213                    self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
214                )
215                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
216                self.logger.log_lr(self._iteration, lr)
217
218            self._iteration += 1
219            n_iter += 1
220            if self._iteration >= self.max_iteration:
221                break
222            progress.update(1)
223
224        t_per_iter = (time.time() - t_per_iter) / n_iter
225        return t_per_iter
226
227    def _train_epoch_semisupervised(self, progress, forward_context, backprop):
228        self.model.train()
229
230        n_iter = 0
231        t_per_iter = time.time()
232
233        # Sample from both the supervised and unsupervised loader.
234        for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader):
235            xs, ys = xs.to(self.device), ys.to(self.device)
236            xu1, xu2 = xu1.to(self.device), xu2.to(self.device)
237
238            # Perform supervised training.
239            self.optimizer.zero_grad()
240            with forward_context():
241                # We pass the model, the input and the labels to the supervised loss function,
242                # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet.
243                supervised_loss = self.supervised_loss(self.model, xs, ys)
244
245            teacher_input, model_input = xu1, xu2
246
247            with forward_context(), torch.no_grad():
248                # Compute the pseudo labels.
249                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
250
251            pseudo_labels = pseudo_labels.detach()
252            if label_filter is not None:
253                label_filter = label_filter.detach()
254
255            # Perform distribution alignment for pseudo labels
256            pseudo_labels = self.get_distribution_alignment(pseudo_labels)
257
258            # Perform unsupervised training
259            with forward_context():
260                unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
261
262            loss = (supervised_loss + unsupervised_loss) / 2
263            backprop(loss)
264
265            if self.logger is not None:
266                with torch.no_grad(), forward_context():
267                    unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
268                    supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None
269
270                self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred)
271                self.logger.log_train_unsupervised(
272                    self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter
273                )
274
275                self.logger.log_combined_loss(self._iteration, loss)
276                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
277                self.logger.log_lr(self._iteration, lr)
278
279            self._iteration += 1
280            n_iter += 1
281            if self._iteration >= self.max_iteration:
282                break
283            progress.update(1)
284
285        t_per_iter = (time.time() - t_per_iter) / n_iter
286        return t_per_iter
287
288    def _validate_supervised(self, forward_context):
289        metric_val = 0.0
290        loss_val = 0.0
291
292        for x, y in self.supervised_val_loader:
293            x, y = x.to(self.device), y.to(self.device)
294            with forward_context():
295                loss, metric = self.supervised_loss_and_metric(self.model, x, y)
296            loss_val += loss.item()
297            metric_val += metric.item()
298
299        metric_val /= len(self.supervised_val_loader)
300        loss_val /= len(self.supervised_val_loader)
301
302        if self.logger is not None:
303            with forward_context():
304                pred = self.model(x)
305            self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred)
306
307        return metric_val
308
309    def _validate_unsupervised(self, forward_context):
310        metric_val = 0.0
311        loss_val = 0.0
312
313        for x1, x2 in self.unsupervised_val_loader:
314            x1, x2 = x1.to(self.device), x2.to(self.device)
315            teacher_input, model_input = x1, x2
316            with forward_context():
317                pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input)
318                loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter)
319            loss_val += loss.item()
320            metric_val += metric.item()
321
322        metric_val /= len(self.unsupervised_val_loader)
323        loss_val /= len(self.unsupervised_val_loader)
324
325        if self.logger is not None:
326            with forward_context():
327                pred = self.model(model_input)
328            self.logger.log_validation_unsupervised(
329                self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter
330            )
331
332        return metric_val
333
334    def _validate_impl(self, forward_context):
335        self.model.eval()
336
337        with torch.no_grad():
338
339            if self.supervised_val_loader is None:
340                supervised_metric = None
341            else:
342                supervised_metric = self._validate_supervised(forward_context)
343
344            if self.unsupervised_val_loader is None:
345                unsupervised_metric = None
346            else:
347                unsupervised_metric = self._validate_unsupervised(forward_context)
348
349        if unsupervised_metric is None:
350            metric = supervised_metric
351        elif supervised_metric is None:
352            metric = unsupervised_metric
353        else:
354            metric = (supervised_metric + unsupervised_metric) / 2
355
356        return metric

This trainer implements self-traning for semi-supervised learning and domain following the 'FixMatch' approach of Sohn et al. (https://arxiv.org/abs/2001.07685). This approach uses a (teacher) model derived from the student model via sharing the weights to predict pseudo-labels on unlabeled data. We support two training strategies: joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). And training only on the unsupervised data.

This class expects the following data loaders:

  • unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
  • supervised_train_loader (optional): Returns input and labels.
  • unsupervised_val_loader (optional): Same as unsupervised_train_loader
  • supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.

And the following elements to customize the pseudo labeling:

  • pseudo_labeler: to compute the psuedo-labels
    • Parameters: model, teacher_input
    • Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
  • unsupervised_loss: the loss between model predictions and pseudo labels
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss
  • supervised_loss (optional): the supervised loss function
    • Parameters: model, input, labels
    • Returns: loss
  • unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
    • Parameters: model, model_input, pseudo_labels, label_filter
    • Returns: loss, metric
  • supervised_loss_and_metric (optional): the supervised loss function and metric
    • Parameters: model, input, labels
    • Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.

Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples

Arguments:
  • model [nn.Module] -
  • unsupervised_train_loader [torch.DataLoader] -
  • unsupervised_loss [callable] -
  • pseudo_labeler [callable] -
  • supervised_train_loader [torch.DataLoader] - (default: None)
  • supervised_loss [callable] - (default: None)
  • unsupervised_loss_and_metric [callable] - (default: None)
  • supervised_loss_and_metric [callable] - (default: None)
  • logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger)
  • momentum [float] - (default: 0.999)
  • source_distribution [list] - (default: None)
  • **kwargs - keyword arguments for torch_em.DataLoader
FixMatchTrainer( model, unsupervised_train_loader, unsupervised_loss, pseudo_labeler, supervised_train_loader=None, unsupervised_val_loader=None, supervised_val_loader=None, supervised_loss=None, unsupervised_loss_and_metric=None, supervised_loss_and_metric=None, logger=<class 'torch_em.self_training.logger.SelfTrainingTensorboardLogger'>, source_distribution=None, **kwargs)
 62    def __init__(
 63        self,
 64        model,
 65        unsupervised_train_loader,
 66        unsupervised_loss,
 67        pseudo_labeler,
 68        supervised_train_loader=None,
 69        unsupervised_val_loader=None,
 70        supervised_val_loader=None,
 71        supervised_loss=None,
 72        unsupervised_loss_and_metric=None,
 73        supervised_loss_and_metric=None,
 74        logger=SelfTrainingTensorboardLogger,
 75        source_distribution=None,
 76        **kwargs
 77    ):
 78        # Do we have supervised data or not?
 79        if supervised_train_loader is None:
 80            # No. -> We use the unsupervised training logic.
 81            train_loader = unsupervised_train_loader
 82            self._train_epoch_impl = self._train_epoch_unsupervised
 83        else:
 84            # Yes. -> We use the semi-supervised training logic.
 85            assert supervised_loss is not None
 86            train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\
 87                else unsupervised_train_loader
 88            self._train_epoch_impl = self._train_epoch_semisupervised
 89
 90        self.unsupervised_train_loader = unsupervised_train_loader
 91        self.supervised_train_loader = supervised_train_loader
 92
 93        # Check that we have at least one of supvervised / unsupervised val loader.
 94        assert sum((
 95            supervised_val_loader is not None,
 96            unsupervised_val_loader is not None,
 97        )) > 0
 98        self.supervised_val_loader = supervised_val_loader
 99        self.unsupervised_val_loader = unsupervised_val_loader
100
101        if self.unsupervised_val_loader is None:
102            val_loader = self.supervised_val_loader
103        else:
104            val_loader = self.unsupervised_train_loader
105
106        # Check that we have at least one of supvervised / unsupervised loss and metric.
107        assert sum((
108            supervised_loss_and_metric is not None,
109            unsupervised_loss_and_metric is not None,
110        )) > 0
111        self.supervised_loss_and_metric = supervised_loss_and_metric
112        self.unsupervised_loss_and_metric = unsupervised_loss_and_metric
113
114        # train_loader, val_loader, loss and metric may be unnecessarily deserialized
115        kwargs.pop("train_loader", None)
116        kwargs.pop("val_loader", None)
117        kwargs.pop("metric", None)
118        kwargs.pop("loss", None)
119        super().__init__(
120            model=model, train_loader=train_loader, val_loader=val_loader,
121            loss=Dummy(), metric=Dummy(), logger=logger, **kwargs
122        )
123
124        self.unsupervised_loss = unsupervised_loss
125        self.supervised_loss = supervised_loss
126
127        self.pseudo_labeler = pseudo_labeler
128
129        if source_distribution is None:
130            self.source_distribution = None
131        else:
132            self.source_distribution = torch.FloatTensor(source_distribution).to(self.device)
133
134        self._kwargs = kwargs
unsupervised_train_loader
supervised_train_loader
supervised_val_loader
unsupervised_val_loader
supervised_loss_and_metric
unsupervised_loss_and_metric
unsupervised_loss
supervised_loss
pseudo_labeler
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
140    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
141        train_loader_kwargs = get_constructor_arguments(self.train_loader)
142        val_loader_kwargs = get_constructor_arguments(self.val_loader)
143        extra_state = {
144            "init": {
145                "train_loader_kwargs": train_loader_kwargs,
146                "train_dataset": self.train_loader.dataset,
147                "val_loader_kwargs": val_loader_kwargs,
148                "val_dataset": self.val_loader.dataset,
149                "loss_class": "torch_em.self_training.mean_teacher.Dummy",
150                "loss_kwargs": {},
151                "metric_class": "torch_em.self_training.mean_teacher.Dummy",
152                "metric_kwargs": {},
153            },
154        }
155        extra_state.update(**extra_save_dict)
156        super().save_checkpoint(name, current_metric, best_metric, **extra_state)
def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5):
161    def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5):
162        if self.source_distribution is not None:
163            pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0)
164            _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True)
165            target_distribution = target_distribution / target_distribution.sum()
166            distribution_ratio = self.source_distribution / target_distribution
167            pseudo_labels = torch.where(
168                pseudo_labels < label_threshold,
169                pseudo_labels * distribution_ratio[0],
170                pseudo_labels * distribution_ratio[1]
171            ).clip(0, 1)
172
173        return pseudo_labels