torch_em.self_training.mean_teacher
1import time 2from copy import deepcopy 3from typing import Callable, Optional 4 5import torch 6import torch_em 7from torch_em.util import get_constructor_arguments 8 9from .logger import SelfTrainingTensorboardLogger 10 11 12class Dummy(torch.nn.Module): 13 init_kwargs = {} 14 15 16class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): 17 """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach. 18 19 Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. 20 It uses a teacher model derived from the student model via EMA of weights 21 to predict pseudo-labels on unlabeled data. We support two training strategies: 22 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 23 - Training only on the unsupervised data. 24 25 This class expects the following data loaders: 26 - unsupervised_train_loader: Returns two augmentations of the same input. 27 - supervised_train_loader (optional): Returns input and labels. 28 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 29 - supervised_val_loader (optional): Same as supervised_train_loader 30 At least one of unsupervised_val_loader and supervised_val_loader must be given. 31 32 And the following elements to customize the pseudo labeling: 33 - pseudo_labeler: to compute the psuedo-labels 34 - Parameters: teacher, teacher_input 35 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 36 - unsupervised_loss: the loss between model predictions and pseudo labels 37 - Parameters: model, model_input, pseudo_labels, label_filter 38 - Returns: loss 39 - supervised_loss (optional): the supervised loss function 40 - Parameters: model, input, labels 41 - Returns: loss 42 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 43 - Parameters: model, model_input, pseudo_labels, label_filter 44 - Returns: loss, metric 45 - supervised_loss_and_metric (optional): the supervised loss function and metric 46 - Parameters: model, input, labels 47 - Returns: loss, metric 48 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 49 50 If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. 51 If it is None, the most appropriate initialization scheme for the training approach is chosen: 52 - semi-supervised training -> reinit, because we usually train a model from scratch 53 - unsupervised training -> do not reinit, because we usually fine-tune a model 54 55 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 56 for setting the ratio between supervised and unsupervised training samples 57 58 Args: 59 model: The model to be trained. 60 unsupervised_train_loader: The loader for unsupervised training. 61 unsupervised_loss: The loss for unsupervised training. 62 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 63 supervised_train_loader: The loader for supervised training. 64 supervised_loss: The loss for supervised training. 65 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 66 supervised_loss_and_metric: The loss and metrhic for supervised training. 67 logger: The logger. 68 momentum: The momentum value for the exponential moving weight average of the teacher model. 69 reinit_teacher: Whether to reinit the teacher model before starting the training. 70 sampler: A sampler for rejecting pseudo-labels according to a defined criterion. 71 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 72 """ 73 74 def __init__( 75 self, 76 model: torch.nn.Module, 77 unsupervised_train_loader: torch.utils.data.DataLoader, 78 unsupervised_loss: torch.utils.data.DataLoader, 79 pseudo_labeler: Callable, 80 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 81 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 82 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 83 supervised_loss: Optional[Callable] = None, 84 unsupervised_loss_and_metric: Optional[Callable] = None, 85 supervised_loss_and_metric: Optional[Callable] = None, 86 logger=SelfTrainingTensorboardLogger, 87 momentum: float = 0.999, 88 reinit_teacher: Optional[bool] = None, 89 sampler: Optional[Callable] = None, 90 **kwargs, 91 ): 92 self.sampler = sampler 93 # Do we have supervised data or not? 94 if supervised_train_loader is None: 95 # No. -> We use the unsupervised training logic. 96 train_loader = unsupervised_train_loader 97 self._train_epoch_impl = self._train_epoch_unsupervised 98 else: 99 # Yes. -> We use the semi-supervised training logic. 100 assert supervised_loss is not None 101 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 102 else unsupervised_train_loader 103 self._train_epoch_impl = self._train_epoch_semisupervised 104 105 self.unsupervised_train_loader = unsupervised_train_loader 106 self.supervised_train_loader = supervised_train_loader 107 108 # Check that we have at least one of supvervised / unsupervised val loader. 109 assert sum(( 110 supervised_val_loader is not None, 111 unsupervised_val_loader is not None, 112 )) > 0 113 self.supervised_val_loader = supervised_val_loader 114 self.unsupervised_val_loader = unsupervised_val_loader 115 116 if self.unsupervised_val_loader is None: 117 val_loader = self.supervised_val_loader 118 else: 119 val_loader = self.unsupervised_train_loader 120 121 # Check that we have at least one of supvervised / unsupervised loss and metric. 122 assert sum(( 123 supervised_loss_and_metric is not None, 124 unsupervised_loss_and_metric is not None, 125 )) > 0 126 self.supervised_loss_and_metric = supervised_loss_and_metric 127 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 128 129 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 130 kwargs.pop("train_loader", None) 131 kwargs.pop("val_loader", None) 132 kwargs.pop("metric", None) 133 kwargs.pop("loss", None) 134 super().__init__( 135 model=model, train_loader=train_loader, val_loader=val_loader, 136 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 137 ) 138 139 self.unsupervised_loss = unsupervised_loss 140 self.supervised_loss = supervised_loss 141 142 self.pseudo_labeler = pseudo_labeler 143 self.momentum = momentum 144 145 # determine how we initialize the teacher weights (copy or reinitialization) 146 if reinit_teacher is None: 147 # semisupervised training: reinitialize 148 # unsupervised training: copy 149 self.reinit_teacher = supervised_train_loader is not None 150 else: 151 self.reinit_teacher = reinit_teacher 152 153 with torch.no_grad(): 154 self.teacher = deepcopy(self.model) 155 if self.reinit_teacher: 156 for layer in self.teacher.children(): 157 if hasattr(layer, "reset_parameters"): 158 layer.reset_parameters() 159 for param in self.teacher.parameters(): 160 param.requires_grad = False 161 162 self._kwargs = kwargs 163 164 def _momentum_update(self): 165 # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations 166 # to avoid a large gap between teacher and student weights, leading to inconsistent predictions 167 # if we don't reinit this is not necessary 168 if self.reinit_teacher: 169 current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum) 170 else: 171 current_momentum = self.momentum 172 173 for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()): 174 param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum) 175 176 # 177 # functionality for saving checkpoints and initialization 178 # 179 180 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 181 """@private 182 """ 183 train_loader_kwargs = get_constructor_arguments(self.train_loader) 184 val_loader_kwargs = get_constructor_arguments(self.val_loader) 185 extra_state = { 186 "teacher_state": self.teacher.state_dict(), 187 "init": { 188 "train_loader_kwargs": train_loader_kwargs, 189 "train_dataset": self.train_loader.dataset, 190 "val_loader_kwargs": val_loader_kwargs, 191 "val_dataset": self.val_loader.dataset, 192 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 193 "loss_kwargs": {}, 194 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 195 "metric_kwargs": {}, 196 }, 197 } 198 extra_state.update(**extra_save_dict) 199 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 200 201 def load_checkpoint(self, checkpoint="best"): 202 """@private 203 """ 204 save_dict = super().load_checkpoint(checkpoint) 205 self.teacher.load_state_dict(save_dict["teacher_state"]) 206 self.teacher.to(self.device) 207 return save_dict 208 209 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 210 best_metric = super()._initialize(iterations, load_from_checkpoint, epochs) 211 self.teacher.to(self.device) 212 return best_metric 213 214 # 215 # training and validation functionality 216 # 217 218 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 219 self.model.train() 220 221 n_iter = 0 222 t_per_iter = time.time() 223 224 # Sample from both the supervised and unsupervised loader. 225 for xu1, xu2 in self.unsupervised_train_loader: 226 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 227 228 teacher_input, model_input = xu1, xu2 229 230 with forward_context(), torch.no_grad(): 231 # Compute the pseudo labels. 232 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 233 234 # If we have a sampler then check if the current batch matches the condition for inclusion in training. 235 if self.sampler is not None: 236 keep_batch = self.sampler(pseudo_labels, label_filter) 237 if not keep_batch: 238 continue 239 240 self.optimizer.zero_grad() 241 # Perform unsupervised training 242 with forward_context(): 243 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 244 backprop(loss) 245 246 if self.logger is not None: 247 with torch.no_grad(), forward_context(): 248 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 249 self.logger.log_train_unsupervised( 250 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 251 ) 252 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 253 self.logger.log_lr(self._iteration, lr) 254 255 with torch.no_grad(): 256 self._momentum_update() 257 258 self._iteration += 1 259 n_iter += 1 260 if self._iteration >= self.max_iteration: 261 break 262 progress.update(1) 263 264 t_per_iter = (time.time() - t_per_iter) / n_iter 265 return t_per_iter 266 267 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 268 self.model.train() 269 270 n_iter = 0 271 t_per_iter = time.time() 272 273 # Sample from both the supervised and unsupervised loader. 274 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 275 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 276 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 277 278 # Perform supervised training. 279 self.optimizer.zero_grad() 280 with forward_context(): 281 # We pass the model, the input and the labels to the supervised loss function, 282 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 283 supervised_loss = self.supervised_loss(self.model, xs, ys) 284 285 teacher_input, model_input = xu1, xu2 286 287 with forward_context(), torch.no_grad(): 288 # Compute the pseudo labels. 289 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 290 291 # Perform unsupervised training 292 with forward_context(): 293 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 294 295 loss = (supervised_loss + unsupervised_loss) / 2 296 backprop(loss) 297 298 if self.logger is not None: 299 with torch.no_grad(), forward_context(): 300 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 301 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 302 303 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 304 self.logger.log_train_unsupervised( 305 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 306 ) 307 308 self.logger.log_combined_loss(self._iteration, loss) 309 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 310 self.logger.log_lr(self._iteration, lr) 311 312 with torch.no_grad(): 313 self._momentum_update() 314 315 self._iteration += 1 316 n_iter += 1 317 if self._iteration >= self.max_iteration: 318 break 319 progress.update(1) 320 321 t_per_iter = (time.time() - t_per_iter) / n_iter 322 return t_per_iter 323 324 def _validate_supervised(self, forward_context): 325 metric_val = 0.0 326 loss_val = 0.0 327 328 for x, y in self.supervised_val_loader: 329 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 330 with forward_context(): 331 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 332 loss_val += loss.item() 333 metric_val += metric.item() 334 335 metric_val /= len(self.supervised_val_loader) 336 loss_val /= len(self.supervised_val_loader) 337 338 if self.logger is not None: 339 with forward_context(): 340 pred = self.model(x) 341 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 342 343 return metric_val 344 345 def _validate_unsupervised(self, forward_context): 346 metric_val = 0.0 347 loss_val = 0.0 348 349 for x1, x2 in self.unsupervised_val_loader: 350 x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True) 351 teacher_input, model_input = x1, x2 352 with forward_context(): 353 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 354 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 355 loss_val += loss.item() 356 metric_val += metric.item() 357 358 metric_val /= len(self.unsupervised_val_loader) 359 loss_val /= len(self.unsupervised_val_loader) 360 361 if self.logger is not None: 362 with forward_context(): 363 pred = self.model(model_input) 364 self.logger.log_validation_unsupervised( 365 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 366 ) 367 368 return metric_val 369 370 def _validate_impl(self, forward_context): 371 self.model.eval() 372 373 with torch.no_grad(): 374 375 if self.supervised_val_loader is None: 376 supervised_metric = None 377 else: 378 supervised_metric = self._validate_supervised(forward_context) 379 380 if self.unsupervised_val_loader is None: 381 unsupervised_metric = None 382 else: 383 unsupervised_metric = self._validate_unsupervised(forward_context) 384 385 if unsupervised_metric is None: 386 metric = supervised_metric 387 elif supervised_metric is None: 388 metric = unsupervised_metric 389 else: 390 metric = (supervised_metric + unsupervised_metric) / 2 391 392 return metric
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
17class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): 18 """Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach. 19 20 Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. 21 It uses a teacher model derived from the student model via EMA of weights 22 to predict pseudo-labels on unlabeled data. We support two training strategies: 23 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 24 - Training only on the unsupervised data. 25 26 This class expects the following data loaders: 27 - unsupervised_train_loader: Returns two augmentations of the same input. 28 - supervised_train_loader (optional): Returns input and labels. 29 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 30 - supervised_val_loader (optional): Same as supervised_train_loader 31 At least one of unsupervised_val_loader and supervised_val_loader must be given. 32 33 And the following elements to customize the pseudo labeling: 34 - pseudo_labeler: to compute the psuedo-labels 35 - Parameters: teacher, teacher_input 36 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 37 - unsupervised_loss: the loss between model predictions and pseudo labels 38 - Parameters: model, model_input, pseudo_labels, label_filter 39 - Returns: loss 40 - supervised_loss (optional): the supervised loss function 41 - Parameters: model, input, labels 42 - Returns: loss 43 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 44 - Parameters: model, model_input, pseudo_labels, label_filter 45 - Returns: loss, metric 46 - supervised_loss_and_metric (optional): the supervised loss function and metric 47 - Parameters: model, input, labels 48 - Returns: loss, metric 49 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 50 51 If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. 52 If it is None, the most appropriate initialization scheme for the training approach is chosen: 53 - semi-supervised training -> reinit, because we usually train a model from scratch 54 - unsupervised training -> do not reinit, because we usually fine-tune a model 55 56 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 57 for setting the ratio between supervised and unsupervised training samples 58 59 Args: 60 model: The model to be trained. 61 unsupervised_train_loader: The loader for unsupervised training. 62 unsupervised_loss: The loss for unsupervised training. 63 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 64 supervised_train_loader: The loader for supervised training. 65 supervised_loss: The loss for supervised training. 66 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 67 supervised_loss_and_metric: The loss and metrhic for supervised training. 68 logger: The logger. 69 momentum: The momentum value for the exponential moving weight average of the teacher model. 70 reinit_teacher: Whether to reinit the teacher model before starting the training. 71 sampler: A sampler for rejecting pseudo-labels according to a defined criterion. 72 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 73 """ 74 75 def __init__( 76 self, 77 model: torch.nn.Module, 78 unsupervised_train_loader: torch.utils.data.DataLoader, 79 unsupervised_loss: torch.utils.data.DataLoader, 80 pseudo_labeler: Callable, 81 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 82 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 83 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 84 supervised_loss: Optional[Callable] = None, 85 unsupervised_loss_and_metric: Optional[Callable] = None, 86 supervised_loss_and_metric: Optional[Callable] = None, 87 logger=SelfTrainingTensorboardLogger, 88 momentum: float = 0.999, 89 reinit_teacher: Optional[bool] = None, 90 sampler: Optional[Callable] = None, 91 **kwargs, 92 ): 93 self.sampler = sampler 94 # Do we have supervised data or not? 95 if supervised_train_loader is None: 96 # No. -> We use the unsupervised training logic. 97 train_loader = unsupervised_train_loader 98 self._train_epoch_impl = self._train_epoch_unsupervised 99 else: 100 # Yes. -> We use the semi-supervised training logic. 101 assert supervised_loss is not None 102 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 103 else unsupervised_train_loader 104 self._train_epoch_impl = self._train_epoch_semisupervised 105 106 self.unsupervised_train_loader = unsupervised_train_loader 107 self.supervised_train_loader = supervised_train_loader 108 109 # Check that we have at least one of supvervised / unsupervised val loader. 110 assert sum(( 111 supervised_val_loader is not None, 112 unsupervised_val_loader is not None, 113 )) > 0 114 self.supervised_val_loader = supervised_val_loader 115 self.unsupervised_val_loader = unsupervised_val_loader 116 117 if self.unsupervised_val_loader is None: 118 val_loader = self.supervised_val_loader 119 else: 120 val_loader = self.unsupervised_train_loader 121 122 # Check that we have at least one of supvervised / unsupervised loss and metric. 123 assert sum(( 124 supervised_loss_and_metric is not None, 125 unsupervised_loss_and_metric is not None, 126 )) > 0 127 self.supervised_loss_and_metric = supervised_loss_and_metric 128 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 129 130 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 131 kwargs.pop("train_loader", None) 132 kwargs.pop("val_loader", None) 133 kwargs.pop("metric", None) 134 kwargs.pop("loss", None) 135 super().__init__( 136 model=model, train_loader=train_loader, val_loader=val_loader, 137 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 138 ) 139 140 self.unsupervised_loss = unsupervised_loss 141 self.supervised_loss = supervised_loss 142 143 self.pseudo_labeler = pseudo_labeler 144 self.momentum = momentum 145 146 # determine how we initialize the teacher weights (copy or reinitialization) 147 if reinit_teacher is None: 148 # semisupervised training: reinitialize 149 # unsupervised training: copy 150 self.reinit_teacher = supervised_train_loader is not None 151 else: 152 self.reinit_teacher = reinit_teacher 153 154 with torch.no_grad(): 155 self.teacher = deepcopy(self.model) 156 if self.reinit_teacher: 157 for layer in self.teacher.children(): 158 if hasattr(layer, "reset_parameters"): 159 layer.reset_parameters() 160 for param in self.teacher.parameters(): 161 param.requires_grad = False 162 163 self._kwargs = kwargs 164 165 def _momentum_update(self): 166 # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations 167 # to avoid a large gap between teacher and student weights, leading to inconsistent predictions 168 # if we don't reinit this is not necessary 169 if self.reinit_teacher: 170 current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum) 171 else: 172 current_momentum = self.momentum 173 174 for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()): 175 param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum) 176 177 # 178 # functionality for saving checkpoints and initialization 179 # 180 181 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 182 """@private 183 """ 184 train_loader_kwargs = get_constructor_arguments(self.train_loader) 185 val_loader_kwargs = get_constructor_arguments(self.val_loader) 186 extra_state = { 187 "teacher_state": self.teacher.state_dict(), 188 "init": { 189 "train_loader_kwargs": train_loader_kwargs, 190 "train_dataset": self.train_loader.dataset, 191 "val_loader_kwargs": val_loader_kwargs, 192 "val_dataset": self.val_loader.dataset, 193 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 194 "loss_kwargs": {}, 195 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 196 "metric_kwargs": {}, 197 }, 198 } 199 extra_state.update(**extra_save_dict) 200 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 201 202 def load_checkpoint(self, checkpoint="best"): 203 """@private 204 """ 205 save_dict = super().load_checkpoint(checkpoint) 206 self.teacher.load_state_dict(save_dict["teacher_state"]) 207 self.teacher.to(self.device) 208 return save_dict 209 210 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 211 best_metric = super()._initialize(iterations, load_from_checkpoint, epochs) 212 self.teacher.to(self.device) 213 return best_metric 214 215 # 216 # training and validation functionality 217 # 218 219 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 220 self.model.train() 221 222 n_iter = 0 223 t_per_iter = time.time() 224 225 # Sample from both the supervised and unsupervised loader. 226 for xu1, xu2 in self.unsupervised_train_loader: 227 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 228 229 teacher_input, model_input = xu1, xu2 230 231 with forward_context(), torch.no_grad(): 232 # Compute the pseudo labels. 233 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 234 235 # If we have a sampler then check if the current batch matches the condition for inclusion in training. 236 if self.sampler is not None: 237 keep_batch = self.sampler(pseudo_labels, label_filter) 238 if not keep_batch: 239 continue 240 241 self.optimizer.zero_grad() 242 # Perform unsupervised training 243 with forward_context(): 244 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 245 backprop(loss) 246 247 if self.logger is not None: 248 with torch.no_grad(), forward_context(): 249 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 250 self.logger.log_train_unsupervised( 251 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 252 ) 253 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 254 self.logger.log_lr(self._iteration, lr) 255 256 with torch.no_grad(): 257 self._momentum_update() 258 259 self._iteration += 1 260 n_iter += 1 261 if self._iteration >= self.max_iteration: 262 break 263 progress.update(1) 264 265 t_per_iter = (time.time() - t_per_iter) / n_iter 266 return t_per_iter 267 268 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 269 self.model.train() 270 271 n_iter = 0 272 t_per_iter = time.time() 273 274 # Sample from both the supervised and unsupervised loader. 275 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 276 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 277 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 278 279 # Perform supervised training. 280 self.optimizer.zero_grad() 281 with forward_context(): 282 # We pass the model, the input and the labels to the supervised loss function, 283 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 284 supervised_loss = self.supervised_loss(self.model, xs, ys) 285 286 teacher_input, model_input = xu1, xu2 287 288 with forward_context(), torch.no_grad(): 289 # Compute the pseudo labels. 290 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 291 292 # Perform unsupervised training 293 with forward_context(): 294 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 295 296 loss = (supervised_loss + unsupervised_loss) / 2 297 backprop(loss) 298 299 if self.logger is not None: 300 with torch.no_grad(), forward_context(): 301 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 302 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 303 304 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 305 self.logger.log_train_unsupervised( 306 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 307 ) 308 309 self.logger.log_combined_loss(self._iteration, loss) 310 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 311 self.logger.log_lr(self._iteration, lr) 312 313 with torch.no_grad(): 314 self._momentum_update() 315 316 self._iteration += 1 317 n_iter += 1 318 if self._iteration >= self.max_iteration: 319 break 320 progress.update(1) 321 322 t_per_iter = (time.time() - t_per_iter) / n_iter 323 return t_per_iter 324 325 def _validate_supervised(self, forward_context): 326 metric_val = 0.0 327 loss_val = 0.0 328 329 for x, y in self.supervised_val_loader: 330 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 331 with forward_context(): 332 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 333 loss_val += loss.item() 334 metric_val += metric.item() 335 336 metric_val /= len(self.supervised_val_loader) 337 loss_val /= len(self.supervised_val_loader) 338 339 if self.logger is not None: 340 with forward_context(): 341 pred = self.model(x) 342 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 343 344 return metric_val 345 346 def _validate_unsupervised(self, forward_context): 347 metric_val = 0.0 348 loss_val = 0.0 349 350 for x1, x2 in self.unsupervised_val_loader: 351 x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True) 352 teacher_input, model_input = x1, x2 353 with forward_context(): 354 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 355 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 356 loss_val += loss.item() 357 metric_val += metric.item() 358 359 metric_val /= len(self.unsupervised_val_loader) 360 loss_val /= len(self.unsupervised_val_loader) 361 362 if self.logger is not None: 363 with forward_context(): 364 pred = self.model(model_input) 365 self.logger.log_validation_unsupervised( 366 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 367 ) 368 369 return metric_val 370 371 def _validate_impl(self, forward_context): 372 self.model.eval() 373 374 with torch.no_grad(): 375 376 if self.supervised_val_loader is None: 377 supervised_metric = None 378 else: 379 supervised_metric = self._validate_supervised(forward_context) 380 381 if self.unsupervised_val_loader is None: 382 unsupervised_metric = None 383 else: 384 unsupervised_metric = self._validate_unsupervised(forward_context) 385 386 if unsupervised_metric is None: 387 metric = supervised_metric 388 elif supervised_metric is None: 389 metric = unsupervised_metric 390 else: 391 metric = (supervised_metric + unsupervised_metric) / 2 392 393 return metric
Trainer for semi-supervised learning and domain adaptation following the MeanTeacher approach.
Mean Teacher was introduced by Tarvainen & Vapola in https://arxiv.org/abs/1703.01780. It uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. We support two training strategies:
- Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
- Training only on the unsupervised data.
This class expects the following data loaders:
- unsupervised_train_loader: Returns two augmentations of the same input.
- supervised_train_loader (optional): Returns input and labels.
- unsupervised_val_loader (optional): Same as unsupervised_train_loader
- supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.
And the following elements to customize the pseudo labeling:
- pseudo_labeler: to compute the psuedo-labels
- Parameters: teacher, teacher_input
- Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
- unsupervised_loss: the loss between model predictions and pseudo labels
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss
- supervised_loss (optional): the supervised loss function
- Parameters: model, input, labels
- Returns: loss
- unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss, metric
- supervised_loss_and_metric (optional): the supervised loss function and metric
- Parameters: model, input, labels
- Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. If it is None, the most appropriate initialization scheme for the training approach is chosen:
- semi-supervised training -> reinit, because we usually train a model from scratch
- unsupervised training -> do not reinit, because we usually fine-tune a model
Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples
Arguments:
- model: The model to be trained.
- unsupervised_train_loader: The loader for unsupervised training.
- unsupervised_loss: The loss for unsupervised training.
- pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
- supervised_train_loader: The loader for supervised training.
- supervised_loss: The loss for supervised training.
- unsupervised_loss_and_metric: The loss and metric for unsupervised training.
- supervised_loss_and_metric: The loss and metrhic for supervised training.
- logger: The logger.
- momentum: The momentum value for the exponential moving weight average of the teacher model.
- reinit_teacher: Whether to reinit the teacher model before starting the training.
- sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
- kwargs: Additional keyword arguments for
torch_em.trainer.DefaultTrainer
.
75 def __init__( 76 self, 77 model: torch.nn.Module, 78 unsupervised_train_loader: torch.utils.data.DataLoader, 79 unsupervised_loss: torch.utils.data.DataLoader, 80 pseudo_labeler: Callable, 81 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 82 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 83 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 84 supervised_loss: Optional[Callable] = None, 85 unsupervised_loss_and_metric: Optional[Callable] = None, 86 supervised_loss_and_metric: Optional[Callable] = None, 87 logger=SelfTrainingTensorboardLogger, 88 momentum: float = 0.999, 89 reinit_teacher: Optional[bool] = None, 90 sampler: Optional[Callable] = None, 91 **kwargs, 92 ): 93 self.sampler = sampler 94 # Do we have supervised data or not? 95 if supervised_train_loader is None: 96 # No. -> We use the unsupervised training logic. 97 train_loader = unsupervised_train_loader 98 self._train_epoch_impl = self._train_epoch_unsupervised 99 else: 100 # Yes. -> We use the semi-supervised training logic. 101 assert supervised_loss is not None 102 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 103 else unsupervised_train_loader 104 self._train_epoch_impl = self._train_epoch_semisupervised 105 106 self.unsupervised_train_loader = unsupervised_train_loader 107 self.supervised_train_loader = supervised_train_loader 108 109 # Check that we have at least one of supvervised / unsupervised val loader. 110 assert sum(( 111 supervised_val_loader is not None, 112 unsupervised_val_loader is not None, 113 )) > 0 114 self.supervised_val_loader = supervised_val_loader 115 self.unsupervised_val_loader = unsupervised_val_loader 116 117 if self.unsupervised_val_loader is None: 118 val_loader = self.supervised_val_loader 119 else: 120 val_loader = self.unsupervised_train_loader 121 122 # Check that we have at least one of supvervised / unsupervised loss and metric. 123 assert sum(( 124 supervised_loss_and_metric is not None, 125 unsupervised_loss_and_metric is not None, 126 )) > 0 127 self.supervised_loss_and_metric = supervised_loss_and_metric 128 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 129 130 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 131 kwargs.pop("train_loader", None) 132 kwargs.pop("val_loader", None) 133 kwargs.pop("metric", None) 134 kwargs.pop("loss", None) 135 super().__init__( 136 model=model, train_loader=train_loader, val_loader=val_loader, 137 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 138 ) 139 140 self.unsupervised_loss = unsupervised_loss 141 self.supervised_loss = supervised_loss 142 143 self.pseudo_labeler = pseudo_labeler 144 self.momentum = momentum 145 146 # determine how we initialize the teacher weights (copy or reinitialization) 147 if reinit_teacher is None: 148 # semisupervised training: reinitialize 149 # unsupervised training: copy 150 self.reinit_teacher = supervised_train_loader is not None 151 else: 152 self.reinit_teacher = reinit_teacher 153 154 with torch.no_grad(): 155 self.teacher = deepcopy(self.model) 156 if self.reinit_teacher: 157 for layer in self.teacher.children(): 158 if hasattr(layer, "reset_parameters"): 159 layer.reset_parameters() 160 for param in self.teacher.parameters(): 161 param.requires_grad = False 162 163 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit