torch_em.self_training.mean_teacher
1import time 2from copy import deepcopy 3 4import torch 5import torch_em 6from torch_em.util import get_constructor_arguments 7 8from .logger import SelfTrainingTensorboardLogger 9 10 11class Dummy(torch.nn.Module): 12 init_kwargs = {} 13 14 15class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): 16 """This trainer implements self-training for semi-supervised learning and domain following the 'MeanTeacher' 17 approach of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from 18 the student model via EMA of weights to predict pseudo-labels on unlabeled data. 19 We support two training strategies: joint training on labeled and unlabeled data 20 (with a supervised and unsupervised loss function). And training only on the unsupervised data. 21 22 This class expects the following data loaders: 23 - unsupervised_train_loader: Returns two augmentations of the same input. 24 - supervised_train_loader (optional): Returns input and labels. 25 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 26 - supervised_val_loader (optional): Same as supervised_train_loader 27 At least one of unsupervised_val_loader and supervised_val_loader must be given. 28 29 And the following elements to customize the pseudo labeling: 30 - pseudo_labeler: to compute the psuedo-labels 31 - Parameters: teacher, teacher_input 32 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 33 - unsupervised_loss: the loss between model predictions and pseudo labels 34 - Parameters: model, model_input, pseudo_labels, label_filter 35 - Returns: loss 36 - supervised_loss (optional): the supervised loss function 37 - Parameters: model, input, labels 38 - Returns: loss 39 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 40 - Parameters: model, model_input, pseudo_labels, label_filter 41 - Returns: loss, metric 42 - supervised_loss_and_metric (optional): the supervised loss function and metric 43 - Parameters: model, input, labels 44 - Returns: loss, metric 45 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 46 47 If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. 48 If it is None, the most appropriate initialization scheme for the training approach is chosen: 49 - semi-supervised training -> reinit, because we usually train a model from scratch 50 - unsupervised training -> do not reinit, because we usually fine-tune a model 51 52 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 53 for setting the ratio between supervised and unsupervised training samples 54 55 Parameters: 56 model [nn.Module] - 57 unsupervised_train_loader [torch.DataLoader] - 58 unsupervised_loss [callable] - 59 pseudo_labeler [callable] - 60 supervised_train_loader [torch.DataLoader] - (default: None) 61 supervised_loss [callable] - (default: None) 62 unsupervised_loss_and_metric [callable] - (default: None) 63 supervised_loss_and_metric [callable] - (default: None) 64 logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger) 65 momentum [float] - (default: 0.999) 66 reinit_teacher [bool] - (default: None) 67 **kwargs - keyword arguments for torch_em.DataLoader 68 """ 69 70 def __init__( 71 self, 72 model, 73 unsupervised_train_loader, 74 unsupervised_loss, 75 pseudo_labeler, 76 supervised_train_loader=None, 77 unsupervised_val_loader=None, 78 supervised_val_loader=None, 79 supervised_loss=None, 80 unsupervised_loss_and_metric=None, 81 supervised_loss_and_metric=None, 82 logger=SelfTrainingTensorboardLogger, 83 momentum=0.999, 84 reinit_teacher=None, 85 **kwargs 86 ): 87 # Do we have supervised data or not? 88 if supervised_train_loader is None: 89 # No. -> We use the unsupervised training logic. 90 train_loader = unsupervised_train_loader 91 self._train_epoch_impl = self._train_epoch_unsupervised 92 else: 93 # Yes. -> We use the semi-supervised training logic. 94 assert supervised_loss is not None 95 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 96 else unsupervised_train_loader 97 self._train_epoch_impl = self._train_epoch_semisupervised 98 99 self.unsupervised_train_loader = unsupervised_train_loader 100 self.supervised_train_loader = supervised_train_loader 101 102 # Check that we have at least one of supvervised / unsupervised val loader. 103 assert sum(( 104 supervised_val_loader is not None, 105 unsupervised_val_loader is not None, 106 )) > 0 107 self.supervised_val_loader = supervised_val_loader 108 self.unsupervised_val_loader = unsupervised_val_loader 109 110 if self.unsupervised_val_loader is None: 111 val_loader = self.supervised_val_loader 112 else: 113 val_loader = self.unsupervised_train_loader 114 115 # Check that we have at least one of supvervised / unsupervised loss and metric. 116 assert sum(( 117 supervised_loss_and_metric is not None, 118 unsupervised_loss_and_metric is not None, 119 )) > 0 120 self.supervised_loss_and_metric = supervised_loss_and_metric 121 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 122 123 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 124 kwargs.pop("train_loader", None) 125 kwargs.pop("val_loader", None) 126 kwargs.pop("metric", None) 127 kwargs.pop("loss", None) 128 super().__init__( 129 model=model, train_loader=train_loader, val_loader=val_loader, 130 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 131 ) 132 133 self.unsupervised_loss = unsupervised_loss 134 self.supervised_loss = supervised_loss 135 136 self.pseudo_labeler = pseudo_labeler 137 self.momentum = momentum 138 139 # determine how we initialize the teacher weights (copy or reinitialization) 140 if reinit_teacher is None: 141 # semisupervised training: reinitialize 142 # unsupervised training: copy 143 self.reinit_teacher = supervised_train_loader is not None 144 else: 145 self.reinit_teacher = reinit_teacher 146 147 with torch.no_grad(): 148 self.teacher = deepcopy(self.model) 149 if self.reinit_teacher: 150 for layer in self.teacher.children(): 151 if hasattr(layer, "reset_parameters"): 152 layer.reset_parameters() 153 for param in self.teacher.parameters(): 154 param.requires_grad = False 155 156 self._kwargs = kwargs 157 158 def _momentum_update(self): 159 # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations 160 # to avoid a large gap between teacher and student weights, leading to inconsistent predictions 161 # if we don't reinit this is not necessary 162 if self.reinit_teacher: 163 current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum) 164 else: 165 current_momentum = self.momentum 166 167 for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()): 168 param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum) 169 170 # 171 # functionality for saving checkpoints and initialization 172 # 173 174 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 175 train_loader_kwargs = get_constructor_arguments(self.train_loader) 176 val_loader_kwargs = get_constructor_arguments(self.val_loader) 177 extra_state = { 178 "teacher_state": self.teacher.state_dict(), 179 "init": { 180 "train_loader_kwargs": train_loader_kwargs, 181 "train_dataset": self.train_loader.dataset, 182 "val_loader_kwargs": val_loader_kwargs, 183 "val_dataset": self.val_loader.dataset, 184 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 185 "loss_kwargs": {}, 186 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 187 "metric_kwargs": {}, 188 }, 189 } 190 extra_state.update(**extra_save_dict) 191 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 192 193 def load_checkpoint(self, checkpoint="best"): 194 save_dict = super().load_checkpoint(checkpoint) 195 self.teacher.load_state_dict(save_dict["teacher_state"]) 196 self.teacher.to(self.device) 197 return save_dict 198 199 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 200 best_metric = super()._initialize(iterations, load_from_checkpoint, epochs) 201 self.teacher.to(self.device) 202 return best_metric 203 204 # 205 # training and validation functionality 206 # 207 208 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 209 self.model.train() 210 211 n_iter = 0 212 t_per_iter = time.time() 213 214 # Sample from both the supervised and unsupervised loader. 215 for xu1, xu2 in self.unsupervised_train_loader: 216 xu1, xu2 = xu1.to(self.device), xu2.to(self.device) 217 218 teacher_input, model_input = xu1, xu2 219 220 with forward_context(), torch.no_grad(): 221 # Compute the pseudo labels. 222 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 223 224 self.optimizer.zero_grad() 225 # Perform unsupervised training 226 with forward_context(): 227 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 228 backprop(loss) 229 230 if self.logger is not None: 231 with torch.no_grad(), forward_context(): 232 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 233 self.logger.log_train_unsupervised( 234 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 235 ) 236 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 237 self.logger.log_lr(self._iteration, lr) 238 239 with torch.no_grad(): 240 self._momentum_update() 241 242 self._iteration += 1 243 n_iter += 1 244 if self._iteration >= self.max_iteration: 245 break 246 progress.update(1) 247 248 t_per_iter = (time.time() - t_per_iter) / n_iter 249 return t_per_iter 250 251 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 252 self.model.train() 253 254 n_iter = 0 255 t_per_iter = time.time() 256 257 # Sample from both the supervised and unsupervised loader. 258 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 259 xs, ys = xs.to(self.device), ys.to(self.device) 260 xu1, xu2 = xu1.to(self.device), xu2.to(self.device) 261 262 # Perform supervised training. 263 self.optimizer.zero_grad() 264 with forward_context(): 265 # We pass the model, the input and the labels to the supervised loss function, 266 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 267 supervised_loss = self.supervised_loss(self.model, xs, ys) 268 269 teacher_input, model_input = xu1, xu2 270 271 with forward_context(), torch.no_grad(): 272 # Compute the pseudo labels. 273 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 274 275 # Perform unsupervised training 276 with forward_context(): 277 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 278 279 loss = (supervised_loss + unsupervised_loss) / 2 280 backprop(loss) 281 282 if self.logger is not None: 283 with torch.no_grad(), forward_context(): 284 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 285 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 286 287 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 288 self.logger.log_train_unsupervised( 289 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 290 ) 291 292 self.logger.log_combined_loss(self._iteration, loss) 293 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 294 self.logger.log_lr(self._iteration, lr) 295 296 with torch.no_grad(): 297 self._momentum_update() 298 299 self._iteration += 1 300 n_iter += 1 301 if self._iteration >= self.max_iteration: 302 break 303 progress.update(1) 304 305 t_per_iter = (time.time() - t_per_iter) / n_iter 306 return t_per_iter 307 308 def _validate_supervised(self, forward_context): 309 metric_val = 0.0 310 loss_val = 0.0 311 312 for x, y in self.supervised_val_loader: 313 x, y = x.to(self.device), y.to(self.device) 314 with forward_context(): 315 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 316 loss_val += loss.item() 317 metric_val += metric.item() 318 319 metric_val /= len(self.supervised_val_loader) 320 loss_val /= len(self.supervised_val_loader) 321 322 if self.logger is not None: 323 with forward_context(): 324 pred = self.model(x) 325 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 326 327 return metric_val 328 329 def _validate_unsupervised(self, forward_context): 330 metric_val = 0.0 331 loss_val = 0.0 332 333 for x1, x2 in self.unsupervised_val_loader: 334 x1, x2 = x1.to(self.device), x2.to(self.device) 335 teacher_input, model_input = x1, x2 336 with forward_context(): 337 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 338 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 339 loss_val += loss.item() 340 metric_val += metric.item() 341 342 metric_val /= len(self.unsupervised_val_loader) 343 loss_val /= len(self.unsupervised_val_loader) 344 345 if self.logger is not None: 346 with forward_context(): 347 pred = self.model(model_input) 348 self.logger.log_validation_unsupervised( 349 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 350 ) 351 352 return metric_val 353 354 def _validate_impl(self, forward_context): 355 self.model.eval() 356 357 with torch.no_grad(): 358 359 if self.supervised_val_loader is None: 360 supervised_metric = None 361 else: 362 supervised_metric = self._validate_supervised(forward_context) 363 364 if self.unsupervised_val_loader is None: 365 unsupervised_metric = None 366 else: 367 unsupervised_metric = self._validate_unsupervised(forward_context) 368 369 if unsupervised_metric is None: 370 metric = supervised_metric 371 elif supervised_metric is None: 372 metric = unsupervised_metric 373 else: 374 metric = (supervised_metric + unsupervised_metric) / 2 375 376 return metric
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Inherited Members
- torch.nn.modules.module.Module
- Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
16class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): 17 """This trainer implements self-training for semi-supervised learning and domain following the 'MeanTeacher' 18 approach of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from 19 the student model via EMA of weights to predict pseudo-labels on unlabeled data. 20 We support two training strategies: joint training on labeled and unlabeled data 21 (with a supervised and unsupervised loss function). And training only on the unsupervised data. 22 23 This class expects the following data loaders: 24 - unsupervised_train_loader: Returns two augmentations of the same input. 25 - supervised_train_loader (optional): Returns input and labels. 26 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 27 - supervised_val_loader (optional): Same as supervised_train_loader 28 At least one of unsupervised_val_loader and supervised_val_loader must be given. 29 30 And the following elements to customize the pseudo labeling: 31 - pseudo_labeler: to compute the psuedo-labels 32 - Parameters: teacher, teacher_input 33 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 34 - unsupervised_loss: the loss between model predictions and pseudo labels 35 - Parameters: model, model_input, pseudo_labels, label_filter 36 - Returns: loss 37 - supervised_loss (optional): the supervised loss function 38 - Parameters: model, input, labels 39 - Returns: loss 40 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 41 - Parameters: model, model_input, pseudo_labels, label_filter 42 - Returns: loss, metric 43 - supervised_loss_and_metric (optional): the supervised loss function and metric 44 - Parameters: model, input, labels 45 - Returns: loss, metric 46 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 47 48 If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. 49 If it is None, the most appropriate initialization scheme for the training approach is chosen: 50 - semi-supervised training -> reinit, because we usually train a model from scratch 51 - unsupervised training -> do not reinit, because we usually fine-tune a model 52 53 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 54 for setting the ratio between supervised and unsupervised training samples 55 56 Parameters: 57 model [nn.Module] - 58 unsupervised_train_loader [torch.DataLoader] - 59 unsupervised_loss [callable] - 60 pseudo_labeler [callable] - 61 supervised_train_loader [torch.DataLoader] - (default: None) 62 supervised_loss [callable] - (default: None) 63 unsupervised_loss_and_metric [callable] - (default: None) 64 supervised_loss_and_metric [callable] - (default: None) 65 logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger) 66 momentum [float] - (default: 0.999) 67 reinit_teacher [bool] - (default: None) 68 **kwargs - keyword arguments for torch_em.DataLoader 69 """ 70 71 def __init__( 72 self, 73 model, 74 unsupervised_train_loader, 75 unsupervised_loss, 76 pseudo_labeler, 77 supervised_train_loader=None, 78 unsupervised_val_loader=None, 79 supervised_val_loader=None, 80 supervised_loss=None, 81 unsupervised_loss_and_metric=None, 82 supervised_loss_and_metric=None, 83 logger=SelfTrainingTensorboardLogger, 84 momentum=0.999, 85 reinit_teacher=None, 86 **kwargs 87 ): 88 # Do we have supervised data or not? 89 if supervised_train_loader is None: 90 # No. -> We use the unsupervised training logic. 91 train_loader = unsupervised_train_loader 92 self._train_epoch_impl = self._train_epoch_unsupervised 93 else: 94 # Yes. -> We use the semi-supervised training logic. 95 assert supervised_loss is not None 96 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 97 else unsupervised_train_loader 98 self._train_epoch_impl = self._train_epoch_semisupervised 99 100 self.unsupervised_train_loader = unsupervised_train_loader 101 self.supervised_train_loader = supervised_train_loader 102 103 # Check that we have at least one of supvervised / unsupervised val loader. 104 assert sum(( 105 supervised_val_loader is not None, 106 unsupervised_val_loader is not None, 107 )) > 0 108 self.supervised_val_loader = supervised_val_loader 109 self.unsupervised_val_loader = unsupervised_val_loader 110 111 if self.unsupervised_val_loader is None: 112 val_loader = self.supervised_val_loader 113 else: 114 val_loader = self.unsupervised_train_loader 115 116 # Check that we have at least one of supvervised / unsupervised loss and metric. 117 assert sum(( 118 supervised_loss_and_metric is not None, 119 unsupervised_loss_and_metric is not None, 120 )) > 0 121 self.supervised_loss_and_metric = supervised_loss_and_metric 122 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 123 124 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 125 kwargs.pop("train_loader", None) 126 kwargs.pop("val_loader", None) 127 kwargs.pop("metric", None) 128 kwargs.pop("loss", None) 129 super().__init__( 130 model=model, train_loader=train_loader, val_loader=val_loader, 131 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 132 ) 133 134 self.unsupervised_loss = unsupervised_loss 135 self.supervised_loss = supervised_loss 136 137 self.pseudo_labeler = pseudo_labeler 138 self.momentum = momentum 139 140 # determine how we initialize the teacher weights (copy or reinitialization) 141 if reinit_teacher is None: 142 # semisupervised training: reinitialize 143 # unsupervised training: copy 144 self.reinit_teacher = supervised_train_loader is not None 145 else: 146 self.reinit_teacher = reinit_teacher 147 148 with torch.no_grad(): 149 self.teacher = deepcopy(self.model) 150 if self.reinit_teacher: 151 for layer in self.teacher.children(): 152 if hasattr(layer, "reset_parameters"): 153 layer.reset_parameters() 154 for param in self.teacher.parameters(): 155 param.requires_grad = False 156 157 self._kwargs = kwargs 158 159 def _momentum_update(self): 160 # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations 161 # to avoid a large gap between teacher and student weights, leading to inconsistent predictions 162 # if we don't reinit this is not necessary 163 if self.reinit_teacher: 164 current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum) 165 else: 166 current_momentum = self.momentum 167 168 for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()): 169 param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum) 170 171 # 172 # functionality for saving checkpoints and initialization 173 # 174 175 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 176 train_loader_kwargs = get_constructor_arguments(self.train_loader) 177 val_loader_kwargs = get_constructor_arguments(self.val_loader) 178 extra_state = { 179 "teacher_state": self.teacher.state_dict(), 180 "init": { 181 "train_loader_kwargs": train_loader_kwargs, 182 "train_dataset": self.train_loader.dataset, 183 "val_loader_kwargs": val_loader_kwargs, 184 "val_dataset": self.val_loader.dataset, 185 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 186 "loss_kwargs": {}, 187 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 188 "metric_kwargs": {}, 189 }, 190 } 191 extra_state.update(**extra_save_dict) 192 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 193 194 def load_checkpoint(self, checkpoint="best"): 195 save_dict = super().load_checkpoint(checkpoint) 196 self.teacher.load_state_dict(save_dict["teacher_state"]) 197 self.teacher.to(self.device) 198 return save_dict 199 200 def _initialize(self, iterations, load_from_checkpoint, epochs=None): 201 best_metric = super()._initialize(iterations, load_from_checkpoint, epochs) 202 self.teacher.to(self.device) 203 return best_metric 204 205 # 206 # training and validation functionality 207 # 208 209 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 210 self.model.train() 211 212 n_iter = 0 213 t_per_iter = time.time() 214 215 # Sample from both the supervised and unsupervised loader. 216 for xu1, xu2 in self.unsupervised_train_loader: 217 xu1, xu2 = xu1.to(self.device), xu2.to(self.device) 218 219 teacher_input, model_input = xu1, xu2 220 221 with forward_context(), torch.no_grad(): 222 # Compute the pseudo labels. 223 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 224 225 self.optimizer.zero_grad() 226 # Perform unsupervised training 227 with forward_context(): 228 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 229 backprop(loss) 230 231 if self.logger is not None: 232 with torch.no_grad(), forward_context(): 233 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 234 self.logger.log_train_unsupervised( 235 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 236 ) 237 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 238 self.logger.log_lr(self._iteration, lr) 239 240 with torch.no_grad(): 241 self._momentum_update() 242 243 self._iteration += 1 244 n_iter += 1 245 if self._iteration >= self.max_iteration: 246 break 247 progress.update(1) 248 249 t_per_iter = (time.time() - t_per_iter) / n_iter 250 return t_per_iter 251 252 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 253 self.model.train() 254 255 n_iter = 0 256 t_per_iter = time.time() 257 258 # Sample from both the supervised and unsupervised loader. 259 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 260 xs, ys = xs.to(self.device), ys.to(self.device) 261 xu1, xu2 = xu1.to(self.device), xu2.to(self.device) 262 263 # Perform supervised training. 264 self.optimizer.zero_grad() 265 with forward_context(): 266 # We pass the model, the input and the labels to the supervised loss function, 267 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 268 supervised_loss = self.supervised_loss(self.model, xs, ys) 269 270 teacher_input, model_input = xu1, xu2 271 272 with forward_context(), torch.no_grad(): 273 # Compute the pseudo labels. 274 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 275 276 # Perform unsupervised training 277 with forward_context(): 278 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 279 280 loss = (supervised_loss + unsupervised_loss) / 2 281 backprop(loss) 282 283 if self.logger is not None: 284 with torch.no_grad(), forward_context(): 285 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 286 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 287 288 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 289 self.logger.log_train_unsupervised( 290 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 291 ) 292 293 self.logger.log_combined_loss(self._iteration, loss) 294 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 295 self.logger.log_lr(self._iteration, lr) 296 297 with torch.no_grad(): 298 self._momentum_update() 299 300 self._iteration += 1 301 n_iter += 1 302 if self._iteration >= self.max_iteration: 303 break 304 progress.update(1) 305 306 t_per_iter = (time.time() - t_per_iter) / n_iter 307 return t_per_iter 308 309 def _validate_supervised(self, forward_context): 310 metric_val = 0.0 311 loss_val = 0.0 312 313 for x, y in self.supervised_val_loader: 314 x, y = x.to(self.device), y.to(self.device) 315 with forward_context(): 316 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 317 loss_val += loss.item() 318 metric_val += metric.item() 319 320 metric_val /= len(self.supervised_val_loader) 321 loss_val /= len(self.supervised_val_loader) 322 323 if self.logger is not None: 324 with forward_context(): 325 pred = self.model(x) 326 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 327 328 return metric_val 329 330 def _validate_unsupervised(self, forward_context): 331 metric_val = 0.0 332 loss_val = 0.0 333 334 for x1, x2 in self.unsupervised_val_loader: 335 x1, x2 = x1.to(self.device), x2.to(self.device) 336 teacher_input, model_input = x1, x2 337 with forward_context(): 338 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) 339 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 340 loss_val += loss.item() 341 metric_val += metric.item() 342 343 metric_val /= len(self.unsupervised_val_loader) 344 loss_val /= len(self.unsupervised_val_loader) 345 346 if self.logger is not None: 347 with forward_context(): 348 pred = self.model(model_input) 349 self.logger.log_validation_unsupervised( 350 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 351 ) 352 353 return metric_val 354 355 def _validate_impl(self, forward_context): 356 self.model.eval() 357 358 with torch.no_grad(): 359 360 if self.supervised_val_loader is None: 361 supervised_metric = None 362 else: 363 supervised_metric = self._validate_supervised(forward_context) 364 365 if self.unsupervised_val_loader is None: 366 unsupervised_metric = None 367 else: 368 unsupervised_metric = self._validate_unsupervised(forward_context) 369 370 if unsupervised_metric is None: 371 metric = supervised_metric 372 elif supervised_metric is None: 373 metric = unsupervised_metric 374 else: 375 metric = (supervised_metric + unsupervised_metric) / 2 376 377 return metric
This trainer implements self-training for semi-supervised learning and domain following the 'MeanTeacher' approach of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. We support two training strategies: joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). And training only on the unsupervised data.
This class expects the following data loaders:
- unsupervised_train_loader: Returns two augmentations of the same input.
- supervised_train_loader (optional): Returns input and labels.
- unsupervised_val_loader (optional): Same as unsupervised_train_loader
- supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.
And the following elements to customize the pseudo labeling:
- pseudo_labeler: to compute the psuedo-labels
- Parameters: teacher, teacher_input
- Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
- unsupervised_loss: the loss between model predictions and pseudo labels
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss
- supervised_loss (optional): the supervised loss function
- Parameters: model, input, labels
- Returns: loss
- unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss, metric
- supervised_loss_and_metric (optional): the supervised loss function and metric
- Parameters: model, input, labels
- Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. If it is None, the most appropriate initialization scheme for the training approach is chosen:
- semi-supervised training -> reinit, because we usually train a model from scratch
- unsupervised training -> do not reinit, because we usually fine-tune a model
Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples
Arguments:
- model [nn.Module] -
- unsupervised_train_loader [torch.DataLoader] -
- unsupervised_loss [callable] -
- pseudo_labeler [callable] -
- supervised_train_loader [torch.DataLoader] - (default: None)
- supervised_loss [callable] - (default: None)
- unsupervised_loss_and_metric [callable] - (default: None)
- supervised_loss_and_metric [callable] - (default: None)
- logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger)
- momentum [float] - (default: 0.999)
- reinit_teacher [bool] - (default: None)
- **kwargs - keyword arguments for torch_em.DataLoader
71 def __init__( 72 self, 73 model, 74 unsupervised_train_loader, 75 unsupervised_loss, 76 pseudo_labeler, 77 supervised_train_loader=None, 78 unsupervised_val_loader=None, 79 supervised_val_loader=None, 80 supervised_loss=None, 81 unsupervised_loss_and_metric=None, 82 supervised_loss_and_metric=None, 83 logger=SelfTrainingTensorboardLogger, 84 momentum=0.999, 85 reinit_teacher=None, 86 **kwargs 87 ): 88 # Do we have supervised data or not? 89 if supervised_train_loader is None: 90 # No. -> We use the unsupervised training logic. 91 train_loader = unsupervised_train_loader 92 self._train_epoch_impl = self._train_epoch_unsupervised 93 else: 94 # Yes. -> We use the semi-supervised training logic. 95 assert supervised_loss is not None 96 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 97 else unsupervised_train_loader 98 self._train_epoch_impl = self._train_epoch_semisupervised 99 100 self.unsupervised_train_loader = unsupervised_train_loader 101 self.supervised_train_loader = supervised_train_loader 102 103 # Check that we have at least one of supvervised / unsupervised val loader. 104 assert sum(( 105 supervised_val_loader is not None, 106 unsupervised_val_loader is not None, 107 )) > 0 108 self.supervised_val_loader = supervised_val_loader 109 self.unsupervised_val_loader = unsupervised_val_loader 110 111 if self.unsupervised_val_loader is None: 112 val_loader = self.supervised_val_loader 113 else: 114 val_loader = self.unsupervised_train_loader 115 116 # Check that we have at least one of supvervised / unsupervised loss and metric. 117 assert sum(( 118 supervised_loss_and_metric is not None, 119 unsupervised_loss_and_metric is not None, 120 )) > 0 121 self.supervised_loss_and_metric = supervised_loss_and_metric 122 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 123 124 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 125 kwargs.pop("train_loader", None) 126 kwargs.pop("val_loader", None) 127 kwargs.pop("metric", None) 128 kwargs.pop("loss", None) 129 super().__init__( 130 model=model, train_loader=train_loader, val_loader=val_loader, 131 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 132 ) 133 134 self.unsupervised_loss = unsupervised_loss 135 self.supervised_loss = supervised_loss 136 137 self.pseudo_labeler = pseudo_labeler 138 self.momentum = momentum 139 140 # determine how we initialize the teacher weights (copy or reinitialization) 141 if reinit_teacher is None: 142 # semisupervised training: reinitialize 143 # unsupervised training: copy 144 self.reinit_teacher = supervised_train_loader is not None 145 else: 146 self.reinit_teacher = reinit_teacher 147 148 with torch.no_grad(): 149 self.teacher = deepcopy(self.model) 150 if self.reinit_teacher: 151 for layer in self.teacher.children(): 152 if hasattr(layer, "reset_parameters"): 153 layer.reset_parameters() 154 for param in self.teacher.parameters(): 155 param.requires_grad = False 156 157 self._kwargs = kwargs
175 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 176 train_loader_kwargs = get_constructor_arguments(self.train_loader) 177 val_loader_kwargs = get_constructor_arguments(self.val_loader) 178 extra_state = { 179 "teacher_state": self.teacher.state_dict(), 180 "init": { 181 "train_loader_kwargs": train_loader_kwargs, 182 "train_dataset": self.train_loader.dataset, 183 "val_loader_kwargs": val_loader_kwargs, 184 "val_dataset": self.val_loader.dataset, 185 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 186 "loss_kwargs": {}, 187 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 188 "metric_kwargs": {}, 189 }, 190 } 191 extra_state.update(**extra_save_dict) 192 super().save_checkpoint(name, current_metric, best_metric, **extra_state)
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- mixed_precision
- early_stopping
- train_time
- scaler
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- fit