torch_em.self_training.logger
1import os 2 3import torch 4import torch_em 5 6from torchvision.utils import make_grid 7 8 9class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger): 10 def __init__(self, trainer, save_root, **unused_kwargs): 11 super().__init__(trainer, save_root) 12 self.my_root = save_root 13 self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ 14 os.path.join(self.my_root, "logs", trainer.name) 15 os.makedirs(self.log_dir, exist_ok=True) 16 17 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 18 self.log_image_interval = trainer.log_image_interval 19 20 def _add_supervised_images(self, step, name, x, y, pred): 21 if x.ndim == 5: 22 assert y.ndim == pred.ndim == 5 23 zindex = x.shape[2] // 2 24 x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex] 25 26 grid = make_grid( 27 [torch_em.transform.raw.normalize(x[0]), y[0, 0:1], pred[0, 0:1]], 28 padding=8 29 ) 30 self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step) 31 32 def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter): 33 if x1.ndim == 5: 34 assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5 35 zindex = x1.shape[2] // 2 36 x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex] 37 pseudo_labels = pseudo_labels[:, :, zindex] 38 if label_filter is not None: 39 assert label_filter.ndim == 5 40 label_filter = label_filter[:, :, zindex] 41 42 images = [ 43 torch_em.transform.raw.normalize(x1[0]), 44 torch_em.transform.raw.normalize(x2[0]), 45 pred[0, 0:1], pseudo_labels[0, 0:1], 46 ] 47 im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels" 48 if label_filter is not None: 49 images.append(label_filter[0, 0:1]) 50 name += "-labelfilter" 51 grid = make_grid(images, nrow=2, padding=8) 52 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 53 54 def log_combined_loss(self, step, loss): 55 self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step) 56 57 def log_lr(self, step, lr): 58 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 59 60 def log_train_supervised(self, step, loss, x, y, pred): 61 self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step) 62 if step % self.log_image_interval == 0: 63 self._add_supervised_images(step, "validation", x, y, pred) 64 65 def log_validation_supervised(self, step, metric, loss, x, y, pred): 66 self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step) 67 self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step) 68 self._add_supervised_images(step, "validation", x, y, pred) 69 70 def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None): 71 self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step) 72 if step % self.log_image_interval == 0: 73 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) 74 75 def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None): 76 self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step) 77 self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step) 78 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) 79 80 def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None): 81 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 82 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 83 if gt_metric is not None: 84 self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)
10class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger): 11 def __init__(self, trainer, save_root, **unused_kwargs): 12 super().__init__(trainer, save_root) 13 self.my_root = save_root 14 self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ 15 os.path.join(self.my_root, "logs", trainer.name) 16 os.makedirs(self.log_dir, exist_ok=True) 17 18 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 19 self.log_image_interval = trainer.log_image_interval 20 21 def _add_supervised_images(self, step, name, x, y, pred): 22 if x.ndim == 5: 23 assert y.ndim == pred.ndim == 5 24 zindex = x.shape[2] // 2 25 x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex] 26 27 grid = make_grid( 28 [torch_em.transform.raw.normalize(x[0]), y[0, 0:1], pred[0, 0:1]], 29 padding=8 30 ) 31 self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step) 32 33 def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter): 34 if x1.ndim == 5: 35 assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5 36 zindex = x1.shape[2] // 2 37 x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex] 38 pseudo_labels = pseudo_labels[:, :, zindex] 39 if label_filter is not None: 40 assert label_filter.ndim == 5 41 label_filter = label_filter[:, :, zindex] 42 43 images = [ 44 torch_em.transform.raw.normalize(x1[0]), 45 torch_em.transform.raw.normalize(x2[0]), 46 pred[0, 0:1], pseudo_labels[0, 0:1], 47 ] 48 im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels" 49 if label_filter is not None: 50 images.append(label_filter[0, 0:1]) 51 name += "-labelfilter" 52 grid = make_grid(images, nrow=2, padding=8) 53 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 54 55 def log_combined_loss(self, step, loss): 56 self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step) 57 58 def log_lr(self, step, lr): 59 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 60 61 def log_train_supervised(self, step, loss, x, y, pred): 62 self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step) 63 if step % self.log_image_interval == 0: 64 self._add_supervised_images(step, "validation", x, y, pred) 65 66 def log_validation_supervised(self, step, metric, loss, x, y, pred): 67 self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step) 68 self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step) 69 self._add_supervised_images(step, "validation", x, y, pred) 70 71 def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None): 72 self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step) 73 if step % self.log_image_interval == 0: 74 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) 75 76 def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None): 77 self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step) 78 self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step) 79 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) 80 81 def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None): 82 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 83 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 84 if gt_metric is not None: 85 self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)
SelfTrainingTensorboardLogger(trainer, save_root, **unused_kwargs)
11 def __init__(self, trainer, save_root, **unused_kwargs): 12 super().__init__(trainer, save_root) 13 self.my_root = save_root 14 self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ 15 os.path.join(self.my_root, "logs", trainer.name) 16 os.makedirs(self.log_dir, exist_ok=True) 17 18 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 19 self.log_image_interval = trainer.log_image_interval
def
log_validation_supervised(self, step, metric, loss, x, y, pred):
66 def log_validation_supervised(self, step, metric, loss, x, y, pred): 67 self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step) 68 self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step) 69 self._add_supervised_images(step, "validation", x, y, pred)
def
log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None):
71 def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None): 72 self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step) 73 if step % self.log_image_interval == 0: 74 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
def
log_validation_unsupervised( self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None):
76 def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None): 77 self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step) 78 self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step) 79 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
def
log_validation( self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None):
81 def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None): 82 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 83 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 84 if gt_metric is not None: 85 self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)