torch_em.self_training.fix_match
1import time 2from typing import Callable, List, Optional 3 4import torch 5import torch_em 6from torch_em.util import get_constructor_arguments 7 8from .logger import SelfTrainingTensorboardLogger 9from .mean_teacher import Dummy 10from ..transform.invertible_augmentations import InvertibleAugmenter 11 12 13class FixMatchTrainer(torch_em.trainer.DefaultTrainer): 14 """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach. 15 16 FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). 17 It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels 18 on unlabeled data. We support two training strategies: 19 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 20 - Taining only on the unsupervised data. 21 22 This class expects the following data loaders: 23 - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input. 24 - supervised_train_loader (optional): Returns input and labels. 25 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 26 - supervised_val_loader (optional): Same as supervised_train_loader 27 At least one of unsupervised_val_loader and supervised_val_loader must be given. 28 29 The following arguments can be used to customize the pseudo labeling: 30 - pseudo_labeler: to compute the psuedo-labels 31 - Parameters: model, teacher_input 32 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 33 - unsupervised_loss: the loss between model predictions and pseudo labels 34 - Parameters: model, model_input, pseudo_labels, label_filter 35 - Returns: loss 36 - supervised_loss (optional): the supervised loss function 37 - Parameters: model, input, labels 38 - Returns: loss 39 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 40 - Parameters: model, model_input, pseudo_labels, label_filter 41 - Returns: loss, metric 42 - supervised_loss_and_metric (optional): the supervised loss function and metric 43 - Parameters: model, input, labels 44 - Returns: loss, metric 45 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 46 47 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 48 for setting the ratio between supervised and unsupervised training samples. 49 50 Args: 51 model: The model to be trained. 52 unsupervised_train_loader: The loader for unsupervised training. 53 unsupervised_loss: The loss for unsupervised training. 54 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 55 supervised_train_loader: The loader for supervised training. 56 supervised_loss: The loss for supervised training. 57 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 58 supervised_loss_and_metric: The loss and metrhic for supervised training. 59 logger: The logger. 60 source_distribution: The ratio of labels in the source label distribution. 61 If given, the predicted distribution of the trained model will be regularized to 62 match this source label distribution. 63 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 64 """ 65 66 def __init__( 67 self, 68 model: torch.nn.Module, 69 unsupervised_train_loader: torch.utils.data.DataLoader, 70 unsupervised_loss: torch.utils.data.DataLoader, 71 pseudo_labeler: Callable, 72 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 73 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 74 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 75 supervised_loss: Optional[Callable] = None, 76 unsupervised_loss_and_metric: Optional[Callable] = None, 77 supervised_loss_and_metric: Optional[Callable] = None, 78 logger=SelfTrainingTensorboardLogger, 79 source_distribution: List[float] = None, 80 **kwargs, 81 ): 82 # Do we have supervised data or not? 83 if supervised_train_loader is None: 84 # No. -> We use the unsupervised training logic. 85 train_loader = unsupervised_train_loader 86 self._train_epoch_impl = self._train_epoch_unsupervised 87 else: 88 # Yes. -> We use the semi-supervised training logic. 89 assert supervised_loss is not None 90 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 91 else unsupervised_train_loader 92 self._train_epoch_impl = self._train_epoch_semisupervised 93 94 self.unsupervised_train_loader = unsupervised_train_loader 95 self.supervised_train_loader = supervised_train_loader 96 97 # Check that we have at least one of supvervised / unsupervised val loader. 98 assert sum(( 99 supervised_val_loader is not None, 100 unsupervised_val_loader is not None, 101 )) > 0 102 self.supervised_val_loader = supervised_val_loader 103 self.unsupervised_val_loader = unsupervised_val_loader 104 105 if self.unsupervised_val_loader is None: 106 val_loader = self.supervised_val_loader 107 else: 108 val_loader = self.unsupervised_val_loader 109 110 # Check that we have at least one of supvervised / unsupervised loss and metric. 111 assert sum(( 112 supervised_loss_and_metric is not None, 113 unsupervised_loss_and_metric is not None, 114 )) > 0 115 self.supervised_loss_and_metric = supervised_loss_and_metric 116 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 117 118 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 119 kwargs.pop("train_loader", None) 120 kwargs.pop("val_loader", None) 121 kwargs.pop("metric", None) 122 kwargs.pop("loss", None) 123 super().__init__( 124 model=model, train_loader=train_loader, val_loader=val_loader, 125 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 126 ) 127 128 self.unsupervised_loss = unsupervised_loss 129 self.supervised_loss = supervised_loss 130 131 self.pseudo_labeler = pseudo_labeler 132 133 if source_distribution is None: 134 self.source_distribution = None 135 else: 136 self.source_distribution = torch.FloatTensor(source_distribution).to(self.device) 137 138 self._kwargs = kwargs 139 140 # 141 # functionality for saving checkpoints and initialization 142 # 143 144 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 145 """@private 146 """ 147 train_loader_kwargs = get_constructor_arguments(self.train_loader) 148 val_loader_kwargs = get_constructor_arguments(self.val_loader) 149 extra_state = { 150 "init": { 151 "train_loader_kwargs": train_loader_kwargs, 152 "train_dataset": self.train_loader.dataset, 153 "val_loader_kwargs": val_loader_kwargs, 154 "val_dataset": self.val_loader.dataset, 155 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 156 "loss_kwargs": {}, 157 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 158 "metric_kwargs": {}, 159 }, 160 } 161 extra_state.update(**extra_save_dict) 162 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 163 164 # Distribution alignment: 165 # Encourages the distribution of the model's generated pseudo labels to match the marginal 166 # distribution of pseudo labels from the source transfer (key idea: to maximize the mutual information). 167 def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5): 168 """@private 169 """ 170 if self.source_distribution is not None: 171 pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0) 172 _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True) 173 target_distribution = target_distribution / target_distribution.sum() 174 distribution_ratio = self.source_distribution / target_distribution 175 pseudo_labels = torch.where( 176 pseudo_labels < label_threshold, 177 pseudo_labels * distribution_ratio[0], 178 pseudo_labels * distribution_ratio[1] 179 ).clip(0, 1) 180 181 return pseudo_labels 182 183 # 184 # training and validation functionality 185 # 186 187 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 188 self.model.train() 189 190 n_iter = 0 191 t_per_iter = time.time() 192 193 # Sample from both the supervised and unsupervised loader. 194 for xu1, xu2 in self.unsupervised_train_loader: 195 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 196 197 teacher_input, model_input = xu1, xu2 198 199 with forward_context(), torch.no_grad(): 200 # Compute the pseudo labels. 201 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 202 203 pseudo_labels = pseudo_labels.detach() 204 if label_filter is not None: 205 label_filter = label_filter.detach() 206 207 # Perform distribution alignment for pseudo labels 208 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 209 210 self.optimizer.zero_grad() 211 # Perform unsupervised training 212 with forward_context(): 213 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 214 215 backprop(loss) 216 217 if self.logger is not None: 218 with torch.no_grad(), forward_context(): 219 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 220 self.logger.log_train_unsupervised( 221 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 222 ) 223 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 224 self.logger.log_lr(self._iteration, lr) 225 if self.pseudo_labeler.confidence_threshold is not None: 226 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 227 228 self._iteration += 1 229 n_iter += 1 230 if self._iteration >= self.max_iteration: 231 break 232 progress.update(1) 233 234 t_per_iter = (time.time() - t_per_iter) / n_iter 235 return t_per_iter 236 237 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 238 self.model.train() 239 240 n_iter = 0 241 t_per_iter = time.time() 242 243 # Sample from both the supervised and unsupervised loader. 244 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 245 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 246 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 247 248 # Perform supervised training. 249 self.optimizer.zero_grad() 250 with forward_context(): 251 # We pass the model, the input and the labels to the supervised loss function, 252 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 253 supervised_loss = self.supervised_loss(self.model, xs, ys) 254 255 teacher_input, model_input = xu1, xu2 256 257 with forward_context(), torch.no_grad(): 258 # Compute the pseudo labels. 259 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 260 261 pseudo_labels = pseudo_labels.detach() 262 if label_filter is not None: 263 label_filter = label_filter.detach() 264 265 # Perform distribution alignment for pseudo labels 266 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 267 268 # Perform unsupervised training 269 with forward_context(): 270 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 271 272 loss = (supervised_loss + unsupervised_loss) / 2 273 backprop(loss) 274 275 if self.logger is not None: 276 with torch.no_grad(), forward_context(): 277 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 278 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 279 280 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 281 self.logger.log_train_unsupervised( 282 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 283 ) 284 285 self.logger.log_combined_loss(self._iteration, loss) 286 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 287 self.logger.log_lr(self._iteration, lr) 288 if self.pseudo_labeler.confidence_threshold is not None: 289 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 290 291 self._iteration += 1 292 n_iter += 1 293 if self._iteration >= self.max_iteration: 294 break 295 progress.update(1) 296 297 t_per_iter = (time.time() - t_per_iter) / n_iter 298 return t_per_iter 299 300 def _validate_supervised(self, forward_context): 301 metric_val = 0.0 302 loss_val = 0.0 303 304 for x, y in self.supervised_val_loader: 305 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 306 with forward_context(): 307 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 308 loss_val += loss.item() 309 metric_val += metric.item() 310 311 metric_val /= len(self.supervised_val_loader) 312 loss_val /= len(self.supervised_val_loader) 313 314 if self.logger is not None: 315 with forward_context(): 316 pred = self.model(x) 317 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 318 319 return metric_val 320 321 def _validate_unsupervised(self, forward_context): 322 metric_val = 0.0 323 loss_val = 0.0 324 325 for x1, x2 in self.unsupervised_val_loader: 326 x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True) 327 teacher_input, model_input = x1, x2 328 with forward_context(): 329 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 330 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 331 loss_val += loss.item() 332 metric_val += metric.item() 333 334 metric_val /= len(self.unsupervised_val_loader) 335 loss_val /= len(self.unsupervised_val_loader) 336 337 if self.logger is not None: 338 with forward_context(): 339 pred = self.model(model_input) 340 self.logger.log_validation_unsupervised( 341 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 342 ) 343 344 return metric_val 345 346 def _validate_impl(self, forward_context): 347 self.model.eval() 348 349 with torch.no_grad(): 350 351 if self.supervised_val_loader is None: 352 supervised_metric = None 353 else: 354 supervised_metric = self._validate_supervised(forward_context) 355 356 if self.unsupervised_val_loader is None: 357 unsupervised_metric = None 358 else: 359 unsupervised_metric = self._validate_unsupervised(forward_context) 360 361 if unsupervised_metric is None: 362 metric = supervised_metric 363 elif supervised_metric is None: 364 metric = unsupervised_metric 365 else: 366 metric = (supervised_metric + unsupervised_metric) / 2 367 368 return metric 369 370 371class FixMatchTrainerWithInvertibleAugmentations(FixMatchTrainer): 372 """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach. 373 374 FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). 375 It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels 376 on unlabeled data. We support two training strategies: 377 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 378 - Taining only on the unsupervised data. 379 380 This class expects the following data loaders: 381 - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input. 382 - supervised_train_loader (optional): Returns input and labels. 383 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 384 - supervised_val_loader (optional): Same as supervised_train_loader 385 At least one of unsupervised_val_loader and supervised_val_loader must be given. 386 387 The augmenter defines separate invertible transforms for teacher and student inputs. 388 Teacher and student views are generated independently, and the corresponding inverse 389 transforms map predictions and pseudo-labels back into a shared reference frame. 390 391 The following arguments can be used to customize the pseudo labeling: 392 - pseudo_labeler: to compute the psuedo-labels 393 - Parameters: model, teacher_input 394 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 395 - unsupervised_loss: the loss between model predictions and pseudo labels 396 - Parameters: model, model_input, pseudo_labels, label_filter 397 - Returns: loss 398 - supervised_loss (optional): the supervised loss function 399 - Parameters: model, input, labels 400 - Returns: loss 401 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 402 - Parameters: model, model_input, pseudo_labels, label_filter 403 - Returns: loss, metric 404 - supervised_loss_and_metric (optional): the supervised loss function and metric 405 - Parameters: model, input, labels 406 - Returns: loss, metric 407 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 408 409 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 410 for setting the ratio between supervised and unsupervised training samples. 411 412 Args: 413 model: The model to be trained. 414 unsupervised_train_loader: The loader for unsupervised training. 415 unsupervised_loss: The loss for unsupervised training. 416 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 417 augmenter: Invertible augmenter providing separate teacher and student transforms, 418 including corresponding inverse transforms to align predictions in a common frame. 419 supervised_train_loader: The loader for supervised training. 420 supervised_loss: The loss for supervised training. 421 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 422 supervised_loss_and_metric: The loss and metrhic for supervised training. 423 logger: The logger. 424 source_distribution: The ratio of labels in the source label distribution. 425 If given, the predicted distribution of the trained model will be regularized to 426 match this source label distribution. 427 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 428 """ 429 430 def __init__( 431 self, 432 model: torch.nn.Module, 433 unsupervised_train_loader: torch.utils.data.DataLoader, 434 unsupervised_loss: torch.utils.data.DataLoader, 435 pseudo_labeler: Callable, 436 augmenter: InvertibleAugmenter, 437 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 438 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 439 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 440 supervised_loss: Optional[Callable] = None, 441 unsupervised_loss_and_metric: Optional[Callable] = None, 442 supervised_loss_and_metric: Optional[Callable] = None, 443 logger=SelfTrainingTensorboardLogger, 444 source_distribution: List[float] = None, 445 **kwargs, 446 ): 447 super().__init__( 448 model=model, 449 unsupervised_train_loader=unsupervised_train_loader, 450 unsupervised_loss=unsupervised_loss, 451 pseudo_labeler=pseudo_labeler, 452 supervised_train_loader=supervised_train_loader, 453 unsupervised_val_loader=unsupervised_val_loader, 454 supervised_val_loader=supervised_val_loader, 455 supervised_loss=supervised_loss, 456 unsupervised_loss_and_metric=unsupervised_loss_and_metric, 457 supervised_loss_and_metric=supervised_loss_and_metric, 458 logger=logger, 459 source_distribution=source_distribution, 460 **kwargs, 461 ) 462 463 self.augmenter = augmenter 464 465 # 466 # training and validation functionality 467 # 468 469 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 470 self.model.train() 471 472 n_iter = 0 473 t_per_iter = time.time() 474 475 # Sample from both the supervised and unsupervised loader. 476 for xu in self.unsupervised_train_loader: 477 self.augmenter.reset_all() 478 xu = xu.to(self.device, non_blocking=True) 479 480 xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu) 481 teacher_input, model_input = xu1, xu2 482 483 with forward_context(), torch.no_grad(): 484 # Compute the pseudo labels. 485 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 486 487 pseudo_labels = pseudo_labels.detach() 488 if label_filter is not None: 489 label_filter = label_filter.detach() 490 491 # Perform distribution alignment for pseudo labels, then invert into the reference frame. 492 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 493 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 494 label_filter_inv = ( 495 self.augmenter.teacher.reverse_transform(label_filter) 496 if label_filter is not None else None 497 ) 498 499 self.optimizer.zero_grad() 500 # Perform unsupervised training 501 with forward_context(): 502 pred = self.model(model_input) 503 pred_inv = self.augmenter.student.reverse_transform(pred) 504 loss = self.unsupervised_loss(pred_inv, pseudo_labels_inv, label_filter_inv) 505 506 backprop(loss) 507 508 if self.logger is not None: 509 with torch.no_grad(), forward_context(): 510 pred = pred if self._iteration % self.log_image_interval == 0 else None 511 self.logger.log_train_unsupervised( 512 self._iteration, loss, xu, xu, pred_inv, pseudo_labels_inv, label_filter_inv 513 ) 514 self.logger.log_train_augmentations( 515 self._iteration, xu1, xu2, pseudo_labels, pred, 516 ) 517 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 518 self.logger.log_lr(self._iteration, lr) 519 if self.pseudo_labeler.confidence_threshold is not None: 520 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 521 522 self._iteration += 1 523 n_iter += 1 524 if self._iteration >= self.max_iteration: 525 break 526 progress.update(1) 527 528 t_per_iter = (time.time() - t_per_iter) / n_iter 529 return t_per_iter 530 531 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 532 self.model.train() 533 534 n_iter = 0 535 t_per_iter = time.time() 536 537 # Sample from both the supervised and unsupervised loader. 538 for (xs, ys), xu in zip(self.supervised_train_loader, self.unsupervised_train_loader): 539 self.augmenter.reset_all() 540 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 541 xu = xu.to(self.device, non_blocking=True) 542 543 xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu) 544 teacher_input, model_input = xu1, xu2 545 546 # Perform supervised training. 547 self.optimizer.zero_grad() 548 with forward_context(): 549 # We pass the model, the input and the labels to the supervised loss function, 550 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 551 supervised_pred = self.model(xs) 552 supervised_loss = self.supervised_loss(supervised_pred, ys) 553 554 with forward_context(), torch.no_grad(): 555 # Compute the pseudo labels. 556 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 557 558 pseudo_labels = pseudo_labels.detach() 559 if label_filter is not None: 560 label_filter = label_filter.detach() 561 562 # Perform distribution alignment for pseudo labels, then invert into the reference frame. 563 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 564 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 565 label_filter_inv = ( 566 self.augmenter.teacher.reverse_transform(label_filter) 567 if label_filter is not None else None 568 ) 569 570 # Perform unsupervised training 571 with forward_context(): 572 unsup_pred = self.model(model_input) 573 unsup_pred_inv = self.augmenter.student.reverse_transform(unsup_pred) 574 unsupervised_loss = self.unsupervised_loss(unsup_pred_inv, pseudo_labels_inv, label_filter_inv) 575 576 loss = (supervised_loss + unsupervised_loss) / 2 577 backprop(loss) 578 579 if self.logger is not None: 580 with torch.no_grad(), forward_context(): 581 unsup_pred = unsup_pred if self._iteration % self.log_image_interval == 0 else None 582 supervised_pred = supervised_pred if self._iteration % self.log_image_interval == 0 else None 583 584 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 585 self.logger.log_train_unsupervised( 586 self._iteration, unsupervised_loss, xu, xu, unsup_pred_inv, pseudo_labels_inv, label_filter_inv 587 ) 588 self.logger.log_train_augmentations( 589 self._iteration, xu1, xu2, pseudo_labels, unsup_pred, 590 ) 591 592 self.logger.log_combined_loss(self._iteration, loss) 593 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 594 self.logger.log_lr(self._iteration, lr) 595 if self.pseudo_labeler.confidence_threshold is not None: 596 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 597 598 self._iteration += 1 599 n_iter += 1 600 if self._iteration >= self.max_iteration: 601 break 602 progress.update(1) 603 604 t_per_iter = (time.time() - t_per_iter) / n_iter 605 return t_per_iter 606 607 def _validate_supervised(self, forward_context): 608 metric_val = 0.0 609 loss_val = 0.0 610 611 for x, y in self.supervised_val_loader: 612 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 613 with forward_context(): 614 pred = self.model(x) 615 loss, metric = self.supervised_loss_and_metric(pred, y) 616 loss_val += loss.item() 617 metric_val += metric.item() 618 619 metric_val /= len(self.supervised_val_loader) 620 loss_val /= len(self.supervised_val_loader) 621 622 if self.logger is not None: 623 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 624 625 return metric_val 626 627 def _validate_unsupervised(self, forward_context): 628 metric_val = 0.0 629 loss_val = 0.0 630 631 for x in self.unsupervised_val_loader: 632 self.augmenter.reset_all() 633 x = x.to(self.device, non_blocking=True) 634 635 x1, x2 = self.augmenter.teacher.transform(x), self.augmenter.student.transform(x) 636 teacher_input, model_input = x1, x2 637 638 with forward_context(): 639 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 640 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 641 label_filter_inv = ( 642 self.augmenter.teacher.reverse_transform(label_filter) 643 if label_filter is not None else None 644 ) 645 646 pred = self.model(model_input) 647 pred_inv = self.augmenter.student.reverse_transform(pred) 648 loss, metric = self.unsupervised_loss_and_metric(pred_inv, pseudo_labels_inv, label_filter_inv) 649 loss_val += loss.item() 650 metric_val += metric.item() 651 652 metric_val /= len(self.unsupervised_val_loader) 653 loss_val /= len(self.unsupervised_val_loader) 654 655 if self.logger is not None: 656 self.logger.log_validation_unsupervised( 657 self._iteration, metric_val, loss_val, x, x, pred_inv, pseudo_labels_inv, label_filter_inv 658 ) 659 self.logger.log_validation_augmentations( 660 self._iteration, x1, x2, pseudo_labels, pred, 661 ) 662 663 self.pseudo_labeler.step(metric_val, self._epoch) 664 665 return metric_val 666 667 def _validate_impl(self, forward_context): 668 self.model.eval() 669 670 with torch.no_grad(): 671 672 if self.supervised_val_loader is None: 673 supervised_metric = None 674 else: 675 supervised_metric = self._validate_supervised(forward_context) 676 677 if self.unsupervised_val_loader is None: 678 unsupervised_metric = None 679 else: 680 unsupervised_metric = self._validate_unsupervised(forward_context) 681 682 if unsupervised_metric is None: 683 metric = supervised_metric 684 elif supervised_metric is None: 685 metric = unsupervised_metric 686 else: 687 metric = (supervised_metric + unsupervised_metric) / 2 688 689 return metric
14class FixMatchTrainer(torch_em.trainer.DefaultTrainer): 15 """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach. 16 17 FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). 18 It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels 19 on unlabeled data. We support two training strategies: 20 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 21 - Taining only on the unsupervised data. 22 23 This class expects the following data loaders: 24 - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input. 25 - supervised_train_loader (optional): Returns input and labels. 26 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 27 - supervised_val_loader (optional): Same as supervised_train_loader 28 At least one of unsupervised_val_loader and supervised_val_loader must be given. 29 30 The following arguments can be used to customize the pseudo labeling: 31 - pseudo_labeler: to compute the psuedo-labels 32 - Parameters: model, teacher_input 33 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 34 - unsupervised_loss: the loss between model predictions and pseudo labels 35 - Parameters: model, model_input, pseudo_labels, label_filter 36 - Returns: loss 37 - supervised_loss (optional): the supervised loss function 38 - Parameters: model, input, labels 39 - Returns: loss 40 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 41 - Parameters: model, model_input, pseudo_labels, label_filter 42 - Returns: loss, metric 43 - supervised_loss_and_metric (optional): the supervised loss function and metric 44 - Parameters: model, input, labels 45 - Returns: loss, metric 46 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 47 48 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 49 for setting the ratio between supervised and unsupervised training samples. 50 51 Args: 52 model: The model to be trained. 53 unsupervised_train_loader: The loader for unsupervised training. 54 unsupervised_loss: The loss for unsupervised training. 55 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 56 supervised_train_loader: The loader for supervised training. 57 supervised_loss: The loss for supervised training. 58 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 59 supervised_loss_and_metric: The loss and metrhic for supervised training. 60 logger: The logger. 61 source_distribution: The ratio of labels in the source label distribution. 62 If given, the predicted distribution of the trained model will be regularized to 63 match this source label distribution. 64 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 65 """ 66 67 def __init__( 68 self, 69 model: torch.nn.Module, 70 unsupervised_train_loader: torch.utils.data.DataLoader, 71 unsupervised_loss: torch.utils.data.DataLoader, 72 pseudo_labeler: Callable, 73 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 74 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 75 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 76 supervised_loss: Optional[Callable] = None, 77 unsupervised_loss_and_metric: Optional[Callable] = None, 78 supervised_loss_and_metric: Optional[Callable] = None, 79 logger=SelfTrainingTensorboardLogger, 80 source_distribution: List[float] = None, 81 **kwargs, 82 ): 83 # Do we have supervised data or not? 84 if supervised_train_loader is None: 85 # No. -> We use the unsupervised training logic. 86 train_loader = unsupervised_train_loader 87 self._train_epoch_impl = self._train_epoch_unsupervised 88 else: 89 # Yes. -> We use the semi-supervised training logic. 90 assert supervised_loss is not None 91 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 92 else unsupervised_train_loader 93 self._train_epoch_impl = self._train_epoch_semisupervised 94 95 self.unsupervised_train_loader = unsupervised_train_loader 96 self.supervised_train_loader = supervised_train_loader 97 98 # Check that we have at least one of supvervised / unsupervised val loader. 99 assert sum(( 100 supervised_val_loader is not None, 101 unsupervised_val_loader is not None, 102 )) > 0 103 self.supervised_val_loader = supervised_val_loader 104 self.unsupervised_val_loader = unsupervised_val_loader 105 106 if self.unsupervised_val_loader is None: 107 val_loader = self.supervised_val_loader 108 else: 109 val_loader = self.unsupervised_val_loader 110 111 # Check that we have at least one of supvervised / unsupervised loss and metric. 112 assert sum(( 113 supervised_loss_and_metric is not None, 114 unsupervised_loss_and_metric is not None, 115 )) > 0 116 self.supervised_loss_and_metric = supervised_loss_and_metric 117 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 118 119 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 120 kwargs.pop("train_loader", None) 121 kwargs.pop("val_loader", None) 122 kwargs.pop("metric", None) 123 kwargs.pop("loss", None) 124 super().__init__( 125 model=model, train_loader=train_loader, val_loader=val_loader, 126 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 127 ) 128 129 self.unsupervised_loss = unsupervised_loss 130 self.supervised_loss = supervised_loss 131 132 self.pseudo_labeler = pseudo_labeler 133 134 if source_distribution is None: 135 self.source_distribution = None 136 else: 137 self.source_distribution = torch.FloatTensor(source_distribution).to(self.device) 138 139 self._kwargs = kwargs 140 141 # 142 # functionality for saving checkpoints and initialization 143 # 144 145 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 146 """@private 147 """ 148 train_loader_kwargs = get_constructor_arguments(self.train_loader) 149 val_loader_kwargs = get_constructor_arguments(self.val_loader) 150 extra_state = { 151 "init": { 152 "train_loader_kwargs": train_loader_kwargs, 153 "train_dataset": self.train_loader.dataset, 154 "val_loader_kwargs": val_loader_kwargs, 155 "val_dataset": self.val_loader.dataset, 156 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 157 "loss_kwargs": {}, 158 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 159 "metric_kwargs": {}, 160 }, 161 } 162 extra_state.update(**extra_save_dict) 163 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 164 165 # Distribution alignment: 166 # Encourages the distribution of the model's generated pseudo labels to match the marginal 167 # distribution of pseudo labels from the source transfer (key idea: to maximize the mutual information). 168 def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5): 169 """@private 170 """ 171 if self.source_distribution is not None: 172 pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0) 173 _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True) 174 target_distribution = target_distribution / target_distribution.sum() 175 distribution_ratio = self.source_distribution / target_distribution 176 pseudo_labels = torch.where( 177 pseudo_labels < label_threshold, 178 pseudo_labels * distribution_ratio[0], 179 pseudo_labels * distribution_ratio[1] 180 ).clip(0, 1) 181 182 return pseudo_labels 183 184 # 185 # training and validation functionality 186 # 187 188 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 189 self.model.train() 190 191 n_iter = 0 192 t_per_iter = time.time() 193 194 # Sample from both the supervised and unsupervised loader. 195 for xu1, xu2 in self.unsupervised_train_loader: 196 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 197 198 teacher_input, model_input = xu1, xu2 199 200 with forward_context(), torch.no_grad(): 201 # Compute the pseudo labels. 202 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 203 204 pseudo_labels = pseudo_labels.detach() 205 if label_filter is not None: 206 label_filter = label_filter.detach() 207 208 # Perform distribution alignment for pseudo labels 209 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 210 211 self.optimizer.zero_grad() 212 # Perform unsupervised training 213 with forward_context(): 214 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 215 216 backprop(loss) 217 218 if self.logger is not None: 219 with torch.no_grad(), forward_context(): 220 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 221 self.logger.log_train_unsupervised( 222 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 223 ) 224 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 225 self.logger.log_lr(self._iteration, lr) 226 if self.pseudo_labeler.confidence_threshold is not None: 227 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 228 229 self._iteration += 1 230 n_iter += 1 231 if self._iteration >= self.max_iteration: 232 break 233 progress.update(1) 234 235 t_per_iter = (time.time() - t_per_iter) / n_iter 236 return t_per_iter 237 238 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 239 self.model.train() 240 241 n_iter = 0 242 t_per_iter = time.time() 243 244 # Sample from both the supervised and unsupervised loader. 245 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 246 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 247 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 248 249 # Perform supervised training. 250 self.optimizer.zero_grad() 251 with forward_context(): 252 # We pass the model, the input and the labels to the supervised loss function, 253 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 254 supervised_loss = self.supervised_loss(self.model, xs, ys) 255 256 teacher_input, model_input = xu1, xu2 257 258 with forward_context(), torch.no_grad(): 259 # Compute the pseudo labels. 260 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 261 262 pseudo_labels = pseudo_labels.detach() 263 if label_filter is not None: 264 label_filter = label_filter.detach() 265 266 # Perform distribution alignment for pseudo labels 267 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 268 269 # Perform unsupervised training 270 with forward_context(): 271 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 272 273 loss = (supervised_loss + unsupervised_loss) / 2 274 backprop(loss) 275 276 if self.logger is not None: 277 with torch.no_grad(), forward_context(): 278 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 279 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 280 281 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 282 self.logger.log_train_unsupervised( 283 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 284 ) 285 286 self.logger.log_combined_loss(self._iteration, loss) 287 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 288 self.logger.log_lr(self._iteration, lr) 289 if self.pseudo_labeler.confidence_threshold is not None: 290 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 291 292 self._iteration += 1 293 n_iter += 1 294 if self._iteration >= self.max_iteration: 295 break 296 progress.update(1) 297 298 t_per_iter = (time.time() - t_per_iter) / n_iter 299 return t_per_iter 300 301 def _validate_supervised(self, forward_context): 302 metric_val = 0.0 303 loss_val = 0.0 304 305 for x, y in self.supervised_val_loader: 306 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 307 with forward_context(): 308 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 309 loss_val += loss.item() 310 metric_val += metric.item() 311 312 metric_val /= len(self.supervised_val_loader) 313 loss_val /= len(self.supervised_val_loader) 314 315 if self.logger is not None: 316 with forward_context(): 317 pred = self.model(x) 318 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 319 320 return metric_val 321 322 def _validate_unsupervised(self, forward_context): 323 metric_val = 0.0 324 loss_val = 0.0 325 326 for x1, x2 in self.unsupervised_val_loader: 327 x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True) 328 teacher_input, model_input = x1, x2 329 with forward_context(): 330 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 331 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 332 loss_val += loss.item() 333 metric_val += metric.item() 334 335 metric_val /= len(self.unsupervised_val_loader) 336 loss_val /= len(self.unsupervised_val_loader) 337 338 if self.logger is not None: 339 with forward_context(): 340 pred = self.model(model_input) 341 self.logger.log_validation_unsupervised( 342 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 343 ) 344 345 return metric_val 346 347 def _validate_impl(self, forward_context): 348 self.model.eval() 349 350 with torch.no_grad(): 351 352 if self.supervised_val_loader is None: 353 supervised_metric = None 354 else: 355 supervised_metric = self._validate_supervised(forward_context) 356 357 if self.unsupervised_val_loader is None: 358 unsupervised_metric = None 359 else: 360 unsupervised_metric = self._validate_unsupervised(forward_context) 361 362 if unsupervised_metric is None: 363 metric = supervised_metric 364 elif supervised_metric is None: 365 metric = unsupervised_metric 366 else: 367 metric = (supervised_metric + unsupervised_metric) / 2 368 369 return metric
Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.
FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels on unlabeled data. We support two training strategies:
- Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
- Taining only on the unsupervised data.
This class expects the following data loaders:
- unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
- supervised_train_loader (optional): Returns input and labels.
- unsupervised_val_loader (optional): Same as unsupervised_train_loader
- supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.
The following arguments can be used to customize the pseudo labeling:
- pseudo_labeler: to compute the psuedo-labels
- Parameters: model, teacher_input
- Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
- unsupervised_loss: the loss between model predictions and pseudo labels
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss
- supervised_loss (optional): the supervised loss function
- Parameters: model, input, labels
- Returns: loss
- unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss, metric
- supervised_loss_and_metric (optional): the supervised loss function and metric
- Parameters: model, input, labels
- Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples.
Arguments:
- model: The model to be trained.
- unsupervised_train_loader: The loader for unsupervised training.
- unsupervised_loss: The loss for unsupervised training.
- pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
- supervised_train_loader: The loader for supervised training.
- supervised_loss: The loss for supervised training.
- unsupervised_loss_and_metric: The loss and metric for unsupervised training.
- supervised_loss_and_metric: The loss and metrhic for supervised training.
- logger: The logger.
- source_distribution: The ratio of labels in the source label distribution. If given, the predicted distribution of the trained model will be regularized to match this source label distribution.
- kwargs: Additional keyword arguments for
torch_em.trainer.DefaultTrainer.
67 def __init__( 68 self, 69 model: torch.nn.Module, 70 unsupervised_train_loader: torch.utils.data.DataLoader, 71 unsupervised_loss: torch.utils.data.DataLoader, 72 pseudo_labeler: Callable, 73 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 74 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 75 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 76 supervised_loss: Optional[Callable] = None, 77 unsupervised_loss_and_metric: Optional[Callable] = None, 78 supervised_loss_and_metric: Optional[Callable] = None, 79 logger=SelfTrainingTensorboardLogger, 80 source_distribution: List[float] = None, 81 **kwargs, 82 ): 83 # Do we have supervised data or not? 84 if supervised_train_loader is None: 85 # No. -> We use the unsupervised training logic. 86 train_loader = unsupervised_train_loader 87 self._train_epoch_impl = self._train_epoch_unsupervised 88 else: 89 # Yes. -> We use the semi-supervised training logic. 90 assert supervised_loss is not None 91 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 92 else unsupervised_train_loader 93 self._train_epoch_impl = self._train_epoch_semisupervised 94 95 self.unsupervised_train_loader = unsupervised_train_loader 96 self.supervised_train_loader = supervised_train_loader 97 98 # Check that we have at least one of supvervised / unsupervised val loader. 99 assert sum(( 100 supervised_val_loader is not None, 101 unsupervised_val_loader is not None, 102 )) > 0 103 self.supervised_val_loader = supervised_val_loader 104 self.unsupervised_val_loader = unsupervised_val_loader 105 106 if self.unsupervised_val_loader is None: 107 val_loader = self.supervised_val_loader 108 else: 109 val_loader = self.unsupervised_val_loader 110 111 # Check that we have at least one of supvervised / unsupervised loss and metric. 112 assert sum(( 113 supervised_loss_and_metric is not None, 114 unsupervised_loss_and_metric is not None, 115 )) > 0 116 self.supervised_loss_and_metric = supervised_loss_and_metric 117 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 118 119 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 120 kwargs.pop("train_loader", None) 121 kwargs.pop("val_loader", None) 122 kwargs.pop("metric", None) 123 kwargs.pop("loss", None) 124 super().__init__( 125 model=model, train_loader=train_loader, val_loader=val_loader, 126 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 127 ) 128 129 self.unsupervised_loss = unsupervised_loss 130 self.supervised_loss = supervised_loss 131 132 self.pseudo_labeler = pseudo_labeler 133 134 if source_distribution is None: 135 self.source_distribution = None 136 else: 137 self.source_distribution = torch.FloatTensor(source_distribution).to(self.device) 138 139 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit
372class FixMatchTrainerWithInvertibleAugmentations(FixMatchTrainer): 373 """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach. 374 375 FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). 376 It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels 377 on unlabeled data. We support two training strategies: 378 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 379 - Taining only on the unsupervised data. 380 381 This class expects the following data loaders: 382 - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input. 383 - supervised_train_loader (optional): Returns input and labels. 384 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 385 - supervised_val_loader (optional): Same as supervised_train_loader 386 At least one of unsupervised_val_loader and supervised_val_loader must be given. 387 388 The augmenter defines separate invertible transforms for teacher and student inputs. 389 Teacher and student views are generated independently, and the corresponding inverse 390 transforms map predictions and pseudo-labels back into a shared reference frame. 391 392 The following arguments can be used to customize the pseudo labeling: 393 - pseudo_labeler: to compute the psuedo-labels 394 - Parameters: model, teacher_input 395 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 396 - unsupervised_loss: the loss between model predictions and pseudo labels 397 - Parameters: model, model_input, pseudo_labels, label_filter 398 - Returns: loss 399 - supervised_loss (optional): the supervised loss function 400 - Parameters: model, input, labels 401 - Returns: loss 402 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 403 - Parameters: model, model_input, pseudo_labels, label_filter 404 - Returns: loss, metric 405 - supervised_loss_and_metric (optional): the supervised loss function and metric 406 - Parameters: model, input, labels 407 - Returns: loss, metric 408 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 409 410 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 411 for setting the ratio between supervised and unsupervised training samples. 412 413 Args: 414 model: The model to be trained. 415 unsupervised_train_loader: The loader for unsupervised training. 416 unsupervised_loss: The loss for unsupervised training. 417 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 418 augmenter: Invertible augmenter providing separate teacher and student transforms, 419 including corresponding inverse transforms to align predictions in a common frame. 420 supervised_train_loader: The loader for supervised training. 421 supervised_loss: The loss for supervised training. 422 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 423 supervised_loss_and_metric: The loss and metrhic for supervised training. 424 logger: The logger. 425 source_distribution: The ratio of labels in the source label distribution. 426 If given, the predicted distribution of the trained model will be regularized to 427 match this source label distribution. 428 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 429 """ 430 431 def __init__( 432 self, 433 model: torch.nn.Module, 434 unsupervised_train_loader: torch.utils.data.DataLoader, 435 unsupervised_loss: torch.utils.data.DataLoader, 436 pseudo_labeler: Callable, 437 augmenter: InvertibleAugmenter, 438 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 439 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 440 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 441 supervised_loss: Optional[Callable] = None, 442 unsupervised_loss_and_metric: Optional[Callable] = None, 443 supervised_loss_and_metric: Optional[Callable] = None, 444 logger=SelfTrainingTensorboardLogger, 445 source_distribution: List[float] = None, 446 **kwargs, 447 ): 448 super().__init__( 449 model=model, 450 unsupervised_train_loader=unsupervised_train_loader, 451 unsupervised_loss=unsupervised_loss, 452 pseudo_labeler=pseudo_labeler, 453 supervised_train_loader=supervised_train_loader, 454 unsupervised_val_loader=unsupervised_val_loader, 455 supervised_val_loader=supervised_val_loader, 456 supervised_loss=supervised_loss, 457 unsupervised_loss_and_metric=unsupervised_loss_and_metric, 458 supervised_loss_and_metric=supervised_loss_and_metric, 459 logger=logger, 460 source_distribution=source_distribution, 461 **kwargs, 462 ) 463 464 self.augmenter = augmenter 465 466 # 467 # training and validation functionality 468 # 469 470 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 471 self.model.train() 472 473 n_iter = 0 474 t_per_iter = time.time() 475 476 # Sample from both the supervised and unsupervised loader. 477 for xu in self.unsupervised_train_loader: 478 self.augmenter.reset_all() 479 xu = xu.to(self.device, non_blocking=True) 480 481 xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu) 482 teacher_input, model_input = xu1, xu2 483 484 with forward_context(), torch.no_grad(): 485 # Compute the pseudo labels. 486 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 487 488 pseudo_labels = pseudo_labels.detach() 489 if label_filter is not None: 490 label_filter = label_filter.detach() 491 492 # Perform distribution alignment for pseudo labels, then invert into the reference frame. 493 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 494 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 495 label_filter_inv = ( 496 self.augmenter.teacher.reverse_transform(label_filter) 497 if label_filter is not None else None 498 ) 499 500 self.optimizer.zero_grad() 501 # Perform unsupervised training 502 with forward_context(): 503 pred = self.model(model_input) 504 pred_inv = self.augmenter.student.reverse_transform(pred) 505 loss = self.unsupervised_loss(pred_inv, pseudo_labels_inv, label_filter_inv) 506 507 backprop(loss) 508 509 if self.logger is not None: 510 with torch.no_grad(), forward_context(): 511 pred = pred if self._iteration % self.log_image_interval == 0 else None 512 self.logger.log_train_unsupervised( 513 self._iteration, loss, xu, xu, pred_inv, pseudo_labels_inv, label_filter_inv 514 ) 515 self.logger.log_train_augmentations( 516 self._iteration, xu1, xu2, pseudo_labels, pred, 517 ) 518 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 519 self.logger.log_lr(self._iteration, lr) 520 if self.pseudo_labeler.confidence_threshold is not None: 521 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 522 523 self._iteration += 1 524 n_iter += 1 525 if self._iteration >= self.max_iteration: 526 break 527 progress.update(1) 528 529 t_per_iter = (time.time() - t_per_iter) / n_iter 530 return t_per_iter 531 532 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 533 self.model.train() 534 535 n_iter = 0 536 t_per_iter = time.time() 537 538 # Sample from both the supervised and unsupervised loader. 539 for (xs, ys), xu in zip(self.supervised_train_loader, self.unsupervised_train_loader): 540 self.augmenter.reset_all() 541 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 542 xu = xu.to(self.device, non_blocking=True) 543 544 xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu) 545 teacher_input, model_input = xu1, xu2 546 547 # Perform supervised training. 548 self.optimizer.zero_grad() 549 with forward_context(): 550 # We pass the model, the input and the labels to the supervised loss function, 551 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 552 supervised_pred = self.model(xs) 553 supervised_loss = self.supervised_loss(supervised_pred, ys) 554 555 with forward_context(), torch.no_grad(): 556 # Compute the pseudo labels. 557 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 558 559 pseudo_labels = pseudo_labels.detach() 560 if label_filter is not None: 561 label_filter = label_filter.detach() 562 563 # Perform distribution alignment for pseudo labels, then invert into the reference frame. 564 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 565 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 566 label_filter_inv = ( 567 self.augmenter.teacher.reverse_transform(label_filter) 568 if label_filter is not None else None 569 ) 570 571 # Perform unsupervised training 572 with forward_context(): 573 unsup_pred = self.model(model_input) 574 unsup_pred_inv = self.augmenter.student.reverse_transform(unsup_pred) 575 unsupervised_loss = self.unsupervised_loss(unsup_pred_inv, pseudo_labels_inv, label_filter_inv) 576 577 loss = (supervised_loss + unsupervised_loss) / 2 578 backprop(loss) 579 580 if self.logger is not None: 581 with torch.no_grad(), forward_context(): 582 unsup_pred = unsup_pred if self._iteration % self.log_image_interval == 0 else None 583 supervised_pred = supervised_pred if self._iteration % self.log_image_interval == 0 else None 584 585 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 586 self.logger.log_train_unsupervised( 587 self._iteration, unsupervised_loss, xu, xu, unsup_pred_inv, pseudo_labels_inv, label_filter_inv 588 ) 589 self.logger.log_train_augmentations( 590 self._iteration, xu1, xu2, pseudo_labels, unsup_pred, 591 ) 592 593 self.logger.log_combined_loss(self._iteration, loss) 594 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 595 self.logger.log_lr(self._iteration, lr) 596 if self.pseudo_labeler.confidence_threshold is not None: 597 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 598 599 self._iteration += 1 600 n_iter += 1 601 if self._iteration >= self.max_iteration: 602 break 603 progress.update(1) 604 605 t_per_iter = (time.time() - t_per_iter) / n_iter 606 return t_per_iter 607 608 def _validate_supervised(self, forward_context): 609 metric_val = 0.0 610 loss_val = 0.0 611 612 for x, y in self.supervised_val_loader: 613 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 614 with forward_context(): 615 pred = self.model(x) 616 loss, metric = self.supervised_loss_and_metric(pred, y) 617 loss_val += loss.item() 618 metric_val += metric.item() 619 620 metric_val /= len(self.supervised_val_loader) 621 loss_val /= len(self.supervised_val_loader) 622 623 if self.logger is not None: 624 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 625 626 return metric_val 627 628 def _validate_unsupervised(self, forward_context): 629 metric_val = 0.0 630 loss_val = 0.0 631 632 for x in self.unsupervised_val_loader: 633 self.augmenter.reset_all() 634 x = x.to(self.device, non_blocking=True) 635 636 x1, x2 = self.augmenter.teacher.transform(x), self.augmenter.student.transform(x) 637 teacher_input, model_input = x1, x2 638 639 with forward_context(): 640 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 641 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 642 label_filter_inv = ( 643 self.augmenter.teacher.reverse_transform(label_filter) 644 if label_filter is not None else None 645 ) 646 647 pred = self.model(model_input) 648 pred_inv = self.augmenter.student.reverse_transform(pred) 649 loss, metric = self.unsupervised_loss_and_metric(pred_inv, pseudo_labels_inv, label_filter_inv) 650 loss_val += loss.item() 651 metric_val += metric.item() 652 653 metric_val /= len(self.unsupervised_val_loader) 654 loss_val /= len(self.unsupervised_val_loader) 655 656 if self.logger is not None: 657 self.logger.log_validation_unsupervised( 658 self._iteration, metric_val, loss_val, x, x, pred_inv, pseudo_labels_inv, label_filter_inv 659 ) 660 self.logger.log_validation_augmentations( 661 self._iteration, x1, x2, pseudo_labels, pred, 662 ) 663 664 self.pseudo_labeler.step(metric_val, self._epoch) 665 666 return metric_val 667 668 def _validate_impl(self, forward_context): 669 self.model.eval() 670 671 with torch.no_grad(): 672 673 if self.supervised_val_loader is None: 674 supervised_metric = None 675 else: 676 supervised_metric = self._validate_supervised(forward_context) 677 678 if self.unsupervised_val_loader is None: 679 unsupervised_metric = None 680 else: 681 unsupervised_metric = self._validate_unsupervised(forward_context) 682 683 if unsupervised_metric is None: 684 metric = supervised_metric 685 elif supervised_metric is None: 686 metric = unsupervised_metric 687 else: 688 metric = (supervised_metric + unsupervised_metric) / 2 689 690 return metric
Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.
FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels on unlabeled data. We support two training strategies:
- Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
- Taining only on the unsupervised data.
This class expects the following data loaders:
- unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
- supervised_train_loader (optional): Returns input and labels.
- unsupervised_val_loader (optional): Same as unsupervised_train_loader
- supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.
The augmenter defines separate invertible transforms for teacher and student inputs. Teacher and student views are generated independently, and the corresponding inverse transforms map predictions and pseudo-labels back into a shared reference frame.
The following arguments can be used to customize the pseudo labeling:
- pseudo_labeler: to compute the psuedo-labels
- Parameters: model, teacher_input
- Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
- unsupervised_loss: the loss between model predictions and pseudo labels
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss
- supervised_loss (optional): the supervised loss function
- Parameters: model, input, labels
- Returns: loss
- unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss, metric
- supervised_loss_and_metric (optional): the supervised loss function and metric
- Parameters: model, input, labels
- Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples.
Arguments:
- model: The model to be trained.
- unsupervised_train_loader: The loader for unsupervised training.
- unsupervised_loss: The loss for unsupervised training.
- pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
- augmenter: Invertible augmenter providing separate teacher and student transforms, including corresponding inverse transforms to align predictions in a common frame.
- supervised_train_loader: The loader for supervised training.
- supervised_loss: The loss for supervised training.
- unsupervised_loss_and_metric: The loss and metric for unsupervised training.
- supervised_loss_and_metric: The loss and metrhic for supervised training.
- logger: The logger.
- source_distribution: The ratio of labels in the source label distribution. If given, the predicted distribution of the trained model will be regularized to match this source label distribution.
- kwargs: Additional keyword arguments for
torch_em.trainer.DefaultTrainer.
431 def __init__( 432 self, 433 model: torch.nn.Module, 434 unsupervised_train_loader: torch.utils.data.DataLoader, 435 unsupervised_loss: torch.utils.data.DataLoader, 436 pseudo_labeler: Callable, 437 augmenter: InvertibleAugmenter, 438 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 439 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 440 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 441 supervised_loss: Optional[Callable] = None, 442 unsupervised_loss_and_metric: Optional[Callable] = None, 443 supervised_loss_and_metric: Optional[Callable] = None, 444 logger=SelfTrainingTensorboardLogger, 445 source_distribution: List[float] = None, 446 **kwargs, 447 ): 448 super().__init__( 449 model=model, 450 unsupervised_train_loader=unsupervised_train_loader, 451 unsupervised_loss=unsupervised_loss, 452 pseudo_labeler=pseudo_labeler, 453 supervised_train_loader=supervised_train_loader, 454 unsupervised_val_loader=unsupervised_val_loader, 455 supervised_val_loader=supervised_val_loader, 456 supervised_loss=supervised_loss, 457 unsupervised_loss_and_metric=unsupervised_loss_and_metric, 458 supervised_loss_and_metric=supervised_loss_and_metric, 459 logger=logger, 460 source_distribution=source_distribution, 461 **kwargs, 462 ) 463 464 self.augmenter = augmenter
Inherited Members
- FixMatchTrainer
- unsupervised_train_loader
- supervised_train_loader
- supervised_val_loader
- unsupervised_val_loader
- supervised_loss_and_metric
- unsupervised_loss_and_metric
- unsupervised_loss
- supervised_loss
- pseudo_labeler
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit