torch_em.self_training.mean_teacher
1import time 2from copy import deepcopy 3from typing import Callable, Optional 4 5import torch 6import torch_em 7from torch_em.util import get_constructor_arguments 8 9from .logger import SelfTrainingTensorboardLogger 10from ..transform.invertible_augmentations import InvertibleAugmenter 11 12 13class Dummy(torch.nn.Module): 14 init_kwargs = {} 15 16 17class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): 18 """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach. 19 20 Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. 21 It uses a teacher model derived from the student model via EMA of weights 22 to predict pseudo-labels on unlabeled data. We support two training strategies: 23 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 24 - Training only on the unsupervised data. 25 26 This class expects the following data loaders: 27 - unsupervised_train_loader: Returns two augmentations of the same input. 28 - supervised_train_loader (optional): Returns input and labels. 29 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 30 - supervised_val_loader (optional): Same as supervised_train_loader 31 At least one of unsupervised_val_loader and supervised_val_loader must be given. 32 33 And the following elements to customize the pseudo labeling: 34 - pseudo_labeler: to compute the psuedo-labels 35 - Parameters: teacher, teacher_input 36 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 37 - unsupervised_loss: the loss between model predictions and pseudo labels 38 - Parameters: model, model_input, pseudo_labels, label_filter 39 - Returns: loss 40 - supervised_loss (optional): the supervised loss function 41 - Parameters: model, input, labels 42 - Returns: loss 43 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 44 - Parameters: model, model_input, pseudo_labels, label_filter 45 - Returns: loss, metric 46 - supervised_loss_and_metric (optional): the supervised loss function and metric 47 - Parameters: model, input, labels 48 - Returns: loss, metric 49 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 50 51 If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. 52 If it is None, the most appropriate initialization scheme for the training approach is chosen: 53 - semi-supervised training -> reinit, because we usually train a model from scratch 54 - unsupervised training -> do not reinit, because we usually fine-tune a model 55 56 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 57 for setting the ratio between supervised and unsupervised training samples 58 59 Args: 60 model: The model to be trained. 61 unsupervised_train_loader: The loader for unsupervised training. 62 unsupervised_loss: The loss for unsupervised training. 63 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 64 supervised_train_loader: The loader for supervised training. 65 supervised_loss: The loss for supervised training. 66 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 67 supervised_loss_and_metric: The loss and metrhic for supervised training. 68 logger: The logger. 69 momentum: The momentum value for the exponential moving weight average of the teacher model. 70 reinit_teacher: Whether to reinit the teacher model before starting the training. 71 sampler: A sampler for rejecting pseudo-labels according to a defined criterion. 72 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 73 """ 74 75 def __init__( 76 self, 77 model: torch.nn.Module, 78 unsupervised_train_loader: torch.utils.data.DataLoader, 79 unsupervised_loss: torch.utils.data.DataLoader, 80 pseudo_labeler: Callable, 81 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 82 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 83 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 84 supervised_loss: Optional[Callable] = None, 85 unsupervised_loss_and_metric: Optional[Callable] = None, 86 supervised_loss_and_metric: Optional[Callable] = None, 87 logger=SelfTrainingTensorboardLogger, 88 momentum: float = 0.999, 89 reinit_teacher: Optional[bool] = None, 90 sampler: Optional[Callable] = None, 91 **kwargs, 92 ): 93 self.sampler = sampler 94 # Do we have supervised data or not? 95 if supervised_train_loader is None: 96 # No. -> We use the unsupervised training logic. 97 train_loader = unsupervised_train_loader 98 self._train_epoch_impl = self._train_epoch_unsupervised 99 else: 100 # Yes. -> We use the semi-supervised training logic. 101 assert supervised_loss is not None 102 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 103 else unsupervised_train_loader 104 self._train_epoch_impl = self._train_epoch_semisupervised 105 106 self.unsupervised_train_loader = unsupervised_train_loader 107 self.supervised_train_loader = supervised_train_loader 108 109 # Check that we have at least one of supvervised / unsupervised val loader. 110 assert sum(( 111 supervised_val_loader is not None, 112 unsupervised_val_loader is not None, 113 )) > 0 114 self.supervised_val_loader = supervised_val_loader 115 self.unsupervised_val_loader = unsupervised_val_loader 116 117 if self.unsupervised_val_loader is None: 118 val_loader = self.supervised_val_loader 119 else: 120 val_loader = self.unsupervised_train_loader 121 122 # Check that we have at least one of supvervised / unsupervised loss and metric. 123 assert sum(( 124 supervised_loss_and_metric is not None, 125 unsupervised_loss_and_metric is not None, 126 )) > 0 127 self.supervised_loss_and_metric = supervised_loss_and_metric 128 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 129 130 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 131 kwargs.pop("train_loader", None) 132 kwargs.pop("val_loader", None) 133 kwargs.pop("metric", None) 134 kwargs.pop("loss", None) 135 super().__init__( 136 model=model, train_loader=train_loader, val_loader=val_loader, 137 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 138 ) 139 140 self.unsupervised_loss = unsupervised_loss 141 self.supervised_loss = supervised_loss 142 143 self.pseudo_labeler = pseudo_labeler 144 self.momentum = momentum 145 146 # determine how we initialize the teacher weights (copy or reinitialization) 147 if reinit_teacher is None: 148 # semisupervised training: reinitialize 149 # unsupervised training: copy 150 self.reinit_teacher = supervised_train_loader is not None 151 else: 152 self.reinit_teacher = reinit_teacher 153 154 with torch.no_grad(): 155 self.teacher = deepcopy(self.model) 156 if self.reinit_teacher: 157 for layer in self.teacher.children(): 158 if hasattr(layer, "reset_parameters"): 159 layer.reset_parameters() 160 for param in self.teacher.parameters(): 161 param.requires_grad = False 162 163 self._kwargs = kwargs 164 165 def _momentum_update(self): 166 # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations 167 # to avoid a large gap between teacher and student weights, leading to inconsistent predictions 168 # if we don't reinit this is not necessary 169 if self.reinit_teacher: 170 current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum) 171 else: 172 current_momentum = self.momentum 173 174 for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()): 175 param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum) 176 177 # 178 # functionality for saving checkpoints and initialization 179 # 180 181 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 182 """@private 183 """ 184 train_loader_kwargs = get_constructor_arguments(self.train_loader) 185 val_loader_kwargs = get_constructor_arguments(self.val_loader) 186 extra_state = { 187 "teacher_state": self.teacher.state_dict(), 188 "init": { 189 "train_loader_kwargs": train_loader_kwargs, 190 "train_dataset": self.train_loader.dataset, 191 "val_loader_kwargs": val_loader_kwargs, 192 "val_dataset": self.val_loader.dataset, 193 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 194 "loss_kwargs": {}, 195 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 196 "metric_kwargs": {}, 197 }, 198 } 199 extra_state.update(**extra_save_dict) 200 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 201 202 def load_checkpoint(self, checkpoint="best"): 203 """@private 204 """ 205 save_dict = super().load_checkpoint(checkpoint) 206 self.teacher.load_state_dict(save_dict["teacher_state"]) 207 self.teacher.to(self.device) 208 return save_dict 209 210 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 211 best_metric = super()._initialize(iterations, load_from_checkpoint, epochs) 212 self.teacher.to(self.device) 213 return best_metric 214 215 # 216 # training and validation functionality 217 # 218 219 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 220 self.model.train() 221 222 n_iter = 0 223 t_per_iter = time.time() 224 225 # Sample from both the supervised and unsupervised loader. 226 for xu1, xu2 in self.unsupervised_train_loader: 227 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 228 229 teacher_input, model_input = xu1, xu2 230 231 with forward_context(), torch.no_grad(): 232 # Compute the pseudo labels. 233 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 234 235 # If we have a sampler then check if the current batch matches the condition for inclusion in training. 236 if self.sampler is not None: 237 keep_batch = self.sampler(pseudo_labels, label_filter) 238 if not keep_batch: 239 continue 240 241 self.optimizer.zero_grad() 242 # Perform unsupervised training 243 with forward_context(): 244 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 245 backprop(loss) 246 247 if self.logger is not None: 248 with torch.no_grad(), forward_context(): 249 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 250 self.logger.log_train_unsupervised( 251 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 252 ) 253 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 254 self.logger.log_lr(self._iteration, lr) 255 if self.pseudo_labeler.confidence_threshold is not None: 256 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 257 258 with torch.no_grad(): 259 self._momentum_update() 260 261 self._iteration += 1 262 n_iter += 1 263 if self._iteration >= self.max_iteration: 264 break 265 progress.update(1) 266 267 t_per_iter = (time.time() - t_per_iter) / n_iter 268 return t_per_iter 269 270 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 271 self.model.train() 272 273 n_iter = 0 274 t_per_iter = time.time() 275 276 # Sample from both the supervised and unsupervised loader. 277 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 278 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 279 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 280 281 # Perform supervised training. 282 self.optimizer.zero_grad() 283 with forward_context(): 284 # We pass the model, the input and the labels to the supervised loss function, 285 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 286 supervised_loss = self.supervised_loss(self.model, xs, ys) 287 288 teacher_input, model_input = xu1, xu2 289 290 with forward_context(), torch.no_grad(): 291 # Compute the pseudo labels. 292 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 293 294 # Perform unsupervised training 295 with forward_context(): 296 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 297 298 loss = (supervised_loss + unsupervised_loss) / 2 299 backprop(loss) 300 301 if self.logger is not None: 302 with torch.no_grad(), forward_context(): 303 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 304 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 305 306 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 307 self.logger.log_train_unsupervised( 308 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 309 ) 310 311 self.logger.log_combined_loss(self._iteration, loss) 312 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 313 self.logger.log_lr(self._iteration, lr) 314 if self.pseudo_labeler.confidence_threshold is not None: 315 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 316 317 with torch.no_grad(): 318 self._momentum_update() 319 320 self._iteration += 1 321 n_iter += 1 322 if self._iteration >= self.max_iteration: 323 break 324 progress.update(1) 325 326 t_per_iter = (time.time() - t_per_iter) / n_iter 327 return t_per_iter 328 329 def _validate_supervised(self, forward_context): 330 metric_val = 0.0 331 loss_val = 0.0 332 333 for x, y in self.supervised_val_loader: 334 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 335 with forward_context(): 336 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 337 loss_val += loss.item() 338 metric_val += metric.item() 339 340 metric_val /= len(self.supervised_val_loader) 341 loss_val /= len(self.supervised_val_loader) 342 343 if self.logger is not None: 344 with forward_context(): 345 pred = self.model(x) 346 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 347 348 return metric_val 349 350 def _validate_unsupervised(self, forward_context): 351 metric_val = 0.0 352 loss_val = 0.0 353 354 for x1, x2 in self.unsupervised_val_loader: 355 x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True) 356 teacher_input, model_input = x1, x2 357 with forward_context(): 358 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 359 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 360 loss_val += loss.item() 361 metric_val += metric.item() 362 363 metric_val /= len(self.unsupervised_val_loader) 364 loss_val /= len(self.unsupervised_val_loader) 365 366 if self.logger is not None: 367 with forward_context(): 368 pred = self.model(model_input) 369 self.logger.log_validation_unsupervised( 370 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 371 ) 372 373 self.pseudo_labeler.step(metric_val, self._epoch) # NOTE: scheduler added in validation step 374 375 return metric_val 376 377 def _validate_impl(self, forward_context): 378 self.model.eval() 379 380 with torch.no_grad(): 381 382 if self.supervised_val_loader is None: 383 supervised_metric = None 384 else: 385 supervised_metric = self._validate_supervised(forward_context) 386 387 if self.unsupervised_val_loader is None: 388 unsupervised_metric = None 389 else: 390 unsupervised_metric = self._validate_unsupervised(forward_context) 391 392 if unsupervised_metric is None: 393 metric = supervised_metric 394 elif supervised_metric is None: 395 metric = unsupervised_metric 396 else: 397 metric = (supervised_metric + unsupervised_metric) / 2 398 399 return metric 400 401 402class MeanTeacherTrainerWithInvertibleAugmentations(MeanTeacherTrainer): 403 """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach. 404 405 Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. 406 It uses a teacher model derived from the student model via EMA of weights 407 to predict pseudo-labels on unlabeled data. We support two training strategies: 408 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 409 - Training only on the unsupervised data. 410 411 This class expects the following data loaders: 412 - unsupervised_train_loader: Returns two augmentations of the same input. 413 - supervised_train_loader (optional): Returns input and labels. 414 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 415 - supervised_val_loader (optional): Same as supervised_train_loader 416 At least one of unsupervised_val_loader and supervised_val_loader must be given. 417 418 The augmenter defines separate invertible transforms for teacher and student inputs. 419 Teacher and student views are generated independently, and the corresponding inverse 420 transforms map predictions and pseudo-labels back into a shared reference frame. 421 422 And the following elements to customize the pseudo labeling: 423 - pseudo_labeler: to compute the psuedo-labels 424 - Parameters: teacher, teacher_input 425 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 426 - unsupervised_loss: the loss between model predictions and pseudo labels 427 - Parameters: model, model_input, pseudo_labels, label_filter 428 - Returns: loss 429 - supervised_loss (optional): the supervised loss function 430 - Parameters: model, input, labels 431 - Returns: loss 432 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 433 - Parameters: model, model_input, pseudo_labels, label_filter 434 - Returns: loss, metric 435 - supervised_loss_and_metric (optional): the supervised loss function and metric 436 - Parameters: model, input, labels 437 - Returns: loss, metric 438 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 439 440 If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. 441 If it is None, the most appropriate initialization scheme for the training approach is chosen: 442 - semi-supervised training -> reinit, because we usually train a model from scratch 443 - unsupervised training -> do not reinit, because we usually fine-tune a model 444 445 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 446 for setting the ratio between supervised and unsupervised training samples 447 448 Args: 449 model: The model to be trained. 450 unsupervised_train_loader: The loader for unsupervised training. 451 unsupervised_loss: The loss for unsupervised training. 452 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 453 augmenter: Invertible augmenter providing separate teacher and student transforms, 454 including corresponding inverse transforms to align predictions in a common frame. 455 supervised_train_loader: The loader for supervised training. 456 supervised_loss: The loss for supervised training. 457 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 458 supervised_loss_and_metric: The loss and metrhic for supervised training. 459 logger: The logger. 460 momentum: The momentum value for the exponential moving weight average of the teacher model. 461 reinit_teacher: Whether to reinit the teacher model before starting the training. 462 sampler: A sampler for rejecting pseudo-labels according to a defined criterion. 463 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 464 """ 465 466 def __init__( 467 self, 468 model: torch.nn.Module, 469 unsupervised_train_loader: torch.utils.data.DataLoader, 470 unsupervised_loss: torch.utils.data.DataLoader, 471 pseudo_labeler: Callable, 472 augmenter: InvertibleAugmenter, 473 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 474 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 475 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 476 supervised_loss: Optional[Callable] = None, 477 unsupervised_loss_and_metric: Optional[Callable] = None, 478 supervised_loss_and_metric: Optional[Callable] = None, 479 logger=SelfTrainingTensorboardLogger, 480 momentum: float = 0.999, 481 reinit_teacher: Optional[bool] = None, 482 sampler: Optional[Callable] = None, 483 **kwargs, 484 ): 485 super().__init__( 486 model=model, 487 unsupervised_train_loader=unsupervised_train_loader, 488 unsupervised_loss=unsupervised_loss, 489 pseudo_labeler=pseudo_labeler, 490 supervised_train_loader=supervised_train_loader, 491 unsupervised_val_loader=unsupervised_val_loader, 492 supervised_val_loader=supervised_val_loader, 493 supervised_loss=supervised_loss, 494 unsupervised_loss_and_metric=unsupervised_loss_and_metric, 495 supervised_loss_and_metric=supervised_loss_and_metric, 496 logger=logger, 497 momentum=momentum, 498 reinit_teacher=reinit_teacher, 499 sampler=sampler, 500 **kwargs, 501 ) 502 self.augmenter = augmenter 503 504 # 505 # training and validation functionality 506 # 507 508 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 509 self.model.train() 510 511 n_iter = 0 512 t_per_iter = time.time() 513 514 # Sample from both the supervised and unsupervised loader. 515 for xu in self.unsupervised_train_loader: 516 self.augmenter.reset_all() 517 xu = xu.to(self.device, non_blocking=True) 518 519 xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu) 520 teacher_input, model_input = xu1, xu2 521 522 with forward_context(), torch.no_grad(): 523 # Compute the pseudo labels. 524 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 525 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 526 label_filter_inv = ( 527 self.augmenter.teacher.reverse_transform(label_filter) 528 if label_filter is not None else None 529 ) 530 531 # If we have a sampler then check if the current batch matches the condition for inclusion in training. 532 if self.sampler is not None: 533 keep_batch = self.sampler(pseudo_labels, label_filter) 534 if not keep_batch: 535 continue 536 537 self.optimizer.zero_grad() 538 # Perform unsupervised training 539 with forward_context(): 540 pred = self.model(model_input) 541 pred_inv = self.augmenter.student.reverse_transform(pred) 542 loss = self.unsupervised_loss(pred_inv, pseudo_labels_inv, label_filter_inv) 543 backprop(loss) 544 545 if self.logger is not None: 546 with torch.no_grad(), forward_context(): 547 pred = pred if self._iteration % self.log_image_interval == 0 else None 548 self.logger.log_train_unsupervised( 549 self._iteration, loss, xu, xu, pred_inv, pseudo_labels_inv, label_filter_inv 550 ) 551 self.logger.log_train_augmentations( 552 self._iteration, xu1, xu2, pseudo_labels, pred, 553 ) 554 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 555 self.logger.log_lr(self._iteration, lr) 556 if self.pseudo_labeler.confidence_threshold is not None: 557 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 558 559 with torch.no_grad(): 560 self._momentum_update() 561 562 self._iteration += 1 563 n_iter += 1 564 if self._iteration >= self.max_iteration: 565 break 566 progress.update(1) 567 568 t_per_iter = (time.time() - t_per_iter) / n_iter 569 return t_per_iter 570 571 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 572 self.model.train() 573 574 n_iter = 0 575 t_per_iter = time.time() 576 577 # Sample from both the supervised and unsupervised loader. 578 for (xs, ys), xu in zip(self.supervised_train_loader, self.unsupervised_train_loader): 579 self.augmenter.reset_all() 580 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 581 xu = xu.to(self.device, non_blocking=True) 582 583 xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu) 584 teacher_input, model_input = xu1, xu2 585 586 # Perform supervised training. 587 self.optimizer.zero_grad() 588 with forward_context(): 589 # We pass the model, the input and the labels to the supervised loss function, 590 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 591 supervised_pred = self.model(xs) 592 supervised_loss = self.supervised_loss(supervised_pred, ys) 593 594 with forward_context(), torch.no_grad(): 595 # Compute the pseudo labels. 596 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 597 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 598 label_filter_inv = ( 599 self.augmenter.teacher.reverse_transform(label_filter) 600 if label_filter is not None else None 601 ) 602 603 # Perform unsupervised training 604 with forward_context(): 605 unsup_pred = self.model(model_input) 606 unsup_pred_inv = self.augmenter.student.reverse_transform(unsup_pred) 607 unsupervised_loss = self.unsupervised_loss(unsup_pred_inv, pseudo_labels_inv, label_filter_inv) 608 609 loss = (supervised_loss + unsupervised_loss) / 2 610 backprop(loss) 611 612 if self.logger is not None: 613 with torch.no_grad(), forward_context(): 614 unsup_pred = unsup_pred if self._iteration % self.log_image_interval == 0 else None 615 supervised_pred = supervised_pred if self._iteration % self.log_image_interval == 0 else None 616 617 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 618 self.logger.log_train_unsupervised( 619 self._iteration, unsupervised_loss, xu, xu, unsup_pred_inv, pseudo_labels_inv, label_filter_inv 620 ) 621 self.logger.log_train_augmentations( 622 self._iteration, xu1, xu2, pseudo_labels, unsup_pred, 623 ) 624 625 self.logger.log_combined_loss(self._iteration, loss) 626 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 627 self.logger.log_lr(self._iteration, lr) 628 if self.pseudo_labeler.confidence_threshold is not None: 629 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 630 631 with torch.no_grad(): 632 self._momentum_update() 633 634 self._iteration += 1 635 n_iter += 1 636 if self._iteration >= self.max_iteration: 637 break 638 progress.update(1) 639 640 t_per_iter = (time.time() - t_per_iter) / n_iter 641 return t_per_iter 642 643 def _validate_supervised(self, forward_context): 644 metric_val = 0.0 645 loss_val = 0.0 646 647 for x, y in self.supervised_val_loader: 648 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 649 with forward_context(): 650 pred = self.model(x) 651 loss, metric = self.supervised_loss_and_metric(pred, y) 652 loss_val += loss.item() 653 metric_val += metric.item() 654 655 metric_val /= len(self.supervised_val_loader) 656 loss_val /= len(self.supervised_val_loader) 657 658 if self.logger is not None: 659 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 660 661 return metric_val 662 663 def _validate_unsupervised(self, forward_context): 664 metric_val = 0.0 665 loss_val = 0.0 666 667 for x in self.unsupervised_val_loader: 668 self.augmenter.reset_all() 669 x = x.to(self.device, non_blocking=True) 670 x1, x2 = self.augmenter.teacher.transform(x), self.augmenter.student.transform(x) 671 teacher_input, model_input = x1, x2 672 with forward_context(): 673 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 674 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 675 label_filter_inv = ( 676 self.augmenter.teacher.reverse_transform(label_filter) 677 if label_filter is not None else None 678 ) 679 680 pred = self.model(model_input) 681 pred_inv = self.augmenter.student.reverse_transform(pred) 682 loss, metric = self.unsupervised_loss_and_metric(pred_inv, pseudo_labels_inv, label_filter_inv) 683 loss_val += loss.item() 684 metric_val += metric.item() 685 686 metric_val /= len(self.unsupervised_val_loader) 687 loss_val /= len(self.unsupervised_val_loader) 688 689 if self.logger is not None: 690 self.logger.log_validation_unsupervised( 691 self._iteration, metric_val, loss_val, x, x, pred_inv, pseudo_labels_inv, label_filter_inv 692 ) 693 self.logger.log_validation_augmentations( 694 self._iteration, x1, x2, pseudo_labels, pred, 695 ) 696 697 self.pseudo_labeler.step(metric_val, self._epoch) 698 699 return metric_val 700 701 def _validate_impl(self, forward_context): 702 self.model.eval() 703 704 with torch.no_grad(): 705 706 if self.supervised_val_loader is None: 707 supervised_metric = None 708 else: 709 supervised_metric = self._validate_supervised(forward_context) 710 711 if self.unsupervised_val_loader is None: 712 unsupervised_metric = None 713 else: 714 unsupervised_metric = self._validate_unsupervised(forward_context) 715 716 if unsupervised_metric is None: 717 metric = supervised_metric 718 elif supervised_metric is None: 719 metric = unsupervised_metric 720 else: 721 metric = (supervised_metric + unsupervised_metric) / 2 722 723 return metric
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
18class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): 19 """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach. 20 21 Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. 22 It uses a teacher model derived from the student model via EMA of weights 23 to predict pseudo-labels on unlabeled data. We support two training strategies: 24 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 25 - Training only on the unsupervised data. 26 27 This class expects the following data loaders: 28 - unsupervised_train_loader: Returns two augmentations of the same input. 29 - supervised_train_loader (optional): Returns input and labels. 30 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 31 - supervised_val_loader (optional): Same as supervised_train_loader 32 At least one of unsupervised_val_loader and supervised_val_loader must be given. 33 34 And the following elements to customize the pseudo labeling: 35 - pseudo_labeler: to compute the psuedo-labels 36 - Parameters: teacher, teacher_input 37 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 38 - unsupervised_loss: the loss between model predictions and pseudo labels 39 - Parameters: model, model_input, pseudo_labels, label_filter 40 - Returns: loss 41 - supervised_loss (optional): the supervised loss function 42 - Parameters: model, input, labels 43 - Returns: loss 44 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 45 - Parameters: model, model_input, pseudo_labels, label_filter 46 - Returns: loss, metric 47 - supervised_loss_and_metric (optional): the supervised loss function and metric 48 - Parameters: model, input, labels 49 - Returns: loss, metric 50 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 51 52 If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. 53 If it is None, the most appropriate initialization scheme for the training approach is chosen: 54 - semi-supervised training -> reinit, because we usually train a model from scratch 55 - unsupervised training -> do not reinit, because we usually fine-tune a model 56 57 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 58 for setting the ratio between supervised and unsupervised training samples 59 60 Args: 61 model: The model to be trained. 62 unsupervised_train_loader: The loader for unsupervised training. 63 unsupervised_loss: The loss for unsupervised training. 64 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 65 supervised_train_loader: The loader for supervised training. 66 supervised_loss: The loss for supervised training. 67 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 68 supervised_loss_and_metric: The loss and metrhic for supervised training. 69 logger: The logger. 70 momentum: The momentum value for the exponential moving weight average of the teacher model. 71 reinit_teacher: Whether to reinit the teacher model before starting the training. 72 sampler: A sampler for rejecting pseudo-labels according to a defined criterion. 73 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 74 """ 75 76 def __init__( 77 self, 78 model: torch.nn.Module, 79 unsupervised_train_loader: torch.utils.data.DataLoader, 80 unsupervised_loss: torch.utils.data.DataLoader, 81 pseudo_labeler: Callable, 82 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 83 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 84 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 85 supervised_loss: Optional[Callable] = None, 86 unsupervised_loss_and_metric: Optional[Callable] = None, 87 supervised_loss_and_metric: Optional[Callable] = None, 88 logger=SelfTrainingTensorboardLogger, 89 momentum: float = 0.999, 90 reinit_teacher: Optional[bool] = None, 91 sampler: Optional[Callable] = None, 92 **kwargs, 93 ): 94 self.sampler = sampler 95 # Do we have supervised data or not? 96 if supervised_train_loader is None: 97 # No. -> We use the unsupervised training logic. 98 train_loader = unsupervised_train_loader 99 self._train_epoch_impl = self._train_epoch_unsupervised 100 else: 101 # Yes. -> We use the semi-supervised training logic. 102 assert supervised_loss is not None 103 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 104 else unsupervised_train_loader 105 self._train_epoch_impl = self._train_epoch_semisupervised 106 107 self.unsupervised_train_loader = unsupervised_train_loader 108 self.supervised_train_loader = supervised_train_loader 109 110 # Check that we have at least one of supvervised / unsupervised val loader. 111 assert sum(( 112 supervised_val_loader is not None, 113 unsupervised_val_loader is not None, 114 )) > 0 115 self.supervised_val_loader = supervised_val_loader 116 self.unsupervised_val_loader = unsupervised_val_loader 117 118 if self.unsupervised_val_loader is None: 119 val_loader = self.supervised_val_loader 120 else: 121 val_loader = self.unsupervised_train_loader 122 123 # Check that we have at least one of supvervised / unsupervised loss and metric. 124 assert sum(( 125 supervised_loss_and_metric is not None, 126 unsupervised_loss_and_metric is not None, 127 )) > 0 128 self.supervised_loss_and_metric = supervised_loss_and_metric 129 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 130 131 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 132 kwargs.pop("train_loader", None) 133 kwargs.pop("val_loader", None) 134 kwargs.pop("metric", None) 135 kwargs.pop("loss", None) 136 super().__init__( 137 model=model, train_loader=train_loader, val_loader=val_loader, 138 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 139 ) 140 141 self.unsupervised_loss = unsupervised_loss 142 self.supervised_loss = supervised_loss 143 144 self.pseudo_labeler = pseudo_labeler 145 self.momentum = momentum 146 147 # determine how we initialize the teacher weights (copy or reinitialization) 148 if reinit_teacher is None: 149 # semisupervised training: reinitialize 150 # unsupervised training: copy 151 self.reinit_teacher = supervised_train_loader is not None 152 else: 153 self.reinit_teacher = reinit_teacher 154 155 with torch.no_grad(): 156 self.teacher = deepcopy(self.model) 157 if self.reinit_teacher: 158 for layer in self.teacher.children(): 159 if hasattr(layer, "reset_parameters"): 160 layer.reset_parameters() 161 for param in self.teacher.parameters(): 162 param.requires_grad = False 163 164 self._kwargs = kwargs 165 166 def _momentum_update(self): 167 # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations 168 # to avoid a large gap between teacher and student weights, leading to inconsistent predictions 169 # if we don't reinit this is not necessary 170 if self.reinit_teacher: 171 current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum) 172 else: 173 current_momentum = self.momentum 174 175 for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()): 176 param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum) 177 178 # 179 # functionality for saving checkpoints and initialization 180 # 181 182 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 183 """@private 184 """ 185 train_loader_kwargs = get_constructor_arguments(self.train_loader) 186 val_loader_kwargs = get_constructor_arguments(self.val_loader) 187 extra_state = { 188 "teacher_state": self.teacher.state_dict(), 189 "init": { 190 "train_loader_kwargs": train_loader_kwargs, 191 "train_dataset": self.train_loader.dataset, 192 "val_loader_kwargs": val_loader_kwargs, 193 "val_dataset": self.val_loader.dataset, 194 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 195 "loss_kwargs": {}, 196 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 197 "metric_kwargs": {}, 198 }, 199 } 200 extra_state.update(**extra_save_dict) 201 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 202 203 def load_checkpoint(self, checkpoint="best"): 204 """@private 205 """ 206 save_dict = super().load_checkpoint(checkpoint) 207 self.teacher.load_state_dict(save_dict["teacher_state"]) 208 self.teacher.to(self.device) 209 return save_dict 210 211 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 212 best_metric = super()._initialize(iterations, load_from_checkpoint, epochs) 213 self.teacher.to(self.device) 214 return best_metric 215 216 # 217 # training and validation functionality 218 # 219 220 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 221 self.model.train() 222 223 n_iter = 0 224 t_per_iter = time.time() 225 226 # Sample from both the supervised and unsupervised loader. 227 for xu1, xu2 in self.unsupervised_train_loader: 228 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 229 230 teacher_input, model_input = xu1, xu2 231 232 with forward_context(), torch.no_grad(): 233 # Compute the pseudo labels. 234 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 235 236 # If we have a sampler then check if the current batch matches the condition for inclusion in training. 237 if self.sampler is not None: 238 keep_batch = self.sampler(pseudo_labels, label_filter) 239 if not keep_batch: 240 continue 241 242 self.optimizer.zero_grad() 243 # Perform unsupervised training 244 with forward_context(): 245 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 246 backprop(loss) 247 248 if self.logger is not None: 249 with torch.no_grad(), forward_context(): 250 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 251 self.logger.log_train_unsupervised( 252 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 253 ) 254 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 255 self.logger.log_lr(self._iteration, lr) 256 if self.pseudo_labeler.confidence_threshold is not None: 257 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 258 259 with torch.no_grad(): 260 self._momentum_update() 261 262 self._iteration += 1 263 n_iter += 1 264 if self._iteration >= self.max_iteration: 265 break 266 progress.update(1) 267 268 t_per_iter = (time.time() - t_per_iter) / n_iter 269 return t_per_iter 270 271 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 272 self.model.train() 273 274 n_iter = 0 275 t_per_iter = time.time() 276 277 # Sample from both the supervised and unsupervised loader. 278 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 279 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 280 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 281 282 # Perform supervised training. 283 self.optimizer.zero_grad() 284 with forward_context(): 285 # We pass the model, the input and the labels to the supervised loss function, 286 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 287 supervised_loss = self.supervised_loss(self.model, xs, ys) 288 289 teacher_input, model_input = xu1, xu2 290 291 with forward_context(), torch.no_grad(): 292 # Compute the pseudo labels. 293 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 294 295 # Perform unsupervised training 296 with forward_context(): 297 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 298 299 loss = (supervised_loss + unsupervised_loss) / 2 300 backprop(loss) 301 302 if self.logger is not None: 303 with torch.no_grad(), forward_context(): 304 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 305 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 306 307 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 308 self.logger.log_train_unsupervised( 309 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 310 ) 311 312 self.logger.log_combined_loss(self._iteration, loss) 313 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 314 self.logger.log_lr(self._iteration, lr) 315 if self.pseudo_labeler.confidence_threshold is not None: 316 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 317 318 with torch.no_grad(): 319 self._momentum_update() 320 321 self._iteration += 1 322 n_iter += 1 323 if self._iteration >= self.max_iteration: 324 break 325 progress.update(1) 326 327 t_per_iter = (time.time() - t_per_iter) / n_iter 328 return t_per_iter 329 330 def _validate_supervised(self, forward_context): 331 metric_val = 0.0 332 loss_val = 0.0 333 334 for x, y in self.supervised_val_loader: 335 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 336 with forward_context(): 337 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 338 loss_val += loss.item() 339 metric_val += metric.item() 340 341 metric_val /= len(self.supervised_val_loader) 342 loss_val /= len(self.supervised_val_loader) 343 344 if self.logger is not None: 345 with forward_context(): 346 pred = self.model(x) 347 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 348 349 return metric_val 350 351 def _validate_unsupervised(self, forward_context): 352 metric_val = 0.0 353 loss_val = 0.0 354 355 for x1, x2 in self.unsupervised_val_loader: 356 x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True) 357 teacher_input, model_input = x1, x2 358 with forward_context(): 359 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 360 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 361 loss_val += loss.item() 362 metric_val += metric.item() 363 364 metric_val /= len(self.unsupervised_val_loader) 365 loss_val /= len(self.unsupervised_val_loader) 366 367 if self.logger is not None: 368 with forward_context(): 369 pred = self.model(model_input) 370 self.logger.log_validation_unsupervised( 371 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 372 ) 373 374 self.pseudo_labeler.step(metric_val, self._epoch) # NOTE: scheduler added in validation step 375 376 return metric_val 377 378 def _validate_impl(self, forward_context): 379 self.model.eval() 380 381 with torch.no_grad(): 382 383 if self.supervised_val_loader is None: 384 supervised_metric = None 385 else: 386 supervised_metric = self._validate_supervised(forward_context) 387 388 if self.unsupervised_val_loader is None: 389 unsupervised_metric = None 390 else: 391 unsupervised_metric = self._validate_unsupervised(forward_context) 392 393 if unsupervised_metric is None: 394 metric = supervised_metric 395 elif supervised_metric is None: 396 metric = unsupervised_metric 397 else: 398 metric = (supervised_metric + unsupervised_metric) / 2 399 400 return metric
Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. It uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. We support two training strategies:
- Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
- Training only on the unsupervised data.
This class expects the following data loaders:
- unsupervised_train_loader: Returns two augmentations of the same input.
- supervised_train_loader (optional): Returns input and labels.
- unsupervised_val_loader (optional): Same as unsupervised_train_loader
- supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.
And the following elements to customize the pseudo labeling:
- pseudo_labeler: to compute the psuedo-labels
- Parameters: teacher, teacher_input
- Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
- unsupervised_loss: the loss between model predictions and pseudo labels
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss
- supervised_loss (optional): the supervised loss function
- Parameters: model, input, labels
- Returns: loss
- unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss, metric
- supervised_loss_and_metric (optional): the supervised loss function and metric
- Parameters: model, input, labels
- Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. If it is None, the most appropriate initialization scheme for the training approach is chosen:
- semi-supervised training -> reinit, because we usually train a model from scratch
- unsupervised training -> do not reinit, because we usually fine-tune a model
Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples
Arguments:
- model: The model to be trained.
- unsupervised_train_loader: The loader for unsupervised training.
- unsupervised_loss: The loss for unsupervised training.
- pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
- supervised_train_loader: The loader for supervised training.
- supervised_loss: The loss for supervised training.
- unsupervised_loss_and_metric: The loss and metric for unsupervised training.
- supervised_loss_and_metric: The loss and metrhic for supervised training.
- logger: The logger.
- momentum: The momentum value for the exponential moving weight average of the teacher model.
- reinit_teacher: Whether to reinit the teacher model before starting the training.
- sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
- kwargs: Additional keyword arguments for
torch_em.trainer.DefaultTrainer.
76 def __init__( 77 self, 78 model: torch.nn.Module, 79 unsupervised_train_loader: torch.utils.data.DataLoader, 80 unsupervised_loss: torch.utils.data.DataLoader, 81 pseudo_labeler: Callable, 82 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 83 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 84 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 85 supervised_loss: Optional[Callable] = None, 86 unsupervised_loss_and_metric: Optional[Callable] = None, 87 supervised_loss_and_metric: Optional[Callable] = None, 88 logger=SelfTrainingTensorboardLogger, 89 momentum: float = 0.999, 90 reinit_teacher: Optional[bool] = None, 91 sampler: Optional[Callable] = None, 92 **kwargs, 93 ): 94 self.sampler = sampler 95 # Do we have supervised data or not? 96 if supervised_train_loader is None: 97 # No. -> We use the unsupervised training logic. 98 train_loader = unsupervised_train_loader 99 self._train_epoch_impl = self._train_epoch_unsupervised 100 else: 101 # Yes. -> We use the semi-supervised training logic. 102 assert supervised_loss is not None 103 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 104 else unsupervised_train_loader 105 self._train_epoch_impl = self._train_epoch_semisupervised 106 107 self.unsupervised_train_loader = unsupervised_train_loader 108 self.supervised_train_loader = supervised_train_loader 109 110 # Check that we have at least one of supvervised / unsupervised val loader. 111 assert sum(( 112 supervised_val_loader is not None, 113 unsupervised_val_loader is not None, 114 )) > 0 115 self.supervised_val_loader = supervised_val_loader 116 self.unsupervised_val_loader = unsupervised_val_loader 117 118 if self.unsupervised_val_loader is None: 119 val_loader = self.supervised_val_loader 120 else: 121 val_loader = self.unsupervised_train_loader 122 123 # Check that we have at least one of supvervised / unsupervised loss and metric. 124 assert sum(( 125 supervised_loss_and_metric is not None, 126 unsupervised_loss_and_metric is not None, 127 )) > 0 128 self.supervised_loss_and_metric = supervised_loss_and_metric 129 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 130 131 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 132 kwargs.pop("train_loader", None) 133 kwargs.pop("val_loader", None) 134 kwargs.pop("metric", None) 135 kwargs.pop("loss", None) 136 super().__init__( 137 model=model, train_loader=train_loader, val_loader=val_loader, 138 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 139 ) 140 141 self.unsupervised_loss = unsupervised_loss 142 self.supervised_loss = supervised_loss 143 144 self.pseudo_labeler = pseudo_labeler 145 self.momentum = momentum 146 147 # determine how we initialize the teacher weights (copy or reinitialization) 148 if reinit_teacher is None: 149 # semisupervised training: reinitialize 150 # unsupervised training: copy 151 self.reinit_teacher = supervised_train_loader is not None 152 else: 153 self.reinit_teacher = reinit_teacher 154 155 with torch.no_grad(): 156 self.teacher = deepcopy(self.model) 157 if self.reinit_teacher: 158 for layer in self.teacher.children(): 159 if hasattr(layer, "reset_parameters"): 160 layer.reset_parameters() 161 for param in self.teacher.parameters(): 162 param.requires_grad = False 163 164 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit
403class MeanTeacherTrainerWithInvertibleAugmentations(MeanTeacherTrainer): 404 """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach. 405 406 Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. 407 It uses a teacher model derived from the student model via EMA of weights 408 to predict pseudo-labels on unlabeled data. We support two training strategies: 409 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 410 - Training only on the unsupervised data. 411 412 This class expects the following data loaders: 413 - unsupervised_train_loader: Returns two augmentations of the same input. 414 - supervised_train_loader (optional): Returns input and labels. 415 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 416 - supervised_val_loader (optional): Same as supervised_train_loader 417 At least one of unsupervised_val_loader and supervised_val_loader must be given. 418 419 The augmenter defines separate invertible transforms for teacher and student inputs. 420 Teacher and student views are generated independently, and the corresponding inverse 421 transforms map predictions and pseudo-labels back into a shared reference frame. 422 423 And the following elements to customize the pseudo labeling: 424 - pseudo_labeler: to compute the psuedo-labels 425 - Parameters: teacher, teacher_input 426 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 427 - unsupervised_loss: the loss between model predictions and pseudo labels 428 - Parameters: model, model_input, pseudo_labels, label_filter 429 - Returns: loss 430 - supervised_loss (optional): the supervised loss function 431 - Parameters: model, input, labels 432 - Returns: loss 433 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 434 - Parameters: model, model_input, pseudo_labels, label_filter 435 - Returns: loss, metric 436 - supervised_loss_and_metric (optional): the supervised loss function and metric 437 - Parameters: model, input, labels 438 - Returns: loss, metric 439 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 440 441 If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. 442 If it is None, the most appropriate initialization scheme for the training approach is chosen: 443 - semi-supervised training -> reinit, because we usually train a model from scratch 444 - unsupervised training -> do not reinit, because we usually fine-tune a model 445 446 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 447 for setting the ratio between supervised and unsupervised training samples 448 449 Args: 450 model: The model to be trained. 451 unsupervised_train_loader: The loader for unsupervised training. 452 unsupervised_loss: The loss for unsupervised training. 453 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 454 augmenter: Invertible augmenter providing separate teacher and student transforms, 455 including corresponding inverse transforms to align predictions in a common frame. 456 supervised_train_loader: The loader for supervised training. 457 supervised_loss: The loss for supervised training. 458 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 459 supervised_loss_and_metric: The loss and metrhic for supervised training. 460 logger: The logger. 461 momentum: The momentum value for the exponential moving weight average of the teacher model. 462 reinit_teacher: Whether to reinit the teacher model before starting the training. 463 sampler: A sampler for rejecting pseudo-labels according to a defined criterion. 464 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 465 """ 466 467 def __init__( 468 self, 469 model: torch.nn.Module, 470 unsupervised_train_loader: torch.utils.data.DataLoader, 471 unsupervised_loss: torch.utils.data.DataLoader, 472 pseudo_labeler: Callable, 473 augmenter: InvertibleAugmenter, 474 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 475 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 476 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 477 supervised_loss: Optional[Callable] = None, 478 unsupervised_loss_and_metric: Optional[Callable] = None, 479 supervised_loss_and_metric: Optional[Callable] = None, 480 logger=SelfTrainingTensorboardLogger, 481 momentum: float = 0.999, 482 reinit_teacher: Optional[bool] = None, 483 sampler: Optional[Callable] = None, 484 **kwargs, 485 ): 486 super().__init__( 487 model=model, 488 unsupervised_train_loader=unsupervised_train_loader, 489 unsupervised_loss=unsupervised_loss, 490 pseudo_labeler=pseudo_labeler, 491 supervised_train_loader=supervised_train_loader, 492 unsupervised_val_loader=unsupervised_val_loader, 493 supervised_val_loader=supervised_val_loader, 494 supervised_loss=supervised_loss, 495 unsupervised_loss_and_metric=unsupervised_loss_and_metric, 496 supervised_loss_and_metric=supervised_loss_and_metric, 497 logger=logger, 498 momentum=momentum, 499 reinit_teacher=reinit_teacher, 500 sampler=sampler, 501 **kwargs, 502 ) 503 self.augmenter = augmenter 504 505 # 506 # training and validation functionality 507 # 508 509 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 510 self.model.train() 511 512 n_iter = 0 513 t_per_iter = time.time() 514 515 # Sample from both the supervised and unsupervised loader. 516 for xu in self.unsupervised_train_loader: 517 self.augmenter.reset_all() 518 xu = xu.to(self.device, non_blocking=True) 519 520 xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu) 521 teacher_input, model_input = xu1, xu2 522 523 with forward_context(), torch.no_grad(): 524 # Compute the pseudo labels. 525 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 526 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 527 label_filter_inv = ( 528 self.augmenter.teacher.reverse_transform(label_filter) 529 if label_filter is not None else None 530 ) 531 532 # If we have a sampler then check if the current batch matches the condition for inclusion in training. 533 if self.sampler is not None: 534 keep_batch = self.sampler(pseudo_labels, label_filter) 535 if not keep_batch: 536 continue 537 538 self.optimizer.zero_grad() 539 # Perform unsupervised training 540 with forward_context(): 541 pred = self.model(model_input) 542 pred_inv = self.augmenter.student.reverse_transform(pred) 543 loss = self.unsupervised_loss(pred_inv, pseudo_labels_inv, label_filter_inv) 544 backprop(loss) 545 546 if self.logger is not None: 547 with torch.no_grad(), forward_context(): 548 pred = pred if self._iteration % self.log_image_interval == 0 else None 549 self.logger.log_train_unsupervised( 550 self._iteration, loss, xu, xu, pred_inv, pseudo_labels_inv, label_filter_inv 551 ) 552 self.logger.log_train_augmentations( 553 self._iteration, xu1, xu2, pseudo_labels, pred, 554 ) 555 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 556 self.logger.log_lr(self._iteration, lr) 557 if self.pseudo_labeler.confidence_threshold is not None: 558 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 559 560 with torch.no_grad(): 561 self._momentum_update() 562 563 self._iteration += 1 564 n_iter += 1 565 if self._iteration >= self.max_iteration: 566 break 567 progress.update(1) 568 569 t_per_iter = (time.time() - t_per_iter) / n_iter 570 return t_per_iter 571 572 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 573 self.model.train() 574 575 n_iter = 0 576 t_per_iter = time.time() 577 578 # Sample from both the supervised and unsupervised loader. 579 for (xs, ys), xu in zip(self.supervised_train_loader, self.unsupervised_train_loader): 580 self.augmenter.reset_all() 581 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 582 xu = xu.to(self.device, non_blocking=True) 583 584 xu1, xu2 = self.augmenter.teacher.transform(xu), self.augmenter.student.transform(xu) 585 teacher_input, model_input = xu1, xu2 586 587 # Perform supervised training. 588 self.optimizer.zero_grad() 589 with forward_context(): 590 # We pass the model, the input and the labels to the supervised loss function, 591 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 592 supervised_pred = self.model(xs) 593 supervised_loss = self.supervised_loss(supervised_pred, ys) 594 595 with forward_context(), torch.no_grad(): 596 # Compute the pseudo labels. 597 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 598 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 599 label_filter_inv = ( 600 self.augmenter.teacher.reverse_transform(label_filter) 601 if label_filter is not None else None 602 ) 603 604 # Perform unsupervised training 605 with forward_context(): 606 unsup_pred = self.model(model_input) 607 unsup_pred_inv = self.augmenter.student.reverse_transform(unsup_pred) 608 unsupervised_loss = self.unsupervised_loss(unsup_pred_inv, pseudo_labels_inv, label_filter_inv) 609 610 loss = (supervised_loss + unsupervised_loss) / 2 611 backprop(loss) 612 613 if self.logger is not None: 614 with torch.no_grad(), forward_context(): 615 unsup_pred = unsup_pred if self._iteration % self.log_image_interval == 0 else None 616 supervised_pred = supervised_pred if self._iteration % self.log_image_interval == 0 else None 617 618 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 619 self.logger.log_train_unsupervised( 620 self._iteration, unsupervised_loss, xu, xu, unsup_pred_inv, pseudo_labels_inv, label_filter_inv 621 ) 622 self.logger.log_train_augmentations( 623 self._iteration, xu1, xu2, pseudo_labels, unsup_pred, 624 ) 625 626 self.logger.log_combined_loss(self._iteration, loss) 627 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 628 self.logger.log_lr(self._iteration, lr) 629 if self.pseudo_labeler.confidence_threshold is not None: 630 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 631 632 with torch.no_grad(): 633 self._momentum_update() 634 635 self._iteration += 1 636 n_iter += 1 637 if self._iteration >= self.max_iteration: 638 break 639 progress.update(1) 640 641 t_per_iter = (time.time() - t_per_iter) / n_iter 642 return t_per_iter 643 644 def _validate_supervised(self, forward_context): 645 metric_val = 0.0 646 loss_val = 0.0 647 648 for x, y in self.supervised_val_loader: 649 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 650 with forward_context(): 651 pred = self.model(x) 652 loss, metric = self.supervised_loss_and_metric(pred, y) 653 loss_val += loss.item() 654 metric_val += metric.item() 655 656 metric_val /= len(self.supervised_val_loader) 657 loss_val /= len(self.supervised_val_loader) 658 659 if self.logger is not None: 660 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 661 662 return metric_val 663 664 def _validate_unsupervised(self, forward_context): 665 metric_val = 0.0 666 loss_val = 0.0 667 668 for x in self.unsupervised_val_loader: 669 self.augmenter.reset_all() 670 x = x.to(self.device, non_blocking=True) 671 x1, x2 = self.augmenter.teacher.transform(x), self.augmenter.student.transform(x) 672 teacher_input, model_input = x1, x2 673 with forward_context(): 674 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 675 pseudo_labels_inv = self.augmenter.teacher.reverse_transform(pseudo_labels) 676 label_filter_inv = ( 677 self.augmenter.teacher.reverse_transform(label_filter) 678 if label_filter is not None else None 679 ) 680 681 pred = self.model(model_input) 682 pred_inv = self.augmenter.student.reverse_transform(pred) 683 loss, metric = self.unsupervised_loss_and_metric(pred_inv, pseudo_labels_inv, label_filter_inv) 684 loss_val += loss.item() 685 metric_val += metric.item() 686 687 metric_val /= len(self.unsupervised_val_loader) 688 loss_val /= len(self.unsupervised_val_loader) 689 690 if self.logger is not None: 691 self.logger.log_validation_unsupervised( 692 self._iteration, metric_val, loss_val, x, x, pred_inv, pseudo_labels_inv, label_filter_inv 693 ) 694 self.logger.log_validation_augmentations( 695 self._iteration, x1, x2, pseudo_labels, pred, 696 ) 697 698 self.pseudo_labeler.step(metric_val, self._epoch) 699 700 return metric_val 701 702 def _validate_impl(self, forward_context): 703 self.model.eval() 704 705 with torch.no_grad(): 706 707 if self.supervised_val_loader is None: 708 supervised_metric = None 709 else: 710 supervised_metric = self._validate_supervised(forward_context) 711 712 if self.unsupervised_val_loader is None: 713 unsupervised_metric = None 714 else: 715 unsupervised_metric = self._validate_unsupervised(forward_context) 716 717 if unsupervised_metric is None: 718 metric = supervised_metric 719 elif supervised_metric is None: 720 metric = unsupervised_metric 721 else: 722 metric = (supervised_metric + unsupervised_metric) / 2 723 724 return metric
Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. It uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. We support two training strategies:
- Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
- Training only on the unsupervised data.
This class expects the following data loaders:
- unsupervised_train_loader: Returns two augmentations of the same input.
- supervised_train_loader (optional): Returns input and labels.
- unsupervised_val_loader (optional): Same as unsupervised_train_loader
- supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.
The augmenter defines separate invertible transforms for teacher and student inputs. Teacher and student views are generated independently, and the corresponding inverse transforms map predictions and pseudo-labels back into a shared reference frame.
And the following elements to customize the pseudo labeling:
- pseudo_labeler: to compute the psuedo-labels
- Parameters: teacher, teacher_input
- Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
- unsupervised_loss: the loss between model predictions and pseudo labels
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss
- supervised_loss (optional): the supervised loss function
- Parameters: model, input, labels
- Returns: loss
- unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss, metric
- supervised_loss_and_metric (optional): the supervised loss function and metric
- Parameters: model, input, labels
- Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. If it is None, the most appropriate initialization scheme for the training approach is chosen:
- semi-supervised training -> reinit, because we usually train a model from scratch
- unsupervised training -> do not reinit, because we usually fine-tune a model
Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples
Arguments:
- model: The model to be trained.
- unsupervised_train_loader: The loader for unsupervised training.
- unsupervised_loss: The loss for unsupervised training.
- pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
- augmenter: Invertible augmenter providing separate teacher and student transforms, including corresponding inverse transforms to align predictions in a common frame.
- supervised_train_loader: The loader for supervised training.
- supervised_loss: The loss for supervised training.
- unsupervised_loss_and_metric: The loss and metric for unsupervised training.
- supervised_loss_and_metric: The loss and metrhic for supervised training.
- logger: The logger.
- momentum: The momentum value for the exponential moving weight average of the teacher model.
- reinit_teacher: Whether to reinit the teacher model before starting the training.
- sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
- kwargs: Additional keyword arguments for
torch_em.trainer.DefaultTrainer.
467 def __init__( 468 self, 469 model: torch.nn.Module, 470 unsupervised_train_loader: torch.utils.data.DataLoader, 471 unsupervised_loss: torch.utils.data.DataLoader, 472 pseudo_labeler: Callable, 473 augmenter: InvertibleAugmenter, 474 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 475 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 476 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 477 supervised_loss: Optional[Callable] = None, 478 unsupervised_loss_and_metric: Optional[Callable] = None, 479 supervised_loss_and_metric: Optional[Callable] = None, 480 logger=SelfTrainingTensorboardLogger, 481 momentum: float = 0.999, 482 reinit_teacher: Optional[bool] = None, 483 sampler: Optional[Callable] = None, 484 **kwargs, 485 ): 486 super().__init__( 487 model=model, 488 unsupervised_train_loader=unsupervised_train_loader, 489 unsupervised_loss=unsupervised_loss, 490 pseudo_labeler=pseudo_labeler, 491 supervised_train_loader=supervised_train_loader, 492 unsupervised_val_loader=unsupervised_val_loader, 493 supervised_val_loader=supervised_val_loader, 494 supervised_loss=supervised_loss, 495 unsupervised_loss_and_metric=unsupervised_loss_and_metric, 496 supervised_loss_and_metric=supervised_loss_and_metric, 497 logger=logger, 498 momentum=momentum, 499 reinit_teacher=reinit_teacher, 500 sampler=sampler, 501 **kwargs, 502 ) 503 self.augmenter = augmenter
Inherited Members
- MeanTeacherTrainer
- sampler
- unsupervised_train_loader
- supervised_train_loader
- supervised_val_loader
- unsupervised_val_loader
- supervised_loss_and_metric
- unsupervised_loss_and_metric
- unsupervised_loss
- supervised_loss
- pseudo_labeler
- momentum
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit