torch_em.self_training.logger

 1import os
 2
 3import torch
 4import torch_em
 5
 6from torchvision.utils import make_grid
 7
 8
 9class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
10    def __init__(self, trainer, save_root, **unused_kwargs):
11        super().__init__(trainer, save_root)
12        self.my_root = save_root
13        self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\
14            os.path.join(self.my_root, "logs", trainer.name)
15        os.makedirs(self.log_dir, exist_ok=True)
16
17        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
18        self.log_image_interval = trainer.log_image_interval
19
20    def _add_supervised_images(self, step, name, x, y, pred):
21        if x.ndim == 5:
22            assert y.ndim == pred.ndim == 5
23            zindex = x.shape[2] // 2
24            x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex]
25
26        grid = make_grid(
27            [torch_em.transform.raw.normalize(x[0]), y[0, 0:1], pred[0, 0:1]],
28            padding=8
29        )
30        self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step)
31
32    def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter):
33        if x1.ndim == 5:
34            assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5
35            zindex = x1.shape[2] // 2
36            x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex]
37            pseudo_labels = pseudo_labels[:, :, zindex]
38            if label_filter is not None:
39                assert label_filter.ndim == 5
40                label_filter = label_filter[:, :, zindex]
41
42        images = [
43            torch_em.transform.raw.normalize(x1[0]),
44            torch_em.transform.raw.normalize(x2[0]),
45            pred[0, 0:1], pseudo_labels[0, 0:1],
46        ]
47        im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels"
48        if label_filter is not None:
49            images.append(label_filter[0, 0:1])
50            name += "-labelfilter"
51        grid = make_grid(images, nrow=2, padding=8)
52        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
53
54    def log_combined_loss(self, step, loss):
55        self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step)
56
57    def log_lr(self, step, lr):
58        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
59
60    def log_train_supervised(self, step, loss, x, y, pred):
61        self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step)
62        if step % self.log_image_interval == 0:
63            self._add_supervised_images(step, "validation", x, y, pred)
64
65    def log_validation_supervised(self, step, metric, loss, x, y, pred):
66        self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step)
67        self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step)
68        self._add_supervised_images(step, "validation", x, y, pred)
69
70    def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None):
71        self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step)
72        if step % self.log_image_interval == 0:
73            self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
74
75    def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None):
76        self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step)
77        self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step)
78        self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
79
80    def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None):
81        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
82        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
83        if gt_metric is not None:
84            self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)
class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
10class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
11    def __init__(self, trainer, save_root, **unused_kwargs):
12        super().__init__(trainer, save_root)
13        self.my_root = save_root
14        self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\
15            os.path.join(self.my_root, "logs", trainer.name)
16        os.makedirs(self.log_dir, exist_ok=True)
17
18        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
19        self.log_image_interval = trainer.log_image_interval
20
21    def _add_supervised_images(self, step, name, x, y, pred):
22        if x.ndim == 5:
23            assert y.ndim == pred.ndim == 5
24            zindex = x.shape[2] // 2
25            x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex]
26
27        grid = make_grid(
28            [torch_em.transform.raw.normalize(x[0]), y[0, 0:1], pred[0, 0:1]],
29            padding=8
30        )
31        self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step)
32
33    def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter):
34        if x1.ndim == 5:
35            assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5
36            zindex = x1.shape[2] // 2
37            x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex]
38            pseudo_labels = pseudo_labels[:, :, zindex]
39            if label_filter is not None:
40                assert label_filter.ndim == 5
41                label_filter = label_filter[:, :, zindex]
42
43        images = [
44            torch_em.transform.raw.normalize(x1[0]),
45            torch_em.transform.raw.normalize(x2[0]),
46            pred[0, 0:1], pseudo_labels[0, 0:1],
47        ]
48        im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels"
49        if label_filter is not None:
50            images.append(label_filter[0, 0:1])
51            name += "-labelfilter"
52        grid = make_grid(images, nrow=2, padding=8)
53        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
54
55    def log_combined_loss(self, step, loss):
56        self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step)
57
58    def log_lr(self, step, lr):
59        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
60
61    def log_train_supervised(self, step, loss, x, y, pred):
62        self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step)
63        if step % self.log_image_interval == 0:
64            self._add_supervised_images(step, "validation", x, y, pred)
65
66    def log_validation_supervised(self, step, metric, loss, x, y, pred):
67        self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step)
68        self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step)
69        self._add_supervised_images(step, "validation", x, y, pred)
70
71    def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None):
72        self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step)
73        if step % self.log_image_interval == 0:
74            self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
75
76    def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None):
77        self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step)
78        self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step)
79        self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
80
81    def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None):
82        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
83        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
84        if gt_metric is not None:
85            self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)
SelfTrainingTensorboardLogger(trainer, save_root, **unused_kwargs)
11    def __init__(self, trainer, save_root, **unused_kwargs):
12        super().__init__(trainer, save_root)
13        self.my_root = save_root
14        self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\
15            os.path.join(self.my_root, "logs", trainer.name)
16        os.makedirs(self.log_dir, exist_ok=True)
17
18        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
19        self.log_image_interval = trainer.log_image_interval
my_root
log_dir
tb
log_image_interval
def log_combined_loss(self, step, loss):
55    def log_combined_loss(self, step, loss):
56        self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step)
def log_lr(self, step, lr):
58    def log_lr(self, step, lr):
59        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
def log_train_supervised(self, step, loss, x, y, pred):
61    def log_train_supervised(self, step, loss, x, y, pred):
62        self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step)
63        if step % self.log_image_interval == 0:
64            self._add_supervised_images(step, "validation", x, y, pred)
def log_validation_supervised(self, step, metric, loss, x, y, pred):
66    def log_validation_supervised(self, step, metric, loss, x, y, pred):
67        self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step)
68        self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step)
69        self._add_supervised_images(step, "validation", x, y, pred)
def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None):
71    def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None):
72        self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step)
73        if step % self.log_image_interval == 0:
74            self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
def log_validation_unsupervised( self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None):
76    def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None):
77        self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step)
78        self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step)
79        self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
def log_validation( self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None):
81    def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None):
82        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
83        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
84        if gt_metric is not None:
85            self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)