torch_em.self_training.logger

  1import os
  2
  3import torch_em
  4
  5from torchvision.utils import make_grid
  6from torch.utils.tensorboard import SummaryWriter
  7
  8
  9class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
 10    """Logger for self-training via `torch_em.self_training.FixMatch` or `torch_em.self_training.MeanTeacher`.
 11
 12    Args:
 13        trainer: The instantiated trainer class.
 14        save_root: The root directory for saving the checkpoints and logs.
 15    """
 16    def __init__(self, trainer, save_root, **unused_kwargs):
 17        super().__init__(trainer, save_root)
 18        self.my_root = save_root
 19        self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\
 20            os.path.join(self.my_root, "logs", trainer.name)
 21        os.makedirs(self.log_dir, exist_ok=True)
 22
 23        self.tb = SummaryWriter(self.log_dir)
 24        self.log_image_interval = trainer.log_image_interval
 25
 26    def _add_supervised_images(self, step, name, x, y, pred):
 27        if x.ndim == 5:
 28            assert y.ndim == pred.ndim == 5
 29            zindex = x.shape[2] // 2
 30            x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex]
 31
 32        grid = make_grid(
 33            [torch_em.transform.raw.normalize(x[0]), y[0, 0:1], pred[0, 0:1]],
 34            padding=8
 35        )
 36        self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step)
 37
 38    def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter):
 39        if x1.ndim == 5:
 40            assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5
 41            zindex = x1.shape[2] // 2
 42            x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex]
 43            pseudo_labels = pseudo_labels[:, :, zindex]
 44            if label_filter is not None:
 45                assert label_filter.ndim == 5
 46                label_filter = label_filter[:, :, zindex]
 47
 48        images = [
 49            torch_em.transform.raw.normalize(x1[0]),
 50            torch_em.transform.raw.normalize(x2[0]),
 51            pred[0, 0:1], pseudo_labels[0, 0:1],
 52        ]
 53        im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels"
 54        if label_filter is not None:
 55            images.append(label_filter[0, 0:1])
 56            name += "-labelfilter"
 57        grid = make_grid(images, nrow=2, padding=8)
 58        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
 59
 60    def log_combined_loss(self, step, loss):
 61        """@private
 62        """
 63        self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step)
 64
 65    def log_lr(self, step, lr):
 66        """@private
 67        """
 68        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
 69
 70    def log_train_supervised(self, step, loss, x, y, pred):
 71        """@private
 72        """
 73        self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step)
 74        if step % self.log_image_interval == 0:
 75            self._add_supervised_images(step, "validation", x, y, pred)
 76
 77    def log_validation_supervised(self, step, metric, loss, x, y, pred):
 78        """@private
 79        """
 80        self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step)
 81        self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step)
 82        self._add_supervised_images(step, "validation", x, y, pred)
 83
 84    def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None):
 85        """@private
 86        """
 87        self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step)
 88        if step % self.log_image_interval == 0:
 89            self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
 90
 91    def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None):
 92        """@private
 93        """
 94        self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step)
 95        self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step)
 96        self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
 97
 98    def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None):
 99        """@private
100        """
101        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
102        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
103        if gt_metric is not None:
104            self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)
class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
 10class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
 11    """Logger for self-training via `torch_em.self_training.FixMatch` or `torch_em.self_training.MeanTeacher`.
 12
 13    Args:
 14        trainer: The instantiated trainer class.
 15        save_root: The root directory for saving the checkpoints and logs.
 16    """
 17    def __init__(self, trainer, save_root, **unused_kwargs):
 18        super().__init__(trainer, save_root)
 19        self.my_root = save_root
 20        self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\
 21            os.path.join(self.my_root, "logs", trainer.name)
 22        os.makedirs(self.log_dir, exist_ok=True)
 23
 24        self.tb = SummaryWriter(self.log_dir)
 25        self.log_image_interval = trainer.log_image_interval
 26
 27    def _add_supervised_images(self, step, name, x, y, pred):
 28        if x.ndim == 5:
 29            assert y.ndim == pred.ndim == 5
 30            zindex = x.shape[2] // 2
 31            x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex]
 32
 33        grid = make_grid(
 34            [torch_em.transform.raw.normalize(x[0]), y[0, 0:1], pred[0, 0:1]],
 35            padding=8
 36        )
 37        self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step)
 38
 39    def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter):
 40        if x1.ndim == 5:
 41            assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5
 42            zindex = x1.shape[2] // 2
 43            x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex]
 44            pseudo_labels = pseudo_labels[:, :, zindex]
 45            if label_filter is not None:
 46                assert label_filter.ndim == 5
 47                label_filter = label_filter[:, :, zindex]
 48
 49        images = [
 50            torch_em.transform.raw.normalize(x1[0]),
 51            torch_em.transform.raw.normalize(x2[0]),
 52            pred[0, 0:1], pseudo_labels[0, 0:1],
 53        ]
 54        im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels"
 55        if label_filter is not None:
 56            images.append(label_filter[0, 0:1])
 57            name += "-labelfilter"
 58        grid = make_grid(images, nrow=2, padding=8)
 59        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
 60
 61    def log_combined_loss(self, step, loss):
 62        """@private
 63        """
 64        self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step)
 65
 66    def log_lr(self, step, lr):
 67        """@private
 68        """
 69        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
 70
 71    def log_train_supervised(self, step, loss, x, y, pred):
 72        """@private
 73        """
 74        self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step)
 75        if step % self.log_image_interval == 0:
 76            self._add_supervised_images(step, "validation", x, y, pred)
 77
 78    def log_validation_supervised(self, step, metric, loss, x, y, pred):
 79        """@private
 80        """
 81        self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step)
 82        self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step)
 83        self._add_supervised_images(step, "validation", x, y, pred)
 84
 85    def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None):
 86        """@private
 87        """
 88        self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step)
 89        if step % self.log_image_interval == 0:
 90            self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
 91
 92    def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None):
 93        """@private
 94        """
 95        self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step)
 96        self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step)
 97        self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
 98
 99    def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None):
100        """@private
101        """
102        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
103        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
104        if gt_metric is not None:
105            self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)

Logger for self-training via torch_em.self_training.FixMatch or torch_em.self_training.MeanTeacher.

Arguments:
  • trainer: The instantiated trainer class.
  • save_root: The root directory for saving the checkpoints and logs.
SelfTrainingTensorboardLogger(trainer, save_root, **unused_kwargs)
17    def __init__(self, trainer, save_root, **unused_kwargs):
18        super().__init__(trainer, save_root)
19        self.my_root = save_root
20        self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\
21            os.path.join(self.my_root, "logs", trainer.name)
22        os.makedirs(self.log_dir, exist_ok=True)
23
24        self.tb = SummaryWriter(self.log_dir)
25        self.log_image_interval = trainer.log_image_interval
my_root
log_dir
tb
log_image_interval