torch_em.self_training.fix_match
1import time 2from typing import Callable, List, Optional 3 4import torch 5import torch_em 6from torch_em.util import get_constructor_arguments 7 8from .logger import SelfTrainingTensorboardLogger 9from .mean_teacher import Dummy 10 11 12class FixMatchTrainer(torch_em.trainer.DefaultTrainer): 13 """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach. 14 15 FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). 16 It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels 17 on unlabeled data. We support two training strategies: 18 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 19 - Taining only on the unsupervised data. 20 21 This class expects the following data loaders: 22 - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input. 23 - supervised_train_loader (optional): Returns input and labels. 24 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 25 - supervised_val_loader (optional): Same as supervised_train_loader 26 At least one of unsupervised_val_loader and supervised_val_loader must be given. 27 28 The following arguments can be used to customize the pseudo labeling: 29 - pseudo_labeler: to compute the psuedo-labels 30 - Parameters: model, teacher_input 31 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 32 - unsupervised_loss: the loss between model predictions and pseudo labels 33 - Parameters: model, model_input, pseudo_labels, label_filter 34 - Returns: loss 35 - supervised_loss (optional): the supervised loss function 36 - Parameters: model, input, labels 37 - Returns: loss 38 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 39 - Parameters: model, model_input, pseudo_labels, label_filter 40 - Returns: loss, metric 41 - supervised_loss_and_metric (optional): the supervised loss function and metric 42 - Parameters: model, input, labels 43 - Returns: loss, metric 44 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 45 46 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 47 for setting the ratio between supervised and unsupervised training samples. 48 49 Args: 50 model: The model to be trained. 51 unsupervised_train_loader: The loader for unsupervised training. 52 unsupervised_loss: The loss for unsupervised training. 53 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 54 supervised_train_loader: The loader for supervised training. 55 supervised_loss: The loss for supervised training. 56 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 57 supervised_loss_and_metric: The loss and metrhic for supervised training. 58 logger: The logger. 59 source_distribution: The ratio of labels in the source label distribution. 60 If given, the predicted distribution of the trained model will be regularized to 61 match this source label distribution. 62 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 63 """ 64 65 def __init__( 66 self, 67 model: torch.nn.Module, 68 unsupervised_train_loader: torch.utils.data.DataLoader, 69 unsupervised_loss: torch.utils.data.DataLoader, 70 pseudo_labeler: Callable, 71 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 72 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 73 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 74 supervised_loss: Optional[Callable] = None, 75 unsupervised_loss_and_metric: Optional[Callable] = None, 76 supervised_loss_and_metric: Optional[Callable] = None, 77 logger=SelfTrainingTensorboardLogger, 78 source_distribution: List[float] = None, 79 **kwargs, 80 ): 81 # Do we have supervised data or not? 82 if supervised_train_loader is None: 83 # No. -> We use the unsupervised training logic. 84 train_loader = unsupervised_train_loader 85 self._train_epoch_impl = self._train_epoch_unsupervised 86 else: 87 # Yes. -> We use the semi-supervised training logic. 88 assert supervised_loss is not None 89 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 90 else unsupervised_train_loader 91 self._train_epoch_impl = self._train_epoch_semisupervised 92 93 self.unsupervised_train_loader = unsupervised_train_loader 94 self.supervised_train_loader = supervised_train_loader 95 96 # Check that we have at least one of supvervised / unsupervised val loader. 97 assert sum(( 98 supervised_val_loader is not None, 99 unsupervised_val_loader is not None, 100 )) > 0 101 self.supervised_val_loader = supervised_val_loader 102 self.unsupervised_val_loader = unsupervised_val_loader 103 104 if self.unsupervised_val_loader is None: 105 val_loader = self.supervised_val_loader 106 else: 107 val_loader = self.unsupervised_train_loader 108 109 # Check that we have at least one of supvervised / unsupervised loss and metric. 110 assert sum(( 111 supervised_loss_and_metric is not None, 112 unsupervised_loss_and_metric is not None, 113 )) > 0 114 self.supervised_loss_and_metric = supervised_loss_and_metric 115 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 116 117 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 118 kwargs.pop("train_loader", None) 119 kwargs.pop("val_loader", None) 120 kwargs.pop("metric", None) 121 kwargs.pop("loss", None) 122 super().__init__( 123 model=model, train_loader=train_loader, val_loader=val_loader, 124 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 125 ) 126 127 self.unsupervised_loss = unsupervised_loss 128 self.supervised_loss = supervised_loss 129 130 self.pseudo_labeler = pseudo_labeler 131 132 if source_distribution is None: 133 self.source_distribution = None 134 else: 135 self.source_distribution = torch.FloatTensor(source_distribution).to(self.device) 136 137 self._kwargs = kwargs 138 139 # 140 # functionality for saving checkpoints and initialization 141 # 142 143 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 144 """@private 145 """ 146 train_loader_kwargs = get_constructor_arguments(self.train_loader) 147 val_loader_kwargs = get_constructor_arguments(self.val_loader) 148 extra_state = { 149 "init": { 150 "train_loader_kwargs": train_loader_kwargs, 151 "train_dataset": self.train_loader.dataset, 152 "val_loader_kwargs": val_loader_kwargs, 153 "val_dataset": self.val_loader.dataset, 154 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 155 "loss_kwargs": {}, 156 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 157 "metric_kwargs": {}, 158 }, 159 } 160 extra_state.update(**extra_save_dict) 161 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 162 163 # Distribution alignment: 164 # Encourages the distribution of the model's generated pseudo labels to match the marginal 165 # distribution of pseudo labels from the source transfer (key idea: to maximize the mutual information). 166 def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5): 167 """@private 168 """ 169 if self.source_distribution is not None: 170 pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0) 171 _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True) 172 target_distribution = target_distribution / target_distribution.sum() 173 distribution_ratio = self.source_distribution / target_distribution 174 pseudo_labels = torch.where( 175 pseudo_labels < label_threshold, 176 pseudo_labels * distribution_ratio[0], 177 pseudo_labels * distribution_ratio[1] 178 ).clip(0, 1) 179 180 return pseudo_labels 181 182 # 183 # training and validation functionality 184 # 185 186 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 187 self.model.train() 188 189 n_iter = 0 190 t_per_iter = time.time() 191 192 # Sample from both the supervised and unsupervised loader. 193 for xu1, xu2 in self.unsupervised_train_loader: 194 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 195 196 teacher_input, model_input = xu1, xu2 197 198 with forward_context(), torch.no_grad(): 199 # Compute the pseudo labels. 200 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 201 202 pseudo_labels = pseudo_labels.detach() 203 if label_filter is not None: 204 label_filter = label_filter.detach() 205 206 # Perform distribution alignment for pseudo labels 207 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 208 209 self.optimizer.zero_grad() 210 # Perform unsupervised training 211 with forward_context(): 212 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 213 214 backprop(loss) 215 216 if self.logger is not None: 217 with torch.no_grad(), forward_context(): 218 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 219 self.logger.log_train_unsupervised( 220 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 221 ) 222 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 223 self.logger.log_lr(self._iteration, lr) 224 225 self._iteration += 1 226 n_iter += 1 227 if self._iteration >= self.max_iteration: 228 break 229 progress.update(1) 230 231 t_per_iter = (time.time() - t_per_iter) / n_iter 232 return t_per_iter 233 234 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 235 self.model.train() 236 237 n_iter = 0 238 t_per_iter = time.time() 239 240 # Sample from both the supervised and unsupervised loader. 241 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 242 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 243 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 244 245 # Perform supervised training. 246 self.optimizer.zero_grad() 247 with forward_context(): 248 # We pass the model, the input and the labels to the supervised loss function, 249 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 250 supervised_loss = self.supervised_loss(self.model, xs, ys) 251 252 teacher_input, model_input = xu1, xu2 253 254 with forward_context(), torch.no_grad(): 255 # Compute the pseudo labels. 256 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 257 258 pseudo_labels = pseudo_labels.detach() 259 if label_filter is not None: 260 label_filter = label_filter.detach() 261 262 # Perform distribution alignment for pseudo labels 263 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 264 265 # Perform unsupervised training 266 with forward_context(): 267 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 268 269 loss = (supervised_loss + unsupervised_loss) / 2 270 backprop(loss) 271 272 if self.logger is not None: 273 with torch.no_grad(), forward_context(): 274 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 275 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 276 277 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 278 self.logger.log_train_unsupervised( 279 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 280 ) 281 282 self.logger.log_combined_loss(self._iteration, loss) 283 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 284 self.logger.log_lr(self._iteration, lr) 285 286 self._iteration += 1 287 n_iter += 1 288 if self._iteration >= self.max_iteration: 289 break 290 progress.update(1) 291 292 t_per_iter = (time.time() - t_per_iter) / n_iter 293 return t_per_iter 294 295 def _validate_supervised(self, forward_context): 296 metric_val = 0.0 297 loss_val = 0.0 298 299 for x, y in self.supervised_val_loader: 300 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 301 with forward_context(): 302 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 303 loss_val += loss.item() 304 metric_val += metric.item() 305 306 metric_val /= len(self.supervised_val_loader) 307 loss_val /= len(self.supervised_val_loader) 308 309 if self.logger is not None: 310 with forward_context(): 311 pred = self.model(x) 312 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 313 314 return metric_val 315 316 def _validate_unsupervised(self, forward_context): 317 metric_val = 0.0 318 loss_val = 0.0 319 320 for x1, x2 in self.unsupervised_val_loader: 321 x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True) 322 teacher_input, model_input = x1, x2 323 with forward_context(): 324 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 325 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 326 loss_val += loss.item() 327 metric_val += metric.item() 328 329 metric_val /= len(self.unsupervised_val_loader) 330 loss_val /= len(self.unsupervised_val_loader) 331 332 if self.logger is not None: 333 with forward_context(): 334 pred = self.model(model_input) 335 self.logger.log_validation_unsupervised( 336 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 337 ) 338 339 return metric_val 340 341 def _validate_impl(self, forward_context): 342 self.model.eval() 343 344 with torch.no_grad(): 345 346 if self.supervised_val_loader is None: 347 supervised_metric = None 348 else: 349 supervised_metric = self._validate_supervised(forward_context) 350 351 if self.unsupervised_val_loader is None: 352 unsupervised_metric = None 353 else: 354 unsupervised_metric = self._validate_unsupervised(forward_context) 355 356 if unsupervised_metric is None: 357 metric = supervised_metric 358 elif supervised_metric is None: 359 metric = unsupervised_metric 360 else: 361 metric = (supervised_metric + unsupervised_metric) / 2 362 363 return metric
13class FixMatchTrainer(torch_em.trainer.DefaultTrainer): 14 """Trainer for semi-supervised learning and domain adaptation following the FixMatch approach. 15 16 FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). 17 It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels 18 on unlabeled data. We support two training strategies: 19 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 20 - Taining only on the unsupervised data. 21 22 This class expects the following data loaders: 23 - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input. 24 - supervised_train_loader (optional): Returns input and labels. 25 - unsupervised_val_loader (optional): Same as unsupervised_train_loader 26 - supervised_val_loader (optional): Same as supervised_train_loader 27 At least one of unsupervised_val_loader and supervised_val_loader must be given. 28 29 The following arguments can be used to customize the pseudo labeling: 30 - pseudo_labeler: to compute the psuedo-labels 31 - Parameters: model, teacher_input 32 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 33 - unsupervised_loss: the loss between model predictions and pseudo labels 34 - Parameters: model, model_input, pseudo_labels, label_filter 35 - Returns: loss 36 - supervised_loss (optional): the supervised loss function 37 - Parameters: model, input, labels 38 - Returns: loss 39 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 40 - Parameters: model, model_input, pseudo_labels, label_filter 41 - Returns: loss, metric 42 - supervised_loss_and_metric (optional): the supervised loss function and metric 43 - Parameters: model, input, labels 44 - Returns: loss, metric 45 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 46 47 Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' 48 for setting the ratio between supervised and unsupervised training samples. 49 50 Args: 51 model: The model to be trained. 52 unsupervised_train_loader: The loader for unsupervised training. 53 unsupervised_loss: The loss for unsupervised training. 54 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 55 supervised_train_loader: The loader for supervised training. 56 supervised_loss: The loss for supervised training. 57 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 58 supervised_loss_and_metric: The loss and metrhic for supervised training. 59 logger: The logger. 60 source_distribution: The ratio of labels in the source label distribution. 61 If given, the predicted distribution of the trained model will be regularized to 62 match this source label distribution. 63 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 64 """ 65 66 def __init__( 67 self, 68 model: torch.nn.Module, 69 unsupervised_train_loader: torch.utils.data.DataLoader, 70 unsupervised_loss: torch.utils.data.DataLoader, 71 pseudo_labeler: Callable, 72 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 73 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 74 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 75 supervised_loss: Optional[Callable] = None, 76 unsupervised_loss_and_metric: Optional[Callable] = None, 77 supervised_loss_and_metric: Optional[Callable] = None, 78 logger=SelfTrainingTensorboardLogger, 79 source_distribution: List[float] = None, 80 **kwargs, 81 ): 82 # Do we have supervised data or not? 83 if supervised_train_loader is None: 84 # No. -> We use the unsupervised training logic. 85 train_loader = unsupervised_train_loader 86 self._train_epoch_impl = self._train_epoch_unsupervised 87 else: 88 # Yes. -> We use the semi-supervised training logic. 89 assert supervised_loss is not None 90 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 91 else unsupervised_train_loader 92 self._train_epoch_impl = self._train_epoch_semisupervised 93 94 self.unsupervised_train_loader = unsupervised_train_loader 95 self.supervised_train_loader = supervised_train_loader 96 97 # Check that we have at least one of supvervised / unsupervised val loader. 98 assert sum(( 99 supervised_val_loader is not None, 100 unsupervised_val_loader is not None, 101 )) > 0 102 self.supervised_val_loader = supervised_val_loader 103 self.unsupervised_val_loader = unsupervised_val_loader 104 105 if self.unsupervised_val_loader is None: 106 val_loader = self.supervised_val_loader 107 else: 108 val_loader = self.unsupervised_train_loader 109 110 # Check that we have at least one of supvervised / unsupervised loss and metric. 111 assert sum(( 112 supervised_loss_and_metric is not None, 113 unsupervised_loss_and_metric is not None, 114 )) > 0 115 self.supervised_loss_and_metric = supervised_loss_and_metric 116 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 117 118 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 119 kwargs.pop("train_loader", None) 120 kwargs.pop("val_loader", None) 121 kwargs.pop("metric", None) 122 kwargs.pop("loss", None) 123 super().__init__( 124 model=model, train_loader=train_loader, val_loader=val_loader, 125 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 126 ) 127 128 self.unsupervised_loss = unsupervised_loss 129 self.supervised_loss = supervised_loss 130 131 self.pseudo_labeler = pseudo_labeler 132 133 if source_distribution is None: 134 self.source_distribution = None 135 else: 136 self.source_distribution = torch.FloatTensor(source_distribution).to(self.device) 137 138 self._kwargs = kwargs 139 140 # 141 # functionality for saving checkpoints and initialization 142 # 143 144 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 145 """@private 146 """ 147 train_loader_kwargs = get_constructor_arguments(self.train_loader) 148 val_loader_kwargs = get_constructor_arguments(self.val_loader) 149 extra_state = { 150 "init": { 151 "train_loader_kwargs": train_loader_kwargs, 152 "train_dataset": self.train_loader.dataset, 153 "val_loader_kwargs": val_loader_kwargs, 154 "val_dataset": self.val_loader.dataset, 155 "loss_class": "torch_em.self_training.mean_teacher.Dummy", 156 "loss_kwargs": {}, 157 "metric_class": "torch_em.self_training.mean_teacher.Dummy", 158 "metric_kwargs": {}, 159 }, 160 } 161 extra_state.update(**extra_save_dict) 162 super().save_checkpoint(name, current_metric, best_metric, **extra_state) 163 164 # Distribution alignment: 165 # Encourages the distribution of the model's generated pseudo labels to match the marginal 166 # distribution of pseudo labels from the source transfer (key idea: to maximize the mutual information). 167 def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5): 168 """@private 169 """ 170 if self.source_distribution is not None: 171 pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0) 172 _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True) 173 target_distribution = target_distribution / target_distribution.sum() 174 distribution_ratio = self.source_distribution / target_distribution 175 pseudo_labels = torch.where( 176 pseudo_labels < label_threshold, 177 pseudo_labels * distribution_ratio[0], 178 pseudo_labels * distribution_ratio[1] 179 ).clip(0, 1) 180 181 return pseudo_labels 182 183 # 184 # training and validation functionality 185 # 186 187 def _train_epoch_unsupervised(self, progress, forward_context, backprop): 188 self.model.train() 189 190 n_iter = 0 191 t_per_iter = time.time() 192 193 # Sample from both the supervised and unsupervised loader. 194 for xu1, xu2 in self.unsupervised_train_loader: 195 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 196 197 teacher_input, model_input = xu1, xu2 198 199 with forward_context(), torch.no_grad(): 200 # Compute the pseudo labels. 201 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 202 203 pseudo_labels = pseudo_labels.detach() 204 if label_filter is not None: 205 label_filter = label_filter.detach() 206 207 # Perform distribution alignment for pseudo labels 208 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 209 210 self.optimizer.zero_grad() 211 # Perform unsupervised training 212 with forward_context(): 213 loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 214 215 backprop(loss) 216 217 if self.logger is not None: 218 with torch.no_grad(), forward_context(): 219 pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 220 self.logger.log_train_unsupervised( 221 self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter 222 ) 223 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 224 self.logger.log_lr(self._iteration, lr) 225 226 self._iteration += 1 227 n_iter += 1 228 if self._iteration >= self.max_iteration: 229 break 230 progress.update(1) 231 232 t_per_iter = (time.time() - t_per_iter) / n_iter 233 return t_per_iter 234 235 def _train_epoch_semisupervised(self, progress, forward_context, backprop): 236 self.model.train() 237 238 n_iter = 0 239 t_per_iter = time.time() 240 241 # Sample from both the supervised and unsupervised loader. 242 for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): 243 xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) 244 xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) 245 246 # Perform supervised training. 247 self.optimizer.zero_grad() 248 with forward_context(): 249 # We pass the model, the input and the labels to the supervised loss function, 250 # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. 251 supervised_loss = self.supervised_loss(self.model, xs, ys) 252 253 teacher_input, model_input = xu1, xu2 254 255 with forward_context(), torch.no_grad(): 256 # Compute the pseudo labels. 257 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 258 259 pseudo_labels = pseudo_labels.detach() 260 if label_filter is not None: 261 label_filter = label_filter.detach() 262 263 # Perform distribution alignment for pseudo labels 264 pseudo_labels = self.get_distribution_alignment(pseudo_labels) 265 266 # Perform unsupervised training 267 with forward_context(): 268 unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) 269 270 loss = (supervised_loss + unsupervised_loss) / 2 271 backprop(loss) 272 273 if self.logger is not None: 274 with torch.no_grad(), forward_context(): 275 unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None 276 supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None 277 278 self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) 279 self.logger.log_train_unsupervised( 280 self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter 281 ) 282 283 self.logger.log_combined_loss(self._iteration, loss) 284 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 285 self.logger.log_lr(self._iteration, lr) 286 287 self._iteration += 1 288 n_iter += 1 289 if self._iteration >= self.max_iteration: 290 break 291 progress.update(1) 292 293 t_per_iter = (time.time() - t_per_iter) / n_iter 294 return t_per_iter 295 296 def _validate_supervised(self, forward_context): 297 metric_val = 0.0 298 loss_val = 0.0 299 300 for x, y in self.supervised_val_loader: 301 x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) 302 with forward_context(): 303 loss, metric = self.supervised_loss_and_metric(self.model, x, y) 304 loss_val += loss.item() 305 metric_val += metric.item() 306 307 metric_val /= len(self.supervised_val_loader) 308 loss_val /= len(self.supervised_val_loader) 309 310 if self.logger is not None: 311 with forward_context(): 312 pred = self.model(x) 313 self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) 314 315 return metric_val 316 317 def _validate_unsupervised(self, forward_context): 318 metric_val = 0.0 319 loss_val = 0.0 320 321 for x1, x2 in self.unsupervised_val_loader: 322 x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True) 323 teacher_input, model_input = x1, x2 324 with forward_context(): 325 pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) 326 loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) 327 loss_val += loss.item() 328 metric_val += metric.item() 329 330 metric_val /= len(self.unsupervised_val_loader) 331 loss_val /= len(self.unsupervised_val_loader) 332 333 if self.logger is not None: 334 with forward_context(): 335 pred = self.model(model_input) 336 self.logger.log_validation_unsupervised( 337 self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter 338 ) 339 340 return metric_val 341 342 def _validate_impl(self, forward_context): 343 self.model.eval() 344 345 with torch.no_grad(): 346 347 if self.supervised_val_loader is None: 348 supervised_metric = None 349 else: 350 supervised_metric = self._validate_supervised(forward_context) 351 352 if self.unsupervised_val_loader is None: 353 unsupervised_metric = None 354 else: 355 unsupervised_metric = self._validate_unsupervised(forward_context) 356 357 if unsupervised_metric is None: 358 metric = supervised_metric 359 elif supervised_metric is None: 360 metric = unsupervised_metric 361 else: 362 metric = (supervised_metric + unsupervised_metric) / 2 363 364 return metric
Trainer for semi-supervised learning and domain adaptation following the FixMatch approach.
FixMatch was introduced by Sohn et al. in https://arxiv.org/abs/2001.07685). It uses a teacher model derived from the student model via weight sharing to predict pseudo-labels on unlabeled data. We support two training strategies:
- Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
- Taining only on the unsupervised data.
This class expects the following data loaders:
- unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input.
- supervised_train_loader (optional): Returns input and labels.
- unsupervised_val_loader (optional): Same as unsupervised_train_loader
- supervised_val_loader (optional): Same as supervised_train_loader At least one of unsupervised_val_loader and supervised_val_loader must be given.
The following arguments can be used to customize the pseudo labeling:
- pseudo_labeler: to compute the psuedo-labels
- Parameters: model, teacher_input
- Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
- unsupervised_loss: the loss between model predictions and pseudo labels
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss
- supervised_loss (optional): the supervised loss function
- Parameters: model, input, labels
- Returns: loss
- unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
- Parameters: model, model_input, pseudo_labels, label_filter
- Returns: loss, metric
- supervised_loss_and_metric (optional): the supervised loss function and metric
- Parameters: model, input, labels
- Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' for setting the ratio between supervised and unsupervised training samples.
Arguments:
- model: The model to be trained.
- unsupervised_train_loader: The loader for unsupervised training.
- unsupervised_loss: The loss for unsupervised training.
- pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
- supervised_train_loader: The loader for supervised training.
- supervised_loss: The loss for supervised training.
- unsupervised_loss_and_metric: The loss and metric for unsupervised training.
- supervised_loss_and_metric: The loss and metrhic for supervised training.
- logger: The logger.
- source_distribution: The ratio of labels in the source label distribution. If given, the predicted distribution of the trained model will be regularized to match this source label distribution.
- kwargs: Additional keyword arguments for
torch_em.trainer.DefaultTrainer
.
FixMatchTrainer( model: torch.nn.modules.module.Module, unsupervised_train_loader: torch.utils.data.dataloader.DataLoader, unsupervised_loss: torch.utils.data.dataloader.DataLoader, pseudo_labeler: Callable, supervised_train_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, unsupervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, supervised_loss: Optional[Callable] = None, unsupervised_loss_and_metric: Optional[Callable] = None, supervised_loss_and_metric: Optional[Callable] = None, logger=<class 'torch_em.self_training.logger.SelfTrainingTensorboardLogger'>, source_distribution: List[float] = None, **kwargs)
66 def __init__( 67 self, 68 model: torch.nn.Module, 69 unsupervised_train_loader: torch.utils.data.DataLoader, 70 unsupervised_loss: torch.utils.data.DataLoader, 71 pseudo_labeler: Callable, 72 supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, 73 unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 74 supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, 75 supervised_loss: Optional[Callable] = None, 76 unsupervised_loss_and_metric: Optional[Callable] = None, 77 supervised_loss_and_metric: Optional[Callable] = None, 78 logger=SelfTrainingTensorboardLogger, 79 source_distribution: List[float] = None, 80 **kwargs, 81 ): 82 # Do we have supervised data or not? 83 if supervised_train_loader is None: 84 # No. -> We use the unsupervised training logic. 85 train_loader = unsupervised_train_loader 86 self._train_epoch_impl = self._train_epoch_unsupervised 87 else: 88 # Yes. -> We use the semi-supervised training logic. 89 assert supervised_loss is not None 90 train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ 91 else unsupervised_train_loader 92 self._train_epoch_impl = self._train_epoch_semisupervised 93 94 self.unsupervised_train_loader = unsupervised_train_loader 95 self.supervised_train_loader = supervised_train_loader 96 97 # Check that we have at least one of supvervised / unsupervised val loader. 98 assert sum(( 99 supervised_val_loader is not None, 100 unsupervised_val_loader is not None, 101 )) > 0 102 self.supervised_val_loader = supervised_val_loader 103 self.unsupervised_val_loader = unsupervised_val_loader 104 105 if self.unsupervised_val_loader is None: 106 val_loader = self.supervised_val_loader 107 else: 108 val_loader = self.unsupervised_train_loader 109 110 # Check that we have at least one of supvervised / unsupervised loss and metric. 111 assert sum(( 112 supervised_loss_and_metric is not None, 113 unsupervised_loss_and_metric is not None, 114 )) > 0 115 self.supervised_loss_and_metric = supervised_loss_and_metric 116 self.unsupervised_loss_and_metric = unsupervised_loss_and_metric 117 118 # train_loader, val_loader, loss and metric may be unnecessarily deserialized 119 kwargs.pop("train_loader", None) 120 kwargs.pop("val_loader", None) 121 kwargs.pop("metric", None) 122 kwargs.pop("loss", None) 123 super().__init__( 124 model=model, train_loader=train_loader, val_loader=val_loader, 125 loss=Dummy(), metric=Dummy(), logger=logger, **kwargs 126 ) 127 128 self.unsupervised_loss = unsupervised_loss 129 self.supervised_loss = supervised_loss 130 131 self.pseudo_labeler = pseudo_labeler 132 133 if source_distribution is None: 134 self.source_distribution = None 135 else: 136 self.source_distribution = torch.FloatTensor(source_distribution).to(self.device) 137 138 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit