torch_em.self_training.uni_match_v2
1import time 2 3import torch 4from torch_em.self_training.mean_teacher import MeanTeacherTrainerWithInvertibleAugmentations 5 6# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 7 8 9class UniMatchv2Trainer(MeanTeacherTrainerWithInvertibleAugmentations): 10 """Trainer for semi-supervised learning and domain adaptation following the UniMatch v2 framework. 11 12 UniMatch v2 was introduced by Yang et al. in https://arxiv.org/abs/2410.10777v2. 13 It uses a teacher model derived from the student model via EMA of weights to predict 14 pseudo-labels on unlabeled data. Three augmented views are generated per sample - one weak 15 (for the teacher) and two strong (for the student) - and the student loss is computed as the 16 average over both strong-view predictions against the shared weak-view pseudo-label. 17 We support two training strategies: 18 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 19 - Training only on the unsupervised data. 20 21 This class expects the following data loaders: 22 - unsupervised_train_loader: Returns a single (raw) input per sample. The trainer applies 23 weak and two strong augmentations internally via the augmenter. 24 - supervised_train_loader (optional): Returns input and labels. 25 - unsupervised_val_loader (optional): Same format as unsupervised_train_loader. 26 - supervised_val_loader (optional): Same format as supervised_train_loader. 27 At least one of unsupervised_val_loader and supervised_val_loader must be given. 28 29 The augmenter must be a `UniMatchv2Augmenters` instance providing three invertible transforms: 30 `.weak` for the teacher view, `.strong1` and `.strong2` for the two student views. The 31 corresponding inverse transforms map predictions and pseudo-labels back into a shared 32 reference frame before the loss is computed. 33 34 The following arguments can be used to customize the pseudo labeling: 35 - pseudo_labeler: to compute the pseudo-labels 36 - Parameters: teacher, teacher_input 37 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 38 - unsupervised_loss: the loss between stacked student predictions and pseudo-labels 39 - Parameters: prediction (stacked [pred_s1_inv, pred_s2_inv]), pseudo_labels, label_filter, pred_dim 40 - Returns: loss 41 - supervised_loss (optional): the supervised loss function 42 - Parameters: prediction, labels 43 - Returns: loss 44 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 45 - Parameters: prediction (stacked), pseudo_labels, label_filter, pred_dim 46 - Returns: loss, metric 47 - supervised_loss_and_metric (optional): the supervised loss function and metric 48 - Parameters: prediction, labels 49 - Returns: loss, metric 50 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 51 52 Note: adjust the batch size of the 'unsupervised_train_loader' relative to 53 'supervised_train_loader' to control the ratio of supervised to unsupervised training samples. 54 55 Args: 56 model: The model to be trained. 57 unsupervised_train_loader: The loader for unsupervised training (returns raw inputs only). 58 unsupervised_loss: The loss for unsupervised training. 59 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 60 augmenter: `UniMatchv2Augmenters` instance providing `.weak`, `.strong1`, and `.strong2` 61 invertible transforms with corresponding inverse transforms. 62 complementary_dropout: If True, applies complementary feature dropout to the encoder 63 features before decoding, creating two complementary student views. Requires a 64 UNETR-compatible model architecture. 65 supervised_train_loader: The loader for supervised training. 66 supervised_loss: The loss for supervised training. 67 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 68 supervised_loss_and_metric: The loss and metric for supervised training. 69 logger: The logger. Defaults to `UniMatchv2TensorboardLogger`. 70 momentum: The momentum value for the exponential moving weight average of the teacher model. 71 reinit_teacher: Whether to reinit the teacher model before starting the training. 72 sampler: A sampler for rejecting pseudo-labels according to a defined criterion. 73 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 74 """ 75 76 def __init__( 77 self, complementary_dropout, **kwargs 78 ): 79 super().__init__(**kwargs) 80 self.complementary_dropout = complementary_dropout 81 82 self.teacher.eval() 83 84 def unetr_decoder_prediction(self, model, features, input_shape, original_shape): 85 86 z9 = model.deconv1(features) 87 z6 = model.deconv2(z9) 88 z3 = model.deconv3(z6) 89 z0 = model.deconv4(z3) 90 91 updated_from_encoder = [z9, z6, z3] 92 93 x = model.base(features) 94 x = model.decoder(x, encoder_inputs=updated_from_encoder) 95 x = model.deconv_out(x) 96 97 x = torch.cat([x, z0], dim=1) 98 x = model.decoder_head(x) 99 100 x = model.out_conv(x) 101 if model.final_activation is not None: 102 x = model.final_activation(x) 103 104 x = model.postprocess_masks(x, input_shape, original_shape) 105 return x 106 107 def predict_with_comp_drop(self, model, input_): 108 batch_size = input_.shape[0] 109 original_shape = input_.shape[2:] 110 111 x, input_shape = model.preprocess(input_) 112 113 if len(original_shape) == 2: 114 encoder_output = model.encoder(x) 115 if isinstance(encoder_output[-1], list): 116 features, _ = encoder_output 117 else: 118 features = encoder_output 119 if len(original_shape) == 3: 120 depth = input_.shape[-3] 121 features = torch.stack([model.encoder(x[:, :, i])[0] for i in range(depth)], dim=2) 122 123 features_dim = features.shape[1] 124 125 binom = torch.distributions.binomial.Binomial(probs=0.5) 126 127 dropout_mask1 = binom.sample((int(batch_size/2), features_dim)).to(input_.device) * 2.0 128 if len(original_shape) == 2: 129 dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1) 130 if len(original_shape) == 3: 131 dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 132 133 dropout_mask2 = 2.0 - dropout_mask1 134 dropout_mask = torch.cat([dropout_mask1, dropout_mask2]) 135 136 # NOTE: in the UniMatch v2 code some samples of the batch stay unchanged! 137 # Keep some samples unchanged (code block not tested) 138 # dropout_prob = 0.5 139 # num_kept = int(batch_size * (1 - dropout_prob)) 140 # kept_indexes = torch.randperm(batch_size, device=input_.device)[:num_kept] 141 142 # dropout_mask1[kept_indexes, :] = 1.0 143 # dropout_mask2[kept_indexes, :] = 1.0 144 145 dropped_features = features * dropout_mask 146 147 pred = self.unetr_decoder_prediction(model, dropped_features, input_shape, original_shape) 148 149 return pred 150 151 def _train_epoch_unsupervised( 152 self, progress, forward_context, backprop 153 ): 154 self.model.train() 155 156 n_iter = 0 157 t_per_iter = time.time() 158 159 for x_u in self.unsupervised_train_loader: 160 self.augmenter.reset_all() 161 x_u = x_u.to(self.device, non_blocking=True) 162 163 x_u_w = self.augmenter.weak.transform(x_u) 164 x_u_s1, x_u_s2 = self.augmenter.strong1.transform(x_u), self.augmenter.strong2.transform(x_u) 165 166 # Compute the pseudo labels (unsupervised teacher prediction) 167 with forward_context(), torch.no_grad(): 168 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_u_w) 169 pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels) 170 label_filter_inv = ( 171 self.augmenter.weak.reverse_transform(label_filter) 172 if label_filter is not None else None 173 ) 174 175 # Perform unsupervised training 176 with forward_context(): 177 if self.complementary_dropout: 178 pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_u_s1, x_u_s2))).chunk(2) 179 else: 180 pred_s1, pred_s2 = self.model(torch.cat((x_u_s1, x_u_s2))).chunk(2) 181 pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1) 182 pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2) 183 unsupervised_loss = self.unsupervised_loss( 184 torch.stack((pred_s1_inv, pred_s2_inv)), 185 pseudo_labels_inv, 186 label_filter_inv, 187 pred_dim=2, 188 ) 189 190 backprop(unsupervised_loss) 191 192 if self.logger is not None: 193 self.logger.log_train_unsupervised( 194 self._iteration, 195 unsupervised_loss, 196 x_u, 197 pred_s1_inv, 198 pred_s2_inv, 199 pseudo_labels_inv, 200 label_filter_inv, 201 ) 202 self.logger.log_train_augmentations( 203 self._iteration, 204 x_u_w, 205 x_u_s1, 206 x_u_s2, 207 pseudo_labels, 208 pred_s1, 209 pred_s2, 210 ) 211 212 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 213 self.logger.log_lr(self._iteration, lr) 214 if self.pseudo_labeler.confidence_threshold is not None: 215 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 216 217 with torch.no_grad(): 218 self._momentum_update() # EMA update of the teacher 219 220 self._iteration += 1 221 n_iter += 1 222 if self._iteration >= self.max_iteration: 223 break 224 progress.update(1) 225 226 t_per_iter = (time.time() - t_per_iter) / n_iter 227 return t_per_iter 228 229 def _train_epoch_semisupervised( 230 self, progress, forward_context, backprop 231 ): 232 train_loader = zip(self.supervised_train_loader, self.unsupervised_train_loader) 233 self.model.train() 234 235 n_iter = 0 236 t_per_iter = time.time() 237 238 for i, ((x_s, y_s), x_u) in enumerate(train_loader): 239 self.augmenter.reset_all() 240 241 x_s, y_s = x_s.to(self.device, non_blocking=True), y_s.to(self.device, non_blocking=True) 242 x_u = x_u.to(self.device, non_blocking=True) 243 244 x_u_w = self.augmenter.weak.transform(x_u) 245 x_u_s1, x_u_s2 = self.augmenter.strong1.transform(x_u), self.augmenter.strong2.transform(x_u) 246 247 self.optimizer.zero_grad() 248 # supervised loss (supervised student prediction) 249 pred_s = self.model(x_s) 250 supervised_loss = self.supervised_loss(pred_s, y_s) 251 252 backprop(supervised_loss) 253 254 # Compute the pseudo labels (unsupervised teacher prediction) 255 with forward_context(), torch.no_grad(): 256 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_u_w) 257 pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels) 258 label_filter_inv = ( 259 self.augmenter.weak.reverse_transform(label_filter) 260 if label_filter is not None else None 261 ) 262 263 # Perform unsupervised training 264 self.optimizer.zero_grad() 265 with forward_context(): 266 if self.complementary_dropout: 267 pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_u_s1, x_u_s2))).chunk(2) 268 else: 269 pred_s1, pred_s2 = self.model(torch.cat((x_u_s1, x_u_s2))).chunk(2) 270 pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1) 271 pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2) 272 unsupervised_loss = self.unsupervised_loss( 273 torch.stack((pred_s1_inv, pred_s2_inv)), 274 pseudo_labels_inv, 275 label_filter_inv, 276 pred_dim=2, 277 ) 278 279 backprop(unsupervised_loss) 280 281 if self.logger is not None: 282 self.logger.log_train_supervised( 283 self._iteration, supervised_loss, x_s, y_s, pred_s 284 ) 285 self.logger.log_train_unsupervised( 286 self._iteration, 287 unsupervised_loss, 288 x_u, 289 pred_s1_inv, 290 pred_s2_inv, 291 pseudo_labels_inv, 292 label_filter_inv, 293 ) 294 self.logger.log_train_augmentations( 295 self._iteration, 296 x_u_w, 297 x_u_s1, 298 x_u_s2, 299 pseudo_labels, 300 pred_s1, 301 pred_s2, 302 ) 303 304 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 305 self.logger.log_lr(self._iteration, lr) 306 if self.pseudo_labeler.confidence_threshold is not None: 307 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 308 309 with torch.no_grad(): 310 self._momentum_update() # EMA update of the teacher 311 312 self._iteration += 1 313 n_iter += 1 314 if self._iteration >= self.max_iteration: 315 break 316 progress.update(1) 317 318 t_per_iter = (time.time() - t_per_iter) / n_iter 319 return t_per_iter 320 321 def _validate_supervised(self, forward_context): 322 metric_val = 0.0 323 loss_val = 0.0 324 325 for x, y in self.supervised_val_loader: 326 x, y = ( 327 x.to(self.device, non_blocking=True), 328 y.to(self.device, non_blocking=True) 329 ) 330 331 with forward_context(): 332 pred = self.model(x) 333 loss, metric = self.supervised_loss_and_metric(pred, y) 334 loss_val += loss.item() 335 metric_val += metric.item() 336 337 metric_val /= len(self.supervised_val_loader) 338 loss_val /= len(self.supervised_val_loader) 339 340 if self.logger is not None: 341 self.logger.log_validation_supervised( 342 self._iteration, metric_val, loss_val, x, y, pred 343 ) 344 345 return metric_val 346 347 def _validate_unsupervised(self, forward_context): 348 metric_val = 0.0 349 loss_val = 0.0 350 351 for x in self.unsupervised_val_loader: 352 self.augmenter.reset_all() 353 x = x.to(self.device, non_blocking=True) 354 355 # apply augmentations 356 x_w = self.augmenter.weak.transform(x) 357 x_s1, x_s2 = self.augmenter.strong1.transform(x), self.augmenter.strong2.transform(x) 358 359 # Compute the pseudo labels (unsupervised teacher prediction) 360 with forward_context(): 361 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_w) 362 pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels) 363 label_filter_inv = ( 364 self.augmenter.weak.reverse_transform(label_filter) 365 if label_filter is not None else None 366 ) 367 368 if self.complementary_dropout: 369 pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_s1, x_s2))).chunk(2) 370 else: 371 pred_s1, pred_s2 = self.model(torch.cat((x_s1, x_s2))).chunk(2) 372 pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1) 373 pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2) 374 375 loss, metric = self.unsupervised_loss_and_metric( 376 torch.stack((pred_s1_inv, pred_s2_inv)), 377 pseudo_labels_inv, 378 label_filter_inv, 379 pred_dim=2, 380 ) 381 loss_val += loss.item() 382 metric_val += metric.item() 383 384 metric_val /= len(self.unsupervised_val_loader) 385 loss_val /= len(self.unsupervised_val_loader) 386 387 if self.logger is not None: 388 self.logger.log_validation_unsupervised( 389 self._iteration, 390 metric_val, 391 loss_val, 392 x, 393 pred_s1_inv, 394 pred_s2_inv, 395 pseudo_labels_inv, 396 label_filter_inv, 397 ) 398 399 self.logger.log_validation_augmentations( 400 self._iteration, 401 x_w, 402 x_s1, 403 x_s2, 404 pseudo_labels, 405 pred_s1, 406 pred_s2, 407 ) 408 409 self.pseudo_labeler.step(metric_val, self._epoch) 410 411 return metric_val 412 413 def _validate_impl(self, forward_context): 414 self.model.eval() 415 416 with torch.no_grad(): 417 418 if self.supervised_val_loader is None: 419 supervised_metric = None 420 else: 421 supervised_metric = self._validate_supervised(forward_context) 422 423 if self.unsupervised_val_loader is None: 424 unsupervised_metric = None 425 else: 426 unsupervised_metric = self._validate_unsupervised(forward_context) 427 428 if unsupervised_metric is None: 429 metric = supervised_metric 430 elif supervised_metric is None: 431 metric = unsupervised_metric 432 else: 433 metric = (supervised_metric + unsupervised_metric) / 2 434 435 return metric
10class UniMatchv2Trainer(MeanTeacherTrainerWithInvertibleAugmentations): 11 """Trainer for semi-supervised learning and domain adaptation following the UniMatch v2 framework. 12 13 UniMatch v2 was introduced by Yang et al. in https://arxiv.org/abs/2410.10777v2. 14 It uses a teacher model derived from the student model via EMA of weights to predict 15 pseudo-labels on unlabeled data. Three augmented views are generated per sample - one weak 16 (for the teacher) and two strong (for the student) - and the student loss is computed as the 17 average over both strong-view predictions against the shared weak-view pseudo-label. 18 We support two training strategies: 19 - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function). 20 - Training only on the unsupervised data. 21 22 This class expects the following data loaders: 23 - unsupervised_train_loader: Returns a single (raw) input per sample. The trainer applies 24 weak and two strong augmentations internally via the augmenter. 25 - supervised_train_loader (optional): Returns input and labels. 26 - unsupervised_val_loader (optional): Same format as unsupervised_train_loader. 27 - supervised_val_loader (optional): Same format as supervised_train_loader. 28 At least one of unsupervised_val_loader and supervised_val_loader must be given. 29 30 The augmenter must be a `UniMatchv2Augmenters` instance providing three invertible transforms: 31 `.weak` for the teacher view, `.strong1` and `.strong2` for the two student views. The 32 corresponding inverse transforms map predictions and pseudo-labels back into a shared 33 reference frame before the loss is computed. 34 35 The following arguments can be used to customize the pseudo labeling: 36 - pseudo_labeler: to compute the pseudo-labels 37 - Parameters: teacher, teacher_input 38 - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) 39 - unsupervised_loss: the loss between stacked student predictions and pseudo-labels 40 - Parameters: prediction (stacked [pred_s1_inv, pred_s2_inv]), pseudo_labels, label_filter, pred_dim 41 - Returns: loss 42 - supervised_loss (optional): the supervised loss function 43 - Parameters: prediction, labels 44 - Returns: loss 45 - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric 46 - Parameters: prediction (stacked), pseudo_labels, label_filter, pred_dim 47 - Returns: loss, metric 48 - supervised_loss_and_metric (optional): the supervised loss function and metric 49 - Parameters: prediction, labels 50 - Returns: loss, metric 51 At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. 52 53 Note: adjust the batch size of the 'unsupervised_train_loader' relative to 54 'supervised_train_loader' to control the ratio of supervised to unsupervised training samples. 55 56 Args: 57 model: The model to be trained. 58 unsupervised_train_loader: The loader for unsupervised training (returns raw inputs only). 59 unsupervised_loss: The loss for unsupervised training. 60 pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training. 61 augmenter: `UniMatchv2Augmenters` instance providing `.weak`, `.strong1`, and `.strong2` 62 invertible transforms with corresponding inverse transforms. 63 complementary_dropout: If True, applies complementary feature dropout to the encoder 64 features before decoding, creating two complementary student views. Requires a 65 UNETR-compatible model architecture. 66 supervised_train_loader: The loader for supervised training. 67 supervised_loss: The loss for supervised training. 68 unsupervised_loss_and_metric: The loss and metric for unsupervised training. 69 supervised_loss_and_metric: The loss and metric for supervised training. 70 logger: The logger. Defaults to `UniMatchv2TensorboardLogger`. 71 momentum: The momentum value for the exponential moving weight average of the teacher model. 72 reinit_teacher: Whether to reinit the teacher model before starting the training. 73 sampler: A sampler for rejecting pseudo-labels according to a defined criterion. 74 kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`. 75 """ 76 77 def __init__( 78 self, complementary_dropout, **kwargs 79 ): 80 super().__init__(**kwargs) 81 self.complementary_dropout = complementary_dropout 82 83 self.teacher.eval() 84 85 def unetr_decoder_prediction(self, model, features, input_shape, original_shape): 86 87 z9 = model.deconv1(features) 88 z6 = model.deconv2(z9) 89 z3 = model.deconv3(z6) 90 z0 = model.deconv4(z3) 91 92 updated_from_encoder = [z9, z6, z3] 93 94 x = model.base(features) 95 x = model.decoder(x, encoder_inputs=updated_from_encoder) 96 x = model.deconv_out(x) 97 98 x = torch.cat([x, z0], dim=1) 99 x = model.decoder_head(x) 100 101 x = model.out_conv(x) 102 if model.final_activation is not None: 103 x = model.final_activation(x) 104 105 x = model.postprocess_masks(x, input_shape, original_shape) 106 return x 107 108 def predict_with_comp_drop(self, model, input_): 109 batch_size = input_.shape[0] 110 original_shape = input_.shape[2:] 111 112 x, input_shape = model.preprocess(input_) 113 114 if len(original_shape) == 2: 115 encoder_output = model.encoder(x) 116 if isinstance(encoder_output[-1], list): 117 features, _ = encoder_output 118 else: 119 features = encoder_output 120 if len(original_shape) == 3: 121 depth = input_.shape[-3] 122 features = torch.stack([model.encoder(x[:, :, i])[0] for i in range(depth)], dim=2) 123 124 features_dim = features.shape[1] 125 126 binom = torch.distributions.binomial.Binomial(probs=0.5) 127 128 dropout_mask1 = binom.sample((int(batch_size/2), features_dim)).to(input_.device) * 2.0 129 if len(original_shape) == 2: 130 dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1) 131 if len(original_shape) == 3: 132 dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 133 134 dropout_mask2 = 2.0 - dropout_mask1 135 dropout_mask = torch.cat([dropout_mask1, dropout_mask2]) 136 137 # NOTE: in the UniMatch v2 code some samples of the batch stay unchanged! 138 # Keep some samples unchanged (code block not tested) 139 # dropout_prob = 0.5 140 # num_kept = int(batch_size * (1 - dropout_prob)) 141 # kept_indexes = torch.randperm(batch_size, device=input_.device)[:num_kept] 142 143 # dropout_mask1[kept_indexes, :] = 1.0 144 # dropout_mask2[kept_indexes, :] = 1.0 145 146 dropped_features = features * dropout_mask 147 148 pred = self.unetr_decoder_prediction(model, dropped_features, input_shape, original_shape) 149 150 return pred 151 152 def _train_epoch_unsupervised( 153 self, progress, forward_context, backprop 154 ): 155 self.model.train() 156 157 n_iter = 0 158 t_per_iter = time.time() 159 160 for x_u in self.unsupervised_train_loader: 161 self.augmenter.reset_all() 162 x_u = x_u.to(self.device, non_blocking=True) 163 164 x_u_w = self.augmenter.weak.transform(x_u) 165 x_u_s1, x_u_s2 = self.augmenter.strong1.transform(x_u), self.augmenter.strong2.transform(x_u) 166 167 # Compute the pseudo labels (unsupervised teacher prediction) 168 with forward_context(), torch.no_grad(): 169 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_u_w) 170 pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels) 171 label_filter_inv = ( 172 self.augmenter.weak.reverse_transform(label_filter) 173 if label_filter is not None else None 174 ) 175 176 # Perform unsupervised training 177 with forward_context(): 178 if self.complementary_dropout: 179 pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_u_s1, x_u_s2))).chunk(2) 180 else: 181 pred_s1, pred_s2 = self.model(torch.cat((x_u_s1, x_u_s2))).chunk(2) 182 pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1) 183 pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2) 184 unsupervised_loss = self.unsupervised_loss( 185 torch.stack((pred_s1_inv, pred_s2_inv)), 186 pseudo_labels_inv, 187 label_filter_inv, 188 pred_dim=2, 189 ) 190 191 backprop(unsupervised_loss) 192 193 if self.logger is not None: 194 self.logger.log_train_unsupervised( 195 self._iteration, 196 unsupervised_loss, 197 x_u, 198 pred_s1_inv, 199 pred_s2_inv, 200 pseudo_labels_inv, 201 label_filter_inv, 202 ) 203 self.logger.log_train_augmentations( 204 self._iteration, 205 x_u_w, 206 x_u_s1, 207 x_u_s2, 208 pseudo_labels, 209 pred_s1, 210 pred_s2, 211 ) 212 213 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 214 self.logger.log_lr(self._iteration, lr) 215 if self.pseudo_labeler.confidence_threshold is not None: 216 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 217 218 with torch.no_grad(): 219 self._momentum_update() # EMA update of the teacher 220 221 self._iteration += 1 222 n_iter += 1 223 if self._iteration >= self.max_iteration: 224 break 225 progress.update(1) 226 227 t_per_iter = (time.time() - t_per_iter) / n_iter 228 return t_per_iter 229 230 def _train_epoch_semisupervised( 231 self, progress, forward_context, backprop 232 ): 233 train_loader = zip(self.supervised_train_loader, self.unsupervised_train_loader) 234 self.model.train() 235 236 n_iter = 0 237 t_per_iter = time.time() 238 239 for i, ((x_s, y_s), x_u) in enumerate(train_loader): 240 self.augmenter.reset_all() 241 242 x_s, y_s = x_s.to(self.device, non_blocking=True), y_s.to(self.device, non_blocking=True) 243 x_u = x_u.to(self.device, non_blocking=True) 244 245 x_u_w = self.augmenter.weak.transform(x_u) 246 x_u_s1, x_u_s2 = self.augmenter.strong1.transform(x_u), self.augmenter.strong2.transform(x_u) 247 248 self.optimizer.zero_grad() 249 # supervised loss (supervised student prediction) 250 pred_s = self.model(x_s) 251 supervised_loss = self.supervised_loss(pred_s, y_s) 252 253 backprop(supervised_loss) 254 255 # Compute the pseudo labels (unsupervised teacher prediction) 256 with forward_context(), torch.no_grad(): 257 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_u_w) 258 pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels) 259 label_filter_inv = ( 260 self.augmenter.weak.reverse_transform(label_filter) 261 if label_filter is not None else None 262 ) 263 264 # Perform unsupervised training 265 self.optimizer.zero_grad() 266 with forward_context(): 267 if self.complementary_dropout: 268 pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_u_s1, x_u_s2))).chunk(2) 269 else: 270 pred_s1, pred_s2 = self.model(torch.cat((x_u_s1, x_u_s2))).chunk(2) 271 pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1) 272 pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2) 273 unsupervised_loss = self.unsupervised_loss( 274 torch.stack((pred_s1_inv, pred_s2_inv)), 275 pseudo_labels_inv, 276 label_filter_inv, 277 pred_dim=2, 278 ) 279 280 backprop(unsupervised_loss) 281 282 if self.logger is not None: 283 self.logger.log_train_supervised( 284 self._iteration, supervised_loss, x_s, y_s, pred_s 285 ) 286 self.logger.log_train_unsupervised( 287 self._iteration, 288 unsupervised_loss, 289 x_u, 290 pred_s1_inv, 291 pred_s2_inv, 292 pseudo_labels_inv, 293 label_filter_inv, 294 ) 295 self.logger.log_train_augmentations( 296 self._iteration, 297 x_u_w, 298 x_u_s1, 299 x_u_s2, 300 pseudo_labels, 301 pred_s1, 302 pred_s2, 303 ) 304 305 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 306 self.logger.log_lr(self._iteration, lr) 307 if self.pseudo_labeler.confidence_threshold is not None: 308 self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) 309 310 with torch.no_grad(): 311 self._momentum_update() # EMA update of the teacher 312 313 self._iteration += 1 314 n_iter += 1 315 if self._iteration >= self.max_iteration: 316 break 317 progress.update(1) 318 319 t_per_iter = (time.time() - t_per_iter) / n_iter 320 return t_per_iter 321 322 def _validate_supervised(self, forward_context): 323 metric_val = 0.0 324 loss_val = 0.0 325 326 for x, y in self.supervised_val_loader: 327 x, y = ( 328 x.to(self.device, non_blocking=True), 329 y.to(self.device, non_blocking=True) 330 ) 331 332 with forward_context(): 333 pred = self.model(x) 334 loss, metric = self.supervised_loss_and_metric(pred, y) 335 loss_val += loss.item() 336 metric_val += metric.item() 337 338 metric_val /= len(self.supervised_val_loader) 339 loss_val /= len(self.supervised_val_loader) 340 341 if self.logger is not None: 342 self.logger.log_validation_supervised( 343 self._iteration, metric_val, loss_val, x, y, pred 344 ) 345 346 return metric_val 347 348 def _validate_unsupervised(self, forward_context): 349 metric_val = 0.0 350 loss_val = 0.0 351 352 for x in self.unsupervised_val_loader: 353 self.augmenter.reset_all() 354 x = x.to(self.device, non_blocking=True) 355 356 # apply augmentations 357 x_w = self.augmenter.weak.transform(x) 358 x_s1, x_s2 = self.augmenter.strong1.transform(x), self.augmenter.strong2.transform(x) 359 360 # Compute the pseudo labels (unsupervised teacher prediction) 361 with forward_context(): 362 pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_w) 363 pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels) 364 label_filter_inv = ( 365 self.augmenter.weak.reverse_transform(label_filter) 366 if label_filter is not None else None 367 ) 368 369 if self.complementary_dropout: 370 pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_s1, x_s2))).chunk(2) 371 else: 372 pred_s1, pred_s2 = self.model(torch.cat((x_s1, x_s2))).chunk(2) 373 pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1) 374 pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2) 375 376 loss, metric = self.unsupervised_loss_and_metric( 377 torch.stack((pred_s1_inv, pred_s2_inv)), 378 pseudo_labels_inv, 379 label_filter_inv, 380 pred_dim=2, 381 ) 382 loss_val += loss.item() 383 metric_val += metric.item() 384 385 metric_val /= len(self.unsupervised_val_loader) 386 loss_val /= len(self.unsupervised_val_loader) 387 388 if self.logger is not None: 389 self.logger.log_validation_unsupervised( 390 self._iteration, 391 metric_val, 392 loss_val, 393 x, 394 pred_s1_inv, 395 pred_s2_inv, 396 pseudo_labels_inv, 397 label_filter_inv, 398 ) 399 400 self.logger.log_validation_augmentations( 401 self._iteration, 402 x_w, 403 x_s1, 404 x_s2, 405 pseudo_labels, 406 pred_s1, 407 pred_s2, 408 ) 409 410 self.pseudo_labeler.step(metric_val, self._epoch) 411 412 return metric_val 413 414 def _validate_impl(self, forward_context): 415 self.model.eval() 416 417 with torch.no_grad(): 418 419 if self.supervised_val_loader is None: 420 supervised_metric = None 421 else: 422 supervised_metric = self._validate_supervised(forward_context) 423 424 if self.unsupervised_val_loader is None: 425 unsupervised_metric = None 426 else: 427 unsupervised_metric = self._validate_unsupervised(forward_context) 428 429 if unsupervised_metric is None: 430 metric = supervised_metric 431 elif supervised_metric is None: 432 metric = unsupervised_metric 433 else: 434 metric = (supervised_metric + unsupervised_metric) / 2 435 436 return metric
Trainer for semi-supervised learning and domain adaptation following the UniMatch v2 framework.
UniMatch v2 was introduced by Yang et al. in https://arxiv.org/abs/2410.10777v2. It uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. Three augmented views are generated per sample - one weak (for the teacher) and two strong (for the student) - and the student loss is computed as the average over both strong-view predictions against the shared weak-view pseudo-label. We support two training strategies:
- Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
- Training only on the unsupervised data.
This class expects the following data loaders:
- unsupervised_train_loader: Returns a single (raw) input per sample. The trainer applies weak and two strong augmentations internally via the augmenter.
- supervised_train_loader (optional): Returns input and labels.
- unsupervised_val_loader (optional): Same format as unsupervised_train_loader.
- supervised_val_loader (optional): Same format as supervised_train_loader. At least one of unsupervised_val_loader and supervised_val_loader must be given.
The augmenter must be a UniMatchv2Augmenters instance providing three invertible transforms:
.weak for the teacher view, .strong1 and .strong2 for the two student views. The
corresponding inverse transforms map predictions and pseudo-labels back into a shared
reference frame before the loss is computed.
The following arguments can be used to customize the pseudo labeling:
- pseudo_labeler: to compute the pseudo-labels
- Parameters: teacher, teacher_input
- Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
- unsupervised_loss: the loss between stacked student predictions and pseudo-labels
- Parameters: prediction (stacked [pred_s1_inv, pred_s2_inv]), pseudo_labels, label_filter, pred_dim
- Returns: loss
- supervised_loss (optional): the supervised loss function
- Parameters: prediction, labels
- Returns: loss
- unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
- Parameters: prediction (stacked), pseudo_labels, label_filter, pred_dim
- Returns: loss, metric
- supervised_loss_and_metric (optional): the supervised loss function and metric
- Parameters: prediction, labels
- Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
Note: adjust the batch size of the 'unsupervised_train_loader' relative to 'supervised_train_loader' to control the ratio of supervised to unsupervised training samples.
Arguments:
- model: The model to be trained.
- unsupervised_train_loader: The loader for unsupervised training (returns raw inputs only).
- unsupervised_loss: The loss for unsupervised training.
- pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
- augmenter:
UniMatchv2Augmentersinstance providing.weak,.strong1, and.strong2invertible transforms with corresponding inverse transforms. - complementary_dropout: If True, applies complementary feature dropout to the encoder features before decoding, creating two complementary student views. Requires a UNETR-compatible model architecture.
- supervised_train_loader: The loader for supervised training.
- supervised_loss: The loss for supervised training.
- unsupervised_loss_and_metric: The loss and metric for unsupervised training.
- supervised_loss_and_metric: The loss and metric for supervised training.
- logger: The logger. Defaults to
UniMatchv2TensorboardLogger. - momentum: The momentum value for the exponential moving weight average of the teacher model.
- reinit_teacher: Whether to reinit the teacher model before starting the training.
- sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
- kwargs: Additional keyword arguments for
torch_em.trainer.DefaultTrainer.
85 def unetr_decoder_prediction(self, model, features, input_shape, original_shape): 86 87 z9 = model.deconv1(features) 88 z6 = model.deconv2(z9) 89 z3 = model.deconv3(z6) 90 z0 = model.deconv4(z3) 91 92 updated_from_encoder = [z9, z6, z3] 93 94 x = model.base(features) 95 x = model.decoder(x, encoder_inputs=updated_from_encoder) 96 x = model.deconv_out(x) 97 98 x = torch.cat([x, z0], dim=1) 99 x = model.decoder_head(x) 100 101 x = model.out_conv(x) 102 if model.final_activation is not None: 103 x = model.final_activation(x) 104 105 x = model.postprocess_masks(x, input_shape, original_shape) 106 return x
108 def predict_with_comp_drop(self, model, input_): 109 batch_size = input_.shape[0] 110 original_shape = input_.shape[2:] 111 112 x, input_shape = model.preprocess(input_) 113 114 if len(original_shape) == 2: 115 encoder_output = model.encoder(x) 116 if isinstance(encoder_output[-1], list): 117 features, _ = encoder_output 118 else: 119 features = encoder_output 120 if len(original_shape) == 3: 121 depth = input_.shape[-3] 122 features = torch.stack([model.encoder(x[:, :, i])[0] for i in range(depth)], dim=2) 123 124 features_dim = features.shape[1] 125 126 binom = torch.distributions.binomial.Binomial(probs=0.5) 127 128 dropout_mask1 = binom.sample((int(batch_size/2), features_dim)).to(input_.device) * 2.0 129 if len(original_shape) == 2: 130 dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1) 131 if len(original_shape) == 3: 132 dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 133 134 dropout_mask2 = 2.0 - dropout_mask1 135 dropout_mask = torch.cat([dropout_mask1, dropout_mask2]) 136 137 # NOTE: in the UniMatch v2 code some samples of the batch stay unchanged! 138 # Keep some samples unchanged (code block not tested) 139 # dropout_prob = 0.5 140 # num_kept = int(batch_size * (1 - dropout_prob)) 141 # kept_indexes = torch.randperm(batch_size, device=input_.device)[:num_kept] 142 143 # dropout_mask1[kept_indexes, :] = 1.0 144 # dropout_mask2[kept_indexes, :] = 1.0 145 146 dropped_features = features * dropout_mask 147 148 pred = self.unetr_decoder_prediction(model, dropped_features, input_shape, original_shape) 149 150 return pred
Inherited Members
- torch_em.self_training.mean_teacher.MeanTeacherTrainer
- sampler
- unsupervised_train_loader
- supervised_train_loader
- supervised_val_loader
- unsupervised_val_loader
- supervised_loss_and_metric
- unsupervised_loss_and_metric
- unsupervised_loss
- supervised_loss
- pseudo_labeler
- momentum
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit