torch_em.self_training.uni_match_v2

  1import time
  2
  3import torch
  4from torch_em.self_training.mean_teacher import MeanTeacherTrainerWithInvertibleAugmentations
  5
  6# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
  7
  8
  9class UniMatchv2Trainer(MeanTeacherTrainerWithInvertibleAugmentations):
 10    """Trainer for semi-supervised learning and domain adaptation following the UniMatch v2 framework.
 11
 12    UniMatch v2 was introduced by Yang et al. in https://arxiv.org/abs/2410.10777v2.
 13    It uses a teacher model derived from the student model via EMA of weights to predict
 14    pseudo-labels on unlabeled data. Three augmented views are generated per sample - one weak
 15    (for the teacher) and two strong (for the student) - and the student loss is computed as the
 16    average over both strong-view predictions against the shared weak-view pseudo-label.
 17    We support two training strategies:
 18    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 19    - Training only on the unsupervised data.
 20
 21    This class expects the following data loaders:
 22    - unsupervised_train_loader: Returns a single (raw) input per sample. The trainer applies
 23      weak and two strong augmentations internally via the augmenter.
 24    - supervised_train_loader (optional): Returns input and labels.
 25    - unsupervised_val_loader (optional): Same format as unsupervised_train_loader.
 26    - supervised_val_loader (optional): Same format as supervised_train_loader.
 27    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 28
 29    The augmenter must be a `UniMatchv2Augmenters` instance providing three invertible transforms:
 30    `.weak` for the teacher view, `.strong1` and `.strong2` for the two student views. The
 31    corresponding inverse transforms map predictions and pseudo-labels back into a shared
 32    reference frame before the loss is computed.
 33
 34    The following arguments can be used to customize the pseudo labeling:
 35    - pseudo_labeler: to compute the pseudo-labels
 36        - Parameters: teacher, teacher_input
 37        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 38    - unsupervised_loss: the loss between stacked student predictions and pseudo-labels
 39        - Parameters: prediction (stacked [pred_s1_inv, pred_s2_inv]), pseudo_labels, label_filter, pred_dim
 40        - Returns: loss
 41    - supervised_loss (optional): the supervised loss function
 42        - Parameters: prediction, labels
 43        - Returns: loss
 44    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 45        - Parameters: prediction (stacked), pseudo_labels, label_filter, pred_dim
 46        - Returns: loss, metric
 47    - supervised_loss_and_metric (optional): the supervised loss function and metric
 48        - Parameters: prediction, labels
 49        - Returns: loss, metric
 50    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 51
 52    Note: adjust the batch size of the 'unsupervised_train_loader' relative to
 53    'supervised_train_loader' to control the ratio of supervised to unsupervised training samples.
 54
 55    Args:
 56        model: The model to be trained.
 57        unsupervised_train_loader: The loader for unsupervised training (returns raw inputs only).
 58        unsupervised_loss: The loss for unsupervised training.
 59        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 60        augmenter: `UniMatchv2Augmenters` instance providing `.weak`, `.strong1`, and `.strong2`
 61            invertible transforms with corresponding inverse transforms.
 62        complementary_dropout: If True, applies complementary feature dropout to the encoder
 63            features before decoding, creating two complementary student views. Requires a
 64            UNETR-compatible model architecture.
 65        supervised_train_loader: The loader for supervised training.
 66        supervised_loss: The loss for supervised training.
 67        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 68        supervised_loss_and_metric: The loss and metric for supervised training.
 69        logger: The logger. Defaults to `UniMatchv2TensorboardLogger`.
 70        momentum: The momentum value for the exponential moving weight average of the teacher model.
 71        reinit_teacher: Whether to reinit the teacher model before starting the training.
 72        sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
 73        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 74    """
 75
 76    def __init__(
 77        self, complementary_dropout, **kwargs
 78    ):
 79        super().__init__(**kwargs)
 80        self.complementary_dropout = complementary_dropout
 81
 82        self.teacher.eval()
 83
 84    def unetr_decoder_prediction(self, model, features, input_shape, original_shape):
 85
 86        z9 = model.deconv1(features)
 87        z6 = model.deconv2(z9)
 88        z3 = model.deconv3(z6)
 89        z0 = model.deconv4(z3)
 90
 91        updated_from_encoder = [z9, z6, z3]
 92
 93        x = model.base(features)
 94        x = model.decoder(x, encoder_inputs=updated_from_encoder)
 95        x = model.deconv_out(x)
 96
 97        x = torch.cat([x, z0], dim=1)
 98        x = model.decoder_head(x)
 99
100        x = model.out_conv(x)
101        if model.final_activation is not None:
102            x = model.final_activation(x)
103
104        x = model.postprocess_masks(x, input_shape, original_shape)
105        return x
106
107    def predict_with_comp_drop(self, model, input_):
108        batch_size = input_.shape[0]
109        original_shape = input_.shape[2:]
110
111        x, input_shape = model.preprocess(input_)
112
113        if len(original_shape) == 2:
114            encoder_output = model.encoder(x)
115            if isinstance(encoder_output[-1], list):
116                features, _ = encoder_output
117            else:
118                features = encoder_output
119        if len(original_shape) == 3:
120            depth = input_.shape[-3]
121            features = torch.stack([model.encoder(x[:, :, i])[0] for i in range(depth)], dim=2)
122
123        features_dim = features.shape[1]
124
125        binom = torch.distributions.binomial.Binomial(probs=0.5)
126
127        dropout_mask1 = binom.sample((int(batch_size/2), features_dim)).to(input_.device) * 2.0
128        if len(original_shape) == 2:
129            dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1)
130        if len(original_shape) == 3:
131            dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
132
133        dropout_mask2 = 2.0 - dropout_mask1
134        dropout_mask = torch.cat([dropout_mask1, dropout_mask2])
135
136        # NOTE: in the UniMatch v2 code some samples of the batch stay unchanged!
137        # Keep some samples unchanged (code block not tested)
138        # dropout_prob = 0.5
139        # num_kept = int(batch_size * (1 - dropout_prob))
140        # kept_indexes = torch.randperm(batch_size, device=input_.device)[:num_kept]
141
142        # dropout_mask1[kept_indexes, :] = 1.0
143        # dropout_mask2[kept_indexes, :] = 1.0
144
145        dropped_features = features * dropout_mask
146
147        pred = self.unetr_decoder_prediction(model, dropped_features, input_shape, original_shape)
148
149        return pred
150
151    def _train_epoch_unsupervised(
152        self, progress, forward_context, backprop
153    ):
154        self.model.train()
155
156        n_iter = 0
157        t_per_iter = time.time()
158
159        for x_u in self.unsupervised_train_loader:
160            self.augmenter.reset_all()
161            x_u = x_u.to(self.device, non_blocking=True)
162
163            x_u_w = self.augmenter.weak.transform(x_u)
164            x_u_s1, x_u_s2 = self.augmenter.strong1.transform(x_u), self.augmenter.strong2.transform(x_u)
165
166            # Compute the pseudo labels (unsupervised teacher prediction)
167            with forward_context(), torch.no_grad():
168                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_u_w)
169                pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels)
170                label_filter_inv = (
171                    self.augmenter.weak.reverse_transform(label_filter)
172                    if label_filter is not None else None
173                )
174
175            # Perform unsupervised training
176            with forward_context():
177                if self.complementary_dropout:
178                    pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_u_s1, x_u_s2))).chunk(2)
179                else:
180                    pred_s1, pred_s2 = self.model(torch.cat((x_u_s1, x_u_s2))).chunk(2)
181                pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1)
182                pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2)
183                unsupervised_loss = self.unsupervised_loss(
184                    torch.stack((pred_s1_inv, pred_s2_inv)),
185                    pseudo_labels_inv,
186                    label_filter_inv,
187                    pred_dim=2,
188                )
189
190            backprop(unsupervised_loss)
191
192            if self.logger is not None:
193                self.logger.log_train_unsupervised(
194                    self._iteration,
195                    unsupervised_loss,
196                    x_u,
197                    pred_s1_inv,
198                    pred_s2_inv,
199                    pseudo_labels_inv,
200                    label_filter_inv,
201                )
202                self.logger.log_train_augmentations(
203                    self._iteration,
204                    x_u_w,
205                    x_u_s1,
206                    x_u_s2,
207                    pseudo_labels,
208                    pred_s1,
209                    pred_s2,
210                )
211
212                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
213                self.logger.log_lr(self._iteration, lr)
214                if self.pseudo_labeler.confidence_threshold is not None:
215                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
216
217            with torch.no_grad():
218                self._momentum_update()  # EMA update of the teacher
219
220            self._iteration += 1
221            n_iter += 1
222            if self._iteration >= self.max_iteration:
223                break
224            progress.update(1)
225
226        t_per_iter = (time.time() - t_per_iter) / n_iter
227        return t_per_iter
228
229    def _train_epoch_semisupervised(
230        self, progress, forward_context, backprop
231    ):
232        train_loader = zip(self.supervised_train_loader, self.unsupervised_train_loader)
233        self.model.train()
234
235        n_iter = 0
236        t_per_iter = time.time()
237
238        for i, ((x_s, y_s), x_u) in enumerate(train_loader):
239            self.augmenter.reset_all()
240
241            x_s, y_s = x_s.to(self.device, non_blocking=True), y_s.to(self.device, non_blocking=True)
242            x_u = x_u.to(self.device, non_blocking=True)
243
244            x_u_w = self.augmenter.weak.transform(x_u)
245            x_u_s1, x_u_s2 = self.augmenter.strong1.transform(x_u), self.augmenter.strong2.transform(x_u)
246
247            self.optimizer.zero_grad()
248            # supervised loss (supervised student prediction)
249            pred_s = self.model(x_s)
250            supervised_loss = self.supervised_loss(pred_s, y_s)
251
252            backprop(supervised_loss)
253
254            # Compute the pseudo labels (unsupervised teacher prediction)
255            with forward_context(), torch.no_grad():
256                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_u_w)
257                pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels)
258                label_filter_inv = (
259                    self.augmenter.weak.reverse_transform(label_filter)
260                    if label_filter is not None else None
261                )
262
263            # Perform unsupervised training
264            self.optimizer.zero_grad()
265            with forward_context():
266                if self.complementary_dropout:
267                    pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_u_s1, x_u_s2))).chunk(2)
268                else:
269                    pred_s1, pred_s2 = self.model(torch.cat((x_u_s1, x_u_s2))).chunk(2)
270                pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1)
271                pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2)
272                unsupervised_loss = self.unsupervised_loss(
273                    torch.stack((pred_s1_inv, pred_s2_inv)),
274                    pseudo_labels_inv,
275                    label_filter_inv,
276                    pred_dim=2,
277                )
278
279            backprop(unsupervised_loss)
280
281            if self.logger is not None:
282                self.logger.log_train_supervised(
283                    self._iteration, supervised_loss, x_s, y_s, pred_s
284                )
285                self.logger.log_train_unsupervised(
286                    self._iteration,
287                    unsupervised_loss,
288                    x_u,
289                    pred_s1_inv,
290                    pred_s2_inv,
291                    pseudo_labels_inv,
292                    label_filter_inv,
293                )
294                self.logger.log_train_augmentations(
295                    self._iteration,
296                    x_u_w,
297                    x_u_s1,
298                    x_u_s2,
299                    pseudo_labels,
300                    pred_s1,
301                    pred_s2,
302                )
303
304                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
305                self.logger.log_lr(self._iteration, lr)
306                if self.pseudo_labeler.confidence_threshold is not None:
307                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
308
309            with torch.no_grad():
310                self._momentum_update()  # EMA update of the teacher
311
312            self._iteration += 1
313            n_iter += 1
314            if self._iteration >= self.max_iteration:
315                break
316            progress.update(1)
317
318        t_per_iter = (time.time() - t_per_iter) / n_iter
319        return t_per_iter
320
321    def _validate_supervised(self, forward_context):
322        metric_val = 0.0
323        loss_val = 0.0
324
325        for x, y in self.supervised_val_loader:
326            x, y = (
327                x.to(self.device, non_blocking=True),
328                y.to(self.device, non_blocking=True)
329            )
330
331            with forward_context():
332                pred = self.model(x)
333                loss, metric = self.supervised_loss_and_metric(pred, y)
334                loss_val += loss.item()
335            metric_val += metric.item()
336
337        metric_val /= len(self.supervised_val_loader)
338        loss_val /= len(self.supervised_val_loader)
339
340        if self.logger is not None:
341            self.logger.log_validation_supervised(
342                self._iteration, metric_val, loss_val, x, y, pred
343            )
344
345        return metric_val
346
347    def _validate_unsupervised(self, forward_context):
348        metric_val = 0.0
349        loss_val = 0.0
350
351        for x in self.unsupervised_val_loader:
352            self.augmenter.reset_all()
353            x = x.to(self.device, non_blocking=True)
354
355            # apply augmentations
356            x_w = self.augmenter.weak.transform(x)
357            x_s1, x_s2 = self.augmenter.strong1.transform(x), self.augmenter.strong2.transform(x)
358
359            # Compute the pseudo labels (unsupervised teacher prediction)
360            with forward_context():
361                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_w)
362                pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels)
363                label_filter_inv = (
364                    self.augmenter.weak.reverse_transform(label_filter)
365                    if label_filter is not None else None
366                )
367
368                if self.complementary_dropout:
369                    pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_s1, x_s2))).chunk(2)
370                else:
371                    pred_s1, pred_s2 = self.model(torch.cat((x_s1, x_s2))).chunk(2)
372                pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1)
373                pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2)
374
375                loss, metric = self.unsupervised_loss_and_metric(
376                    torch.stack((pred_s1_inv, pred_s2_inv)),
377                    pseudo_labels_inv,
378                    label_filter_inv,
379                    pred_dim=2,
380                )
381            loss_val += loss.item()
382            metric_val += metric.item()
383
384        metric_val /= len(self.unsupervised_val_loader)
385        loss_val /= len(self.unsupervised_val_loader)
386
387        if self.logger is not None:
388            self.logger.log_validation_unsupervised(
389                self._iteration,
390                metric_val,
391                loss_val,
392                x,
393                pred_s1_inv,
394                pred_s2_inv,
395                pseudo_labels_inv,
396                label_filter_inv,
397            )
398
399            self.logger.log_validation_augmentations(
400                self._iteration,
401                x_w,
402                x_s1,
403                x_s2,
404                pseudo_labels,
405                pred_s1,
406                pred_s2,
407            )
408
409        self.pseudo_labeler.step(metric_val, self._epoch)
410
411        return metric_val
412
413    def _validate_impl(self, forward_context):
414        self.model.eval()
415
416        with torch.no_grad():
417
418            if self.supervised_val_loader is None:
419                supervised_metric = None
420            else:
421                supervised_metric = self._validate_supervised(forward_context)
422
423            if self.unsupervised_val_loader is None:
424                unsupervised_metric = None
425            else:
426                unsupervised_metric = self._validate_unsupervised(forward_context)
427
428        if unsupervised_metric is None:
429            metric = supervised_metric
430        elif supervised_metric is None:
431            metric = unsupervised_metric
432        else:
433            metric = (supervised_metric + unsupervised_metric) / 2
434
435        return metric
 10class UniMatchv2Trainer(MeanTeacherTrainerWithInvertibleAugmentations):
 11    """Trainer for semi-supervised learning and domain adaptation following the UniMatch v2 framework.
 12
 13    UniMatch v2 was introduced by Yang et al. in https://arxiv.org/abs/2410.10777v2.
 14    It uses a teacher model derived from the student model via EMA of weights to predict
 15    pseudo-labels on unlabeled data. Three augmented views are generated per sample - one weak
 16    (for the teacher) and two strong (for the student) - and the student loss is computed as the
 17    average over both strong-view predictions against the shared weak-view pseudo-label.
 18    We support two training strategies:
 19    - Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
 20    - Training only on the unsupervised data.
 21
 22    This class expects the following data loaders:
 23    - unsupervised_train_loader: Returns a single (raw) input per sample. The trainer applies
 24      weak and two strong augmentations internally via the augmenter.
 25    - supervised_train_loader (optional): Returns input and labels.
 26    - unsupervised_val_loader (optional): Same format as unsupervised_train_loader.
 27    - supervised_val_loader (optional): Same format as supervised_train_loader.
 28    At least one of unsupervised_val_loader and supervised_val_loader must be given.
 29
 30    The augmenter must be a `UniMatchv2Augmenters` instance providing three invertible transforms:
 31    `.weak` for the teacher view, `.strong1` and `.strong2` for the two student views. The
 32    corresponding inverse transforms map predictions and pseudo-labels back into a shared
 33    reference frame before the loss is computed.
 34
 35    The following arguments can be used to customize the pseudo labeling:
 36    - pseudo_labeler: to compute the pseudo-labels
 37        - Parameters: teacher, teacher_input
 38        - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
 39    - unsupervised_loss: the loss between stacked student predictions and pseudo-labels
 40        - Parameters: prediction (stacked [pred_s1_inv, pred_s2_inv]), pseudo_labels, label_filter, pred_dim
 41        - Returns: loss
 42    - supervised_loss (optional): the supervised loss function
 43        - Parameters: prediction, labels
 44        - Returns: loss
 45    - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
 46        - Parameters: prediction (stacked), pseudo_labels, label_filter, pred_dim
 47        - Returns: loss, metric
 48    - supervised_loss_and_metric (optional): the supervised loss function and metric
 49        - Parameters: prediction, labels
 50        - Returns: loss, metric
 51    At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.
 52
 53    Note: adjust the batch size of the 'unsupervised_train_loader' relative to
 54    'supervised_train_loader' to control the ratio of supervised to unsupervised training samples.
 55
 56    Args:
 57        model: The model to be trained.
 58        unsupervised_train_loader: The loader for unsupervised training (returns raw inputs only).
 59        unsupervised_loss: The loss for unsupervised training.
 60        pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
 61        augmenter: `UniMatchv2Augmenters` instance providing `.weak`, `.strong1`, and `.strong2`
 62            invertible transforms with corresponding inverse transforms.
 63        complementary_dropout: If True, applies complementary feature dropout to the encoder
 64            features before decoding, creating two complementary student views. Requires a
 65            UNETR-compatible model architecture.
 66        supervised_train_loader: The loader for supervised training.
 67        supervised_loss: The loss for supervised training.
 68        unsupervised_loss_and_metric: The loss and metric for unsupervised training.
 69        supervised_loss_and_metric: The loss and metric for supervised training.
 70        logger: The logger. Defaults to `UniMatchv2TensorboardLogger`.
 71        momentum: The momentum value for the exponential moving weight average of the teacher model.
 72        reinit_teacher: Whether to reinit the teacher model before starting the training.
 73        sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
 74        kwargs: Additional keyword arguments for `torch_em.trainer.DefaultTrainer`.
 75    """
 76
 77    def __init__(
 78        self, complementary_dropout, **kwargs
 79    ):
 80        super().__init__(**kwargs)
 81        self.complementary_dropout = complementary_dropout
 82
 83        self.teacher.eval()
 84
 85    def unetr_decoder_prediction(self, model, features, input_shape, original_shape):
 86
 87        z9 = model.deconv1(features)
 88        z6 = model.deconv2(z9)
 89        z3 = model.deconv3(z6)
 90        z0 = model.deconv4(z3)
 91
 92        updated_from_encoder = [z9, z6, z3]
 93
 94        x = model.base(features)
 95        x = model.decoder(x, encoder_inputs=updated_from_encoder)
 96        x = model.deconv_out(x)
 97
 98        x = torch.cat([x, z0], dim=1)
 99        x = model.decoder_head(x)
100
101        x = model.out_conv(x)
102        if model.final_activation is not None:
103            x = model.final_activation(x)
104
105        x = model.postprocess_masks(x, input_shape, original_shape)
106        return x
107
108    def predict_with_comp_drop(self, model, input_):
109        batch_size = input_.shape[0]
110        original_shape = input_.shape[2:]
111
112        x, input_shape = model.preprocess(input_)
113
114        if len(original_shape) == 2:
115            encoder_output = model.encoder(x)
116            if isinstance(encoder_output[-1], list):
117                features, _ = encoder_output
118            else:
119                features = encoder_output
120        if len(original_shape) == 3:
121            depth = input_.shape[-3]
122            features = torch.stack([model.encoder(x[:, :, i])[0] for i in range(depth)], dim=2)
123
124        features_dim = features.shape[1]
125
126        binom = torch.distributions.binomial.Binomial(probs=0.5)
127
128        dropout_mask1 = binom.sample((int(batch_size/2), features_dim)).to(input_.device) * 2.0
129        if len(original_shape) == 2:
130            dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1)
131        if len(original_shape) == 3:
132            dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
133
134        dropout_mask2 = 2.0 - dropout_mask1
135        dropout_mask = torch.cat([dropout_mask1, dropout_mask2])
136
137        # NOTE: in the UniMatch v2 code some samples of the batch stay unchanged!
138        # Keep some samples unchanged (code block not tested)
139        # dropout_prob = 0.5
140        # num_kept = int(batch_size * (1 - dropout_prob))
141        # kept_indexes = torch.randperm(batch_size, device=input_.device)[:num_kept]
142
143        # dropout_mask1[kept_indexes, :] = 1.0
144        # dropout_mask2[kept_indexes, :] = 1.0
145
146        dropped_features = features * dropout_mask
147
148        pred = self.unetr_decoder_prediction(model, dropped_features, input_shape, original_shape)
149
150        return pred
151
152    def _train_epoch_unsupervised(
153        self, progress, forward_context, backprop
154    ):
155        self.model.train()
156
157        n_iter = 0
158        t_per_iter = time.time()
159
160        for x_u in self.unsupervised_train_loader:
161            self.augmenter.reset_all()
162            x_u = x_u.to(self.device, non_blocking=True)
163
164            x_u_w = self.augmenter.weak.transform(x_u)
165            x_u_s1, x_u_s2 = self.augmenter.strong1.transform(x_u), self.augmenter.strong2.transform(x_u)
166
167            # Compute the pseudo labels (unsupervised teacher prediction)
168            with forward_context(), torch.no_grad():
169                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_u_w)
170                pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels)
171                label_filter_inv = (
172                    self.augmenter.weak.reverse_transform(label_filter)
173                    if label_filter is not None else None
174                )
175
176            # Perform unsupervised training
177            with forward_context():
178                if self.complementary_dropout:
179                    pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_u_s1, x_u_s2))).chunk(2)
180                else:
181                    pred_s1, pred_s2 = self.model(torch.cat((x_u_s1, x_u_s2))).chunk(2)
182                pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1)
183                pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2)
184                unsupervised_loss = self.unsupervised_loss(
185                    torch.stack((pred_s1_inv, pred_s2_inv)),
186                    pseudo_labels_inv,
187                    label_filter_inv,
188                    pred_dim=2,
189                )
190
191            backprop(unsupervised_loss)
192
193            if self.logger is not None:
194                self.logger.log_train_unsupervised(
195                    self._iteration,
196                    unsupervised_loss,
197                    x_u,
198                    pred_s1_inv,
199                    pred_s2_inv,
200                    pseudo_labels_inv,
201                    label_filter_inv,
202                )
203                self.logger.log_train_augmentations(
204                    self._iteration,
205                    x_u_w,
206                    x_u_s1,
207                    x_u_s2,
208                    pseudo_labels,
209                    pred_s1,
210                    pred_s2,
211                )
212
213                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
214                self.logger.log_lr(self._iteration, lr)
215                if self.pseudo_labeler.confidence_threshold is not None:
216                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
217
218            with torch.no_grad():
219                self._momentum_update()  # EMA update of the teacher
220
221            self._iteration += 1
222            n_iter += 1
223            if self._iteration >= self.max_iteration:
224                break
225            progress.update(1)
226
227        t_per_iter = (time.time() - t_per_iter) / n_iter
228        return t_per_iter
229
230    def _train_epoch_semisupervised(
231        self, progress, forward_context, backprop
232    ):
233        train_loader = zip(self.supervised_train_loader, self.unsupervised_train_loader)
234        self.model.train()
235
236        n_iter = 0
237        t_per_iter = time.time()
238
239        for i, ((x_s, y_s), x_u) in enumerate(train_loader):
240            self.augmenter.reset_all()
241
242            x_s, y_s = x_s.to(self.device, non_blocking=True), y_s.to(self.device, non_blocking=True)
243            x_u = x_u.to(self.device, non_blocking=True)
244
245            x_u_w = self.augmenter.weak.transform(x_u)
246            x_u_s1, x_u_s2 = self.augmenter.strong1.transform(x_u), self.augmenter.strong2.transform(x_u)
247
248            self.optimizer.zero_grad()
249            # supervised loss (supervised student prediction)
250            pred_s = self.model(x_s)
251            supervised_loss = self.supervised_loss(pred_s, y_s)
252
253            backprop(supervised_loss)
254
255            # Compute the pseudo labels (unsupervised teacher prediction)
256            with forward_context(), torch.no_grad():
257                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_u_w)
258                pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels)
259                label_filter_inv = (
260                    self.augmenter.weak.reverse_transform(label_filter)
261                    if label_filter is not None else None
262                )
263
264            # Perform unsupervised training
265            self.optimizer.zero_grad()
266            with forward_context():
267                if self.complementary_dropout:
268                    pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_u_s1, x_u_s2))).chunk(2)
269                else:
270                    pred_s1, pred_s2 = self.model(torch.cat((x_u_s1, x_u_s2))).chunk(2)
271                pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1)
272                pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2)
273                unsupervised_loss = self.unsupervised_loss(
274                    torch.stack((pred_s1_inv, pred_s2_inv)),
275                    pseudo_labels_inv,
276                    label_filter_inv,
277                    pred_dim=2,
278                )
279
280            backprop(unsupervised_loss)
281
282            if self.logger is not None:
283                self.logger.log_train_supervised(
284                    self._iteration, supervised_loss, x_s, y_s, pred_s
285                )
286                self.logger.log_train_unsupervised(
287                    self._iteration,
288                    unsupervised_loss,
289                    x_u,
290                    pred_s1_inv,
291                    pred_s2_inv,
292                    pseudo_labels_inv,
293                    label_filter_inv,
294                )
295                self.logger.log_train_augmentations(
296                    self._iteration,
297                    x_u_w,
298                    x_u_s1,
299                    x_u_s2,
300                    pseudo_labels,
301                    pred_s1,
302                    pred_s2,
303                )
304
305                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
306                self.logger.log_lr(self._iteration, lr)
307                if self.pseudo_labeler.confidence_threshold is not None:
308                    self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
309
310            with torch.no_grad():
311                self._momentum_update()  # EMA update of the teacher
312
313            self._iteration += 1
314            n_iter += 1
315            if self._iteration >= self.max_iteration:
316                break
317            progress.update(1)
318
319        t_per_iter = (time.time() - t_per_iter) / n_iter
320        return t_per_iter
321
322    def _validate_supervised(self, forward_context):
323        metric_val = 0.0
324        loss_val = 0.0
325
326        for x, y in self.supervised_val_loader:
327            x, y = (
328                x.to(self.device, non_blocking=True),
329                y.to(self.device, non_blocking=True)
330            )
331
332            with forward_context():
333                pred = self.model(x)
334                loss, metric = self.supervised_loss_and_metric(pred, y)
335                loss_val += loss.item()
336            metric_val += metric.item()
337
338        metric_val /= len(self.supervised_val_loader)
339        loss_val /= len(self.supervised_val_loader)
340
341        if self.logger is not None:
342            self.logger.log_validation_supervised(
343                self._iteration, metric_val, loss_val, x, y, pred
344            )
345
346        return metric_val
347
348    def _validate_unsupervised(self, forward_context):
349        metric_val = 0.0
350        loss_val = 0.0
351
352        for x in self.unsupervised_val_loader:
353            self.augmenter.reset_all()
354            x = x.to(self.device, non_blocking=True)
355
356            # apply augmentations
357            x_w = self.augmenter.weak.transform(x)
358            x_s1, x_s2 = self.augmenter.strong1.transform(x), self.augmenter.strong2.transform(x)
359
360            # Compute the pseudo labels (unsupervised teacher prediction)
361            with forward_context():
362                pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, x_w)
363                pseudo_labels_inv = self.augmenter.weak.reverse_transform(pseudo_labels)
364                label_filter_inv = (
365                    self.augmenter.weak.reverse_transform(label_filter)
366                    if label_filter is not None else None
367                )
368
369                if self.complementary_dropout:
370                    pred_s1, pred_s2 = self.predict_with_comp_drop(self.model, torch.cat((x_s1, x_s2))).chunk(2)
371                else:
372                    pred_s1, pred_s2 = self.model(torch.cat((x_s1, x_s2))).chunk(2)
373                pred_s1_inv = self.augmenter.strong1.reverse_transform(pred_s1)
374                pred_s2_inv = self.augmenter.strong2.reverse_transform(pred_s2)
375
376                loss, metric = self.unsupervised_loss_and_metric(
377                    torch.stack((pred_s1_inv, pred_s2_inv)),
378                    pseudo_labels_inv,
379                    label_filter_inv,
380                    pred_dim=2,
381                )
382            loss_val += loss.item()
383            metric_val += metric.item()
384
385        metric_val /= len(self.unsupervised_val_loader)
386        loss_val /= len(self.unsupervised_val_loader)
387
388        if self.logger is not None:
389            self.logger.log_validation_unsupervised(
390                self._iteration,
391                metric_val,
392                loss_val,
393                x,
394                pred_s1_inv,
395                pred_s2_inv,
396                pseudo_labels_inv,
397                label_filter_inv,
398            )
399
400            self.logger.log_validation_augmentations(
401                self._iteration,
402                x_w,
403                x_s1,
404                x_s2,
405                pseudo_labels,
406                pred_s1,
407                pred_s2,
408            )
409
410        self.pseudo_labeler.step(metric_val, self._epoch)
411
412        return metric_val
413
414    def _validate_impl(self, forward_context):
415        self.model.eval()
416
417        with torch.no_grad():
418
419            if self.supervised_val_loader is None:
420                supervised_metric = None
421            else:
422                supervised_metric = self._validate_supervised(forward_context)
423
424            if self.unsupervised_val_loader is None:
425                unsupervised_metric = None
426            else:
427                unsupervised_metric = self._validate_unsupervised(forward_context)
428
429        if unsupervised_metric is None:
430            metric = supervised_metric
431        elif supervised_metric is None:
432            metric = unsupervised_metric
433        else:
434            metric = (supervised_metric + unsupervised_metric) / 2
435
436        return metric

Trainer for semi-supervised learning and domain adaptation following the UniMatch v2 framework.

UniMatch v2 was introduced by Yang et al. in https://arxiv.org/abs/2410.10777v2. It uses a teacher model derived from the student model via EMA of weights to predict pseudo-labels on unlabeled data. Three augmented views are generated per sample - one weak (for the teacher) and two strong (for the student) - and the student loss is computed as the average over both strong-view predictions against the shared weak-view pseudo-label. We support two training strategies:

  • Joint training on labeled and unlabeled data (with a supervised and unsupervised loss function).
  • Training only on the unsupervised data.

This class expects the following data loaders:

  • unsupervised_train_loader: Returns a single (raw) input per sample. The trainer applies weak and two strong augmentations internally via the augmenter.
  • supervised_train_loader (optional): Returns input and labels.
  • unsupervised_val_loader (optional): Same format as unsupervised_train_loader.
  • supervised_val_loader (optional): Same format as supervised_train_loader. At least one of unsupervised_val_loader and supervised_val_loader must be given.

The augmenter must be a UniMatchv2Augmenters instance providing three invertible transforms: .weak for the teacher view, .strong1 and .strong2 for the two student views. The corresponding inverse transforms map predictions and pseudo-labels back into a shared reference frame before the loss is computed.

The following arguments can be used to customize the pseudo labeling:

  • pseudo_labeler: to compute the pseudo-labels
    • Parameters: teacher, teacher_input
    • Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None)
  • unsupervised_loss: the loss between stacked student predictions and pseudo-labels
    • Parameters: prediction (stacked [pred_s1_inv, pred_s2_inv]), pseudo_labels, label_filter, pred_dim
    • Returns: loss
  • supervised_loss (optional): the supervised loss function
    • Parameters: prediction, labels
    • Returns: loss
  • unsupervised_loss_and_metric (optional): the unsupervised loss function and metric
    • Parameters: prediction (stacked), pseudo_labels, label_filter, pred_dim
    • Returns: loss, metric
  • supervised_loss_and_metric (optional): the supervised loss function and metric
    • Parameters: prediction, labels
    • Returns: loss, metric At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given.

Note: adjust the batch size of the 'unsupervised_train_loader' relative to 'supervised_train_loader' to control the ratio of supervised to unsupervised training samples.

Arguments:
  • model: The model to be trained.
  • unsupervised_train_loader: The loader for unsupervised training (returns raw inputs only).
  • unsupervised_loss: The loss for unsupervised training.
  • pseudo_labeler: The pseudo labeler that predicts labels in unsupervised training.
  • augmenter: UniMatchv2Augmenters instance providing .weak, .strong1, and .strong2 invertible transforms with corresponding inverse transforms.
  • complementary_dropout: If True, applies complementary feature dropout to the encoder features before decoding, creating two complementary student views. Requires a UNETR-compatible model architecture.
  • supervised_train_loader: The loader for supervised training.
  • supervised_loss: The loss for supervised training.
  • unsupervised_loss_and_metric: The loss and metric for unsupervised training.
  • supervised_loss_and_metric: The loss and metric for supervised training.
  • logger: The logger. Defaults to UniMatchv2TensorboardLogger.
  • momentum: The momentum value for the exponential moving weight average of the teacher model.
  • reinit_teacher: Whether to reinit the teacher model before starting the training.
  • sampler: A sampler for rejecting pseudo-labels according to a defined criterion.
  • kwargs: Additional keyword arguments for torch_em.trainer.DefaultTrainer.
UniMatchv2Trainer(complementary_dropout, **kwargs)
77    def __init__(
78        self, complementary_dropout, **kwargs
79    ):
80        super().__init__(**kwargs)
81        self.complementary_dropout = complementary_dropout
82
83        self.teacher.eval()
complementary_dropout
def unetr_decoder_prediction(self, model, features, input_shape, original_shape):
 85    def unetr_decoder_prediction(self, model, features, input_shape, original_shape):
 86
 87        z9 = model.deconv1(features)
 88        z6 = model.deconv2(z9)
 89        z3 = model.deconv3(z6)
 90        z0 = model.deconv4(z3)
 91
 92        updated_from_encoder = [z9, z6, z3]
 93
 94        x = model.base(features)
 95        x = model.decoder(x, encoder_inputs=updated_from_encoder)
 96        x = model.deconv_out(x)
 97
 98        x = torch.cat([x, z0], dim=1)
 99        x = model.decoder_head(x)
100
101        x = model.out_conv(x)
102        if model.final_activation is not None:
103            x = model.final_activation(x)
104
105        x = model.postprocess_masks(x, input_shape, original_shape)
106        return x
def predict_with_comp_drop(self, model, input_):
108    def predict_with_comp_drop(self, model, input_):
109        batch_size = input_.shape[0]
110        original_shape = input_.shape[2:]
111
112        x, input_shape = model.preprocess(input_)
113
114        if len(original_shape) == 2:
115            encoder_output = model.encoder(x)
116            if isinstance(encoder_output[-1], list):
117                features, _ = encoder_output
118            else:
119                features = encoder_output
120        if len(original_shape) == 3:
121            depth = input_.shape[-3]
122            features = torch.stack([model.encoder(x[:, :, i])[0] for i in range(depth)], dim=2)
123
124        features_dim = features.shape[1]
125
126        binom = torch.distributions.binomial.Binomial(probs=0.5)
127
128        dropout_mask1 = binom.sample((int(batch_size/2), features_dim)).to(input_.device) * 2.0
129        if len(original_shape) == 2:
130            dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1)
131        if len(original_shape) == 3:
132            dropout_mask1 = dropout_mask1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
133
134        dropout_mask2 = 2.0 - dropout_mask1
135        dropout_mask = torch.cat([dropout_mask1, dropout_mask2])
136
137        # NOTE: in the UniMatch v2 code some samples of the batch stay unchanged!
138        # Keep some samples unchanged (code block not tested)
139        # dropout_prob = 0.5
140        # num_kept = int(batch_size * (1 - dropout_prob))
141        # kept_indexes = torch.randperm(batch_size, device=input_.device)[:num_kept]
142
143        # dropout_mask1[kept_indexes, :] = 1.0
144        # dropout_mask2[kept_indexes, :] = 1.0
145
146        dropped_features = features * dropout_mask
147
148        pred = self.unetr_decoder_prediction(model, dropped_features, input_shape, original_shape)
149
150        return pred