torch_em.self_training.fix_match
1import time 2 3import torch 4import torch_em 5from torch_em.util import get_constructor_arguments 6 7from .logger import SelfTrainingTensorboardLogger 8from .mean_teacher import Dummy 9 10 11class FixMatchTrainer(torch_em.trainer.DefaultTrainer): 12 """This trainer implements self-traning for semi-supervised learning and domain following the 'FixMatch' approach 13 of Sohn et al. (https://arxiv.org/abs/2001.07685). This approach uses a (teacher) model derived from the 14 student model via sharing the weights to predict pseudo-labels on unlabeled data. 15 We support two training strategies: joint training on labeled and unlabeled data 16 (with a supervised and unsupervised loss function). And training only on the unsupervised data. 17 18 This class expects the following data loaders: 19 - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input. 20 - supervised_train_loader (optional): Returns input and labels. 21 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 22 - supervised_val_loader (optional): Same as supervised_train_loader 23 At least one of unsupervised_val_loader and supervised_val_loader must be given. 24 25 And the following elements to customize the pseudo labeling: 26 - pseudo_labeler: to compute the psuedo-labels 27 - Parameters: model, teacher_input 28 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 29 - unsupervised_loss: the loss between model predictions and pseudo labels 30 - Parameters: model, model_input, pseudo_labels, label_filter 31 - Returns: loss 32 - supervised_loss (optional): the supervised loss function 33 - Parameters: model, input, labels 34 - Returns: loss 35 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 36 - Parameters: model, model_input, pseudo_labels, label_filter 37 - Returns: loss, metric 38 - supervised_loss_and_metric (optional): the supervised loss function and metric 39 - Parameters: model, input, labels 40 - Returns: loss, metric 41 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 42 43 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 44 for setting the ratio between supervised and unsupervised training samples 45 46 Parameters: 47 model [nn.Module] - 48 unsupervised_train_loader [torch.DataLoader] - 49 unsupervised_loss [callable] - 50 pseudo_labeler [callable] - 51 supervised_train_loader [torch.DataLoader] - (default: None) 52 supervised_loss [callable] - (default: None) 53 unsupervised_loss_and_metric [callable] - (default: None) 54 supervised_loss_and_metric [callable] - (default: None) 55 logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger) 56 momentum [float] - (default: 0.999) 57 source_distribution [list] - (default: None) 58 **kwargs - keyword arguments for torch_em.DataLoader 59 """ 60 61 def __init__( 62 self, 63 model, 64 unsupervised_train_loader, 65 unsupervised_loss, 66 pseudo_labeler, 67 supervised_train_loader=None, 68 unsupervised_val_loader=None, 69 supervised_val_loader=None, 70 supervised_loss=None, 71 unsupervised_loss_and_metric=None, 72 supervised_loss_and_metric=None, 73 logger=SelfTrainingTensorboardLogger, 74 source_distribution=None, 75 **kwargs 76 ): 77 # Do we have supervised data or not? 78 if supervised_train_loader is None: 79 # No. -> We use the unsupervised training logic. 80 train_loader = unsupervised_train_loader 81 self._train_epoch_impl = self._train_epoch_unsupervised 82 else: 83 # Yes. -> We use the semi-supervised training logic. 84 assert supervised_loss is not None 85 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 86 else unsupervised_train_loader 87 self._train_epoch_impl = self._train_epoch_semisupervised 88 89 self.unsupervised_train_loader = unsupervised_train_loader 90 self.supervised_train_loader = supervised_train_loader 91 92 # Check that we have at least one of supvervised / unsupervised val loader. 93 assert sum(( 94 supervised_val_loader is not None, 95 unsupervised_val_loader is not None, 96 )) > 0 97 self.supervised_val_loader = supervised_val_loader 98 self.unsupervised_val_loader = unsupervised_val_loader 99 100 if self.unsupervised_val_loader is None: 101 val_loader = self.supervised_val_loader 102 else: 103 val_loader = self.unsupervised_train_loader 104 105 # Check that we have at least one of supvervised / unsupervised loss and metric. 106 assert sum(( 107 supervised_loss_and_metric is not None, 108 unsupervised_loss_and_metric is not None, 109 )) > 0 110 self.supervised_loss_and_metric = supervised_loss_and_metric 111 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 112 113 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 114 kwargs.pop("train_loader", None) 115 kwargs.pop("val_loader", None) 116 kwargs.pop("metric", None) 117 kwargs.pop("loss", None) 118 super().__init__( 119 model=model, train_loader=train_loader, val_loader=val_loader, 120 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 121 ) 122 123 self.unsupervised_loss = unsupervised_loss 124 self.supervised_loss = supervised_loss 125 126 self.pseudo_labeler = pseudo_labeler 127 128 if source_distribution is None: 129 self.source_distribution = None 130 else: 131 self.source_distribution = torch.FloatTensor(source_distribution).to(self.device) 132 133 self._kwargs = kwargs 134 135 # 136 # functionality for saving checkpoints and initialization 137 # 138 139 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 140 train_loader_kwargs = get_constructor_arguments(self.train_loader) 141 val_loader_kwargs = get_constructor_arguments(self.val_loader) 142 extra_state = { 143 "init": { 144 "train_loader_kwargs": train_loader_kwargs, 145 "train_dataset": self.train_loader.dataset, 146 "val_loader_kwargs": val_loader_kwargs, 147 "val_dataset": self.val_loader.dataset, 148 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 149 "loss_kwargs": {}, 150 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 151 "metric_kwargs": {}, 152 }, 153 } 154 extra_state.update(**extra_save_dict) 155 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 156 157 # distribution alignment - encourages the distribution of the model's generated pseudo labels to match the marginal 158 # distribution of pseudo labels from the source transfer 159 # (key idea: to maximize the mutual information) 160 def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5): 161 if self.source_distribution is not None: 162 pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0) 163 _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True) 164 target_distribution = target_distribution / target_distribution.sum() 165 distribution_ratio = self.source_distribution / target_distribution 166 pseudo_labels = torch.where( 167 pseudo_labels < label_threshold, 168 pseudo_labels * distribution_ratio[0], 169 pseudo_labels * distribution_ratio[1] 170 ).clip(0, 1) 171 172 return pseudo_labels 173 174 # 175 # training and validation functionality 176 # 177 178 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 179 self.model.train() 180 181 n_iter = 0 182 t_per_iter = time.time() 183 184 # Sample from both the supervised and unsupervised loader. 185 for xu1, xu2 in self.unsupervised_train_loader: 186 xu1, xu2 = xu1.to(self.device), xu2.to(self.device) 187 188 teacher_input, model_input = xu1, xu2 189 190 with forward_context(), torch.no_grad(): 191 # Compute the pseudo labels. 192 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 193 194 pseudo_labels = pseudo_labels.detach() 195 if label_filter is not None: 196 label_filter = label_filter.detach() 197 198 # Perform distribution alignment for pseudo labels 199 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 200 201 self.optimizer.zero_grad() 202 # Perform unsupervised training 203 with forward_context(): 204 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 205 206 backprop(loss) 207 208 if self.logger is not None: 209 with torch.no_grad(), forward_context(): 210 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 211 self.logger.log_train_unsupervised( 212 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 213 ) 214 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 215 self.logger.log_lr(self._iteration, lr) 216 217 self._iteration += 1 218 n_iter += 1 219 if self._iteration >= self.max_iteration: 220 break 221 progress.update(1) 222 223 t_per_iter = (time.time() - t_per_iter) / n_iter 224 return t_per_iter 225 226 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 227 self.model.train() 228 229 n_iter = 0 230 t_per_iter = time.time() 231 232 # Sample from both the supervised and unsupervised loader. 233 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 234 xs, ys = xs.to(self.device), ys.to(self.device) 235 xu1, xu2 = xu1.to(self.device), xu2.to(self.device) 236 237 # Perform supervised training. 238 self.optimizer.zero_grad() 239 with forward_context(): 240 # We pass the model, the input and the labels to the supervised loss function, 241 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 242 supervised_loss = self.supervised_loss(self.model, xs, ys) 243 244 teacher_input, model_input = xu1, xu2 245 246 with forward_context(), torch.no_grad(): 247 # Compute the pseudo labels. 248 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 249 250 pseudo_labels = pseudo_labels.detach() 251 if label_filter is not None: 252 label_filter = label_filter.detach() 253 254 # Perform distribution alignment for pseudo labels 255 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 256 257 # Perform unsupervised training 258 with forward_context(): 259 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 260 261 loss = (supervised_loss + unsupervised_loss) / 2 262 backprop(loss) 263 264 if self.logger is not None: 265 with torch.no_grad(), forward_context(): 266 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 267 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 268 269 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 270 self.logger.log_train_unsupervised( 271 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 272 ) 273 274 self.logger.log_combined_loss(self._iteration, loss) 275 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 276 self.logger.log_lr(self._iteration, lr) 277 278 self._iteration += 1 279 n_iter += 1 280 if self._iteration >= self.max_iteration: 281 break 282 progress.update(1) 283 284 t_per_iter = (time.time() - t_per_iter) / n_iter 285 return t_per_iter 286 287 def _validate_supervised(self, forward_context): 288 metric_val = 0.0 289 loss_val = 0.0 290 291 for x, y in self.supervised_val_loader: 292 x, y = x.to(self.device), y.to(self.device) 293 with forward_context(): 294 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 295 loss_val += loss.item() 296 metric_val += metric.item() 297 298 metric_val /= len(self.supervised_val_loader) 299 loss_val /= len(self.supervised_val_loader) 300 301 if self.logger is not None: 302 with forward_context(): 303 pred = self.model(x) 304 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 305 306 return metric_val 307 308 def _validate_unsupervised(self, forward_context): 309 metric_val = 0.0 310 loss_val = 0.0 311 312 for x1, x2 in self.unsupervised_val_loader: 313 x1, x2 = x1.to(self.device), x2.to(self.device) 314 teacher_input, model_input = x1, x2 315 with forward_context(): 316 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 317 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 318 loss_val += loss.item() 319 metric_val += metric.item() 320 321 metric_val /= len(self.unsupervised_val_loader) 322 loss_val /= len(self.unsupervised_val_loader) 323 324 if self.logger is not None: 325 with forward_context(): 326 pred = self.model(model_input) 327 self.logger.log_validation_unsupervised( 328 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 329 ) 330 331 return metric_val 332 333 def _validate_impl(self, forward_context): 334 self.model.eval() 335 336 with torch.no_grad(): 337 338 if self.supervised_val_loader is None: 339 supervised_metric = None 340 else: 341 supervised_metric = self._validate_supervised(forward_context) 342 343 if self.unsupervised_val_loader is None: 344 unsupervised_metric = None 345 else: 346 unsupervised_metric = self._validate_unsupervised(forward_context) 347 348 if unsupervised_metric is None: 349 metric = supervised_metric 350 elif supervised_metric is None: 351 metric = unsupervised_metric 352 else: 353 metric = (supervised_metric + unsupervised_metric) / 2 354 355 return metric
12class FixMatchTrainer(torch_em.trainer.DefaultTrainer): 13 """This trainer implements self-traning for semi-supervised learning and domain following the 'FixMatch' approach 14 of Sohn et al. (https://arxiv.org/abs/2001.07685). This approach uses a (teacher) model derived from the 15 student model via sharing the weights to predict pseudo-labels on unlabeled data. 16 We support two training strategies: joint training on labeled and unlabeled data 17 (with a supervised and unsupervised loss function). And training only on the unsupervised data. 18 19 This class expects the following data loaders: 20 - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input. 21 - supervised_train_loader (optional): Returns input and labels. 22 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 23 - supervised_val_loader (optional): Same as supervised_train_loader 24 At least one of unsupervised_val_loader and supervised_val_loader must be given. 25 26 And the following elements to customize the pseudo labeling: 27 - pseudo_labeler: to compute the psuedo-labels 28 - Parameters: model, teacher_input 29 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 30 - unsupervised_loss: the loss between model predictions and pseudo labels 31 - Parameters: model, model_input, pseudo_labels, label_filter 32 - Returns: loss 33 - supervised_loss (optional): the supervised loss function 34 - Parameters: model, input, labels 35 - Returns: loss 36 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 37 - Parameters: model, model_input, pseudo_labels, label_filter 38 - Returns: loss, metric 39 - supervised_loss_and_metric (optional): the supervised loss function and metric 40 - Parameters: model, input, labels 41 - Returns: loss, metric 42 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 43 44 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 45 for setting the ratio between supervised and unsupervised training samples 46 47 Parameters: 48 model [nn.Module] - 49 unsupervised_train_loader [torch.DataLoader] - 50 unsupervised_loss [callable] - 51 pseudo_labeler [callable] - 52 supervised_train_loader [torch.DataLoader] - (default: None) 53 supervised_loss [callable] - (default: None) 54 unsupervised_loss_and_metric [callable] - (default: None) 55 supervised_loss_and_metric [callable] - (default: None) 56 logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger) 57 momentum [float] - (default: 0.999) 58 source_distribution [list] - (default: None) 59 **kwargs - keyword arguments for torch_em.DataLoader 60 """ 61 62 def __init__( 63 self, 64 model, 65 unsupervised_train_loader, 66 unsupervised_loss, 67 pseudo_labeler, 68 supervised_train_loader=None, 69 unsupervised_val_loader=None, 70 supervised_val_loader=None, 71 supervised_loss=None, 72 unsupervised_loss_and_metric=None, 73 supervised_loss_and_metric=None, 74 logger=SelfTrainingTensorboardLogger, 75 source_distribution=None, 76 **kwargs 77 ): 78 # Do we have supervised data or not? 79 if supervised_train_loader is None: 80 # No. -> We use the unsupervised training logic. 81 train_loader = unsupervised_train_loader 82 self._train_epoch_impl = self._train_epoch_unsupervised 83 else: 84 # Yes. -> We use the semi-supervised training logic. 85 assert supervised_loss is not None 86 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 87 else unsupervised_train_loader 88 self._train_epoch_impl = self._train_epoch_semisupervised 89 90 self.unsupervised_train_loader = unsupervised_train_loader 91 self.supervised_train_loader = supervised_train_loader 92 93 # Check that we have at least one of supvervised / unsupervised val loader. 94 assert sum(( 95 supervised_val_loader is not None, 96 unsupervised_val_loader is not None, 97 )) > 0 98 self.supervised_val_loader = supervised_val_loader 99 self.unsupervised_val_loader = unsupervised_val_loader 100 101 if self.unsupervised_val_loader is None: 102 val_loader = self.supervised_val_loader 103 else: 104 val_loader = self.unsupervised_train_loader 105 106 # Check that we have at least one of supvervised / unsupervised loss and metric. 107 assert sum(( 108 supervised_loss_and_metric is not None, 109 unsupervised_loss_and_metric is not None, 110 )) > 0 111 self.supervised_loss_and_metric = supervised_loss_and_metric 112 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 113 114 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 115 kwargs.pop("train_loader", None) 116 kwargs.pop("val_loader", None) 117 kwargs.pop("metric", None) 118 kwargs.pop("loss", None) 119 super().__init__( 120 model=model, train_loader=train_loader, val_loader=val_loader, 121 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 122 ) 123 124 self.unsupervised_loss = unsupervised_loss 125 self.supervised_loss = supervised_loss 126 127 self.pseudo_labeler = pseudo_labeler 128 129 if source_distribution is None: 130 self.source_distribution = None 131 else: 132 self.source_distribution = torch.FloatTensor(source_distribution).to(self.device) 133 134 self._kwargs = kwargs 135 136 # 137 # functionality for saving checkpoints and initialization 138 # 139 140 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 141 train_loader_kwargs = get_constructor_arguments(self.train_loader) 142 val_loader_kwargs = get_constructor_arguments(self.val_loader) 143 extra_state = { 144 "init": { 145 "train_loader_kwargs": train_loader_kwargs, 146 "train_dataset": self.train_loader.dataset, 147 "val_loader_kwargs": val_loader_kwargs, 148 "val_dataset": self.val_loader.dataset, 149 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 150 "loss_kwargs": {}, 151 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 152 "metric_kwargs": {}, 153 }, 154 } 155 extra_state.update(**extra_save_dict) 156 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 157 158 # distribution alignment - encourages the distribution of the model's generated pseudo labels to match the marginal 159 # distribution of pseudo labels from the source transfer 160 # (key idea: to maximize the mutual information) 161 def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5): 162 if self.source_distribution is not None: 163 pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0) 164 _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True) 165 target_distribution = target_distribution / target_distribution.sum() 166 distribution_ratio = self.source_distribution / target_distribution 167 pseudo_labels = torch.where( 168 pseudo_labels < label_threshold, 169 pseudo_labels * distribution_ratio[0], 170 pseudo_labels * distribution_ratio[1] 171 ).clip(0, 1) 172 173 return pseudo_labels 174 175 # 176 # training and validation functionality 177 # 178 179 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 180 self.model.train() 181 182 n_iter = 0 183 t_per_iter = time.time() 184 185 # Sample from both the supervised and unsupervised loader. 186 for xu1, xu2 in self.unsupervised_train_loader: 187 xu1, xu2 = xu1.to(self.device), xu2.to(self.device) 188 189 teacher_input, model_input = xu1, xu2 190 191 with forward_context(), torch.no_grad(): 192 # Compute the pseudo labels. 193 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 194 195 pseudo_labels = pseudo_labels.detach() 196 if label_filter is not None: 197 label_filter = label_filter.detach() 198 199 # Perform distribution alignment for pseudo labels 200 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 201 202 self.optimizer.zero_grad() 203 # Perform unsupervised training 204 with forward_context(): 205 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 206 207 backprop(loss) 208 209 if self.logger is not None: 210 with torch.no_grad(), forward_context(): 211 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 212 self.logger.log_train_unsupervised( 213 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 214 ) 215 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 216 self.logger.log_lr(self._iteration, lr) 217 218 self._iteration += 1 219 n_iter += 1 220 if self._iteration >= self.max_iteration: 221 break 222 progress.update(1) 223 224 t_per_iter = (time.time() - t_per_iter) / n_iter 225 return t_per_iter 226 227 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 228 self.model.train() 229 230 n_iter = 0 231 t_per_iter = time.time() 232 233 # Sample from both the supervised and unsupervised loader. 234 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 235 xs, ys = xs.to(self.device), ys.to(self.device) 236 xu1, xu2 = xu1.to(self.device), xu2.to(self.device) 237 238 # Perform supervised training. 239 self.optimizer.zero_grad() 240 with forward_context(): 241 # We pass the model, the input and the labels to the supervised loss function, 242 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 243 supervised_loss = self.supervised_loss(self.model, xs, ys) 244 245 teacher_input, model_input = xu1, xu2 246 247 with forward_context(), torch.no_grad(): 248 # Compute the pseudo labels. 249 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 250 251 pseudo_labels = pseudo_labels.detach() 252 if label_filter is not None: 253 label_filter = label_filter.detach() 254 255 # Perform distribution alignment for pseudo labels 256 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 257 258 # Perform unsupervised training 259 with forward_context(): 260 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 261 262 loss = (supervised_loss + unsupervised_loss) / 2 263 backprop(loss) 264 265 if self.logger is not None: 266 with torch.no_grad(), forward_context(): 267 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 268 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 269 270 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 271 self.logger.log_train_unsupervised( 272 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 273 ) 274 275 self.logger.log_combined_loss(self._iteration, loss) 276 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 277 self.logger.log_lr(self._iteration, lr) 278 279 self._iteration += 1 280 n_iter += 1 281 if self._iteration >= self.max_iteration: 282 break 283 progress.update(1) 284 285 t_per_iter = (time.time() - t_per_iter) / n_iter 286 return t_per_iter 287 288 def _validate_supervised(self, forward_context): 289 metric_val = 0.0 290 loss_val = 0.0 291 292 for x, y in self.supervised_val_loader: 293 x, y = x.to(self.device), y.to(self.device) 294 with forward_context(): 295 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 296 loss_val += loss.item() 297 metric_val += metric.item() 298 299 metric_val /= len(self.supervised_val_loader) 300 loss_val /= len(self.supervised_val_loader) 301 302 if self.logger is not None: 303 with forward_context(): 304 pred = self.model(x) 305 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 306 307 return metric_val 308 309 def _validate_unsupervised(self, forward_context): 310 metric_val = 0.0 311 loss_val = 0.0 312 313 for x1, x2 in self.unsupervised_val_loader: 314 x1, x2 = x1.to(self.device), x2.to(self.device) 315 teacher_input, model_input = x1, x2 316 with forward_context(): 317 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 318 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 319 loss_val += loss.item() 320 metric_val += metric.item() 321 322 metric_val /= len(self.unsupervised_val_loader) 323 loss_val /= len(self.unsupervised_val_loader) 324 325 if self.logger is not None: 326 with forward_context(): 327 pred = self.model(model_input) 328 self.logger.log_validation_unsupervised( 329 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 330 ) 331 332 return metric_val 333 334 def _validate_impl(self, forward_context): 335 self.model.eval() 336 337 with torch.no_grad(): 338 339 if self.supervised_val_loader is None: 340 supervised_metric = None 341 else: 342 supervised_metric = self._validate_supervised(forward_context) 343 344 if self.unsupervised_val_loader is None: 345 unsupervised_metric = None 346 else: 347 unsupervised_metric = self._validate_unsupervised(forward_context) 348 349 if unsupervised_metric is None: 350 metric = supervised_metric 351 elif supervised_metric is None: 352 metric = unsupervised_metric 353 else: 354 metric = (supervised_metric + unsupervised_metric) / 2 355 356 return metric
This trainer implements self-traning for semi-supervised learning and domain following the 'FixMatch' approach of Sohn et al. (https://arxiv.org/abs/2001.07685). This approach uses a (teacher) model derived from the student model via sharing the weights to predict pseudo-labels on unlabeled data. We support two training strategies: joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). And training only on the unsupervised data.
This class expects the following data loaders:
- unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
- supervised_train_loader (optional): Returns input and labels.
- unsupervised_val_loader (optional): Same as unsupervised_train_loader
- supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.
And the following elements to customize the pseudo labeling:
- pseudo_labeler: to compute the psuedo-labels
- Parameters: model, teacher_input
- Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
- unsupervised_loss: the loss between model predictions and pseudo labels
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss
- supervised_loss (optional): the supervised loss function
- Parameters: model, input, labels
- Returns: loss
- unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss, metric
- supervised_loss_and_metric (optional): the supervised loss function and metric
- Parameters: model, input, labels
- Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples
Arguments:
- model [nn.Module] -
- unsupervised_train_loader [torch.DataLoader] -
- unsupervised_loss [callable] -
- pseudo_labeler [callable] -
- supervised_train_loader [torch.DataLoader] - (default: None)
- supervised_loss [callable] - (default: None)
- unsupervised_loss_and_metric [callable] - (default: None)
- supervised_loss_and_metric [callable] - (default: None)
- logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger)
- momentum [float] - (default: 0.999)
- source_distribution [list] - (default: None)
- **kwargs - keyword arguments for torch_em.DataLoader
62 def __init__( 63 self, 64 model, 65 unsupervised_train_loader, 66 unsupervised_loss, 67 pseudo_labeler, 68 supervised_train_loader=None, 69 unsupervised_val_loader=None, 70 supervised_val_loader=None, 71 supervised_loss=None, 72 unsupervised_loss_and_metric=None, 73 supervised_loss_and_metric=None, 74 logger=SelfTrainingTensorboardLogger, 75 source_distribution=None, 76 **kwargs 77 ): 78 # Do we have supervised data or not? 79 if supervised_train_loader is None: 80 # No. -> We use the unsupervised training logic. 81 train_loader = unsupervised_train_loader 82 self._train_epoch_impl = self._train_epoch_unsupervised 83 else: 84 # Yes. -> We use the semi-supervised training logic. 85 assert supervised_loss is not None 86 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 87 else unsupervised_train_loader 88 self._train_epoch_impl = self._train_epoch_semisupervised 89 90 self.unsupervised_train_loader = unsupervised_train_loader 91 self.supervised_train_loader = supervised_train_loader 92 93 # Check that we have at least one of supvervised / unsupervised val loader. 94 assert sum(( 95 supervised_val_loader is not None, 96 unsupervised_val_loader is not None, 97 )) > 0 98 self.supervised_val_loader = supervised_val_loader 99 self.unsupervised_val_loader = unsupervised_val_loader 100 101 if self.unsupervised_val_loader is None: 102 val_loader = self.supervised_val_loader 103 else: 104 val_loader = self.unsupervised_train_loader 105 106 # Check that we have at least one of supvervised / unsupervised loss and metric. 107 assert sum(( 108 supervised_loss_and_metric is not None, 109 unsupervised_loss_and_metric is not None, 110 )) > 0 111 self.supervised_loss_and_metric = supervised_loss_and_metric 112 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 113 114 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 115 kwargs.pop("train_loader", None) 116 kwargs.pop("val_loader", None) 117 kwargs.pop("metric", None) 118 kwargs.pop("loss", None) 119 super().__init__( 120 model=model, train_loader=train_loader, val_loader=val_loader, 121 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 122 ) 123 124 self.unsupervised_loss = unsupervised_loss 125 self.supervised_loss = supervised_loss 126 127 self.pseudo_labeler = pseudo_labeler 128 129 if source_distribution is None: 130 self.source_distribution = None 131 else: 132 self.source_distribution = torch.FloatTensor(source_distribution).to(self.device) 133 134 self._kwargs = kwargs
140 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 141 train_loader_kwargs = get_constructor_arguments(self.train_loader) 142 val_loader_kwargs = get_constructor_arguments(self.val_loader) 143 extra_state = { 144 "init": { 145 "train_loader_kwargs": train_loader_kwargs, 146 "train_dataset": self.train_loader.dataset, 147 "val_loader_kwargs": val_loader_kwargs, 148 "val_dataset": self.val_loader.dataset, 149 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 150 "loss_kwargs": {}, 151 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 152 "metric_kwargs": {}, 153 }, 154 } 155 extra_state.update(**extra_save_dict) 156 super().save_checkpoint(name, current_metric, best_metric, **extra_state)
161 def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5): 162 if self.source_distribution is not None: 163 pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0) 164 _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True) 165 target_distribution = target_distribution / target_distribution.sum() 166 distribution_ratio = self.source_distribution / target_distribution 167 pseudo_labels = torch.where( 168 pseudo_labels < label_threshold, 169 pseudo_labels * distribution_ratio[0], 170 pseudo_labels * distribution_ratio[1] 171 ).clip(0, 1) 172 173 return pseudo_labels
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- mixed_precision
- early_stopping
- train_time
- scaler
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- load_checkpoint
- fit