torch_em.self_training.logger
1import os 2 3import torch_em 4 5from torchvision.utils import make_grid 6from torch.utils.tensorboard import SummaryWriter 7 8 9class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger): 10 """Logger for self-training via `torch_em.self_training.FixMatch` or `torch_em.self_training.MeanTeacher`. 11 12 Args: 13 trainer: The instantiated trainer class. 14 save_root: The root directory for saving the checkpoints and logs. 15 """ 16 def __init__(self, trainer, save_root, **unused_kwargs): 17 super().__init__(trainer, save_root) 18 self.my_root = save_root 19 self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ 20 os.path.join(self.my_root, "logs", trainer.name) 21 os.makedirs(self.log_dir, exist_ok=True) 22 23 self.tb = SummaryWriter(self.log_dir) 24 self.log_image_interval = trainer.log_image_interval 25 26 def _add_supervised_images(self, step, name, x, y, pred): 27 if x.ndim == 5: 28 assert y.ndim == pred.ndim == 5 29 zindex = x.shape[2] // 2 30 x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex] 31 32 grid = make_grid( 33 [torch_em.transform.raw.normalize(x[0]), y[0, 0:1], pred[0, 0:1]], 34 padding=8 35 ) 36 self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step) 37 38 def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter): 39 if x1.ndim == 5: 40 assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5 41 zindex = x1.shape[2] // 2 42 x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex] 43 pseudo_labels = pseudo_labels[:, :, zindex] 44 if label_filter is not None: 45 assert label_filter.ndim == 5 46 label_filter = label_filter[:, :, zindex] 47 48 images = [ 49 torch_em.transform.raw.normalize(x1[0]), 50 torch_em.transform.raw.normalize(x2[0]), 51 pred[0, 0:1], pseudo_labels[0, 0:1], 52 ] 53 im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels" 54 if label_filter is not None: 55 images.append(label_filter[0, 0:1]) 56 name += "-labelfilter" 57 grid = make_grid(images, nrow=2, padding=8) 58 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 59 60 def log_combined_loss(self, step, loss): 61 """@private 62 """ 63 self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step) 64 65 def log_lr(self, step, lr): 66 """@private 67 """ 68 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 69 70 def log_train_supervised(self, step, loss, x, y, pred): 71 """@private 72 """ 73 self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step) 74 if step % self.log_image_interval == 0: 75 self._add_supervised_images(step, "validation", x, y, pred) 76 77 def log_validation_supervised(self, step, metric, loss, x, y, pred): 78 """@private 79 """ 80 self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step) 81 self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step) 82 self._add_supervised_images(step, "validation", x, y, pred) 83 84 def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None): 85 """@private 86 """ 87 self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step) 88 if step % self.log_image_interval == 0: 89 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) 90 91 def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None): 92 """@private 93 """ 94 self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step) 95 self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step) 96 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) 97 98 def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None): 99 """@private 100 """ 101 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 102 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 103 if gt_metric is not None: 104 self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)
10class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger): 11 """Logger for self-training via `torch_em.self_training.FixMatch` or `torch_em.self_training.MeanTeacher`. 12 13 Args: 14 trainer: The instantiated trainer class. 15 save_root: The root directory for saving the checkpoints and logs. 16 """ 17 def __init__(self, trainer, save_root, **unused_kwargs): 18 super().__init__(trainer, save_root) 19 self.my_root = save_root 20 self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ 21 os.path.join(self.my_root, "logs", trainer.name) 22 os.makedirs(self.log_dir, exist_ok=True) 23 24 self.tb = SummaryWriter(self.log_dir) 25 self.log_image_interval = trainer.log_image_interval 26 27 def _add_supervised_images(self, step, name, x, y, pred): 28 if x.ndim == 5: 29 assert y.ndim == pred.ndim == 5 30 zindex = x.shape[2] // 2 31 x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex] 32 33 grid = make_grid( 34 [torch_em.transform.raw.normalize(x[0]), y[0, 0:1], pred[0, 0:1]], 35 padding=8 36 ) 37 self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step) 38 39 def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter): 40 if x1.ndim == 5: 41 assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5 42 zindex = x1.shape[2] // 2 43 x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex] 44 pseudo_labels = pseudo_labels[:, :, zindex] 45 if label_filter is not None: 46 assert label_filter.ndim == 5 47 label_filter = label_filter[:, :, zindex] 48 49 images = [ 50 torch_em.transform.raw.normalize(x1[0]), 51 torch_em.transform.raw.normalize(x2[0]), 52 pred[0, 0:1], pseudo_labels[0, 0:1], 53 ] 54 im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels" 55 if label_filter is not None: 56 images.append(label_filter[0, 0:1]) 57 name += "-labelfilter" 58 grid = make_grid(images, nrow=2, padding=8) 59 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 60 61 def log_combined_loss(self, step, loss): 62 """@private 63 """ 64 self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step) 65 66 def log_lr(self, step, lr): 67 """@private 68 """ 69 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 70 71 def log_train_supervised(self, step, loss, x, y, pred): 72 """@private 73 """ 74 self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step) 75 if step % self.log_image_interval == 0: 76 self._add_supervised_images(step, "validation", x, y, pred) 77 78 def log_validation_supervised(self, step, metric, loss, x, y, pred): 79 """@private 80 """ 81 self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step) 82 self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step) 83 self._add_supervised_images(step, "validation", x, y, pred) 84 85 def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None): 86 """@private 87 """ 88 self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step) 89 if step % self.log_image_interval == 0: 90 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) 91 92 def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None): 93 """@private 94 """ 95 self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step) 96 self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step) 97 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) 98 99 def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None): 100 """@private 101 """ 102 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 103 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 104 if gt_metric is not None: 105 self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)
Logger for self-training via torch_em.self_training.FixMatch
or torch_em.self_training.MeanTeacher
.
Arguments:
- trainer: The instantiated trainer class.
- save_root: The root directory for saving the checkpoints and logs.
SelfTrainingTensorboardLogger(trainer, save_root, **unused_kwargs)
17 def __init__(self, trainer, save_root, **unused_kwargs): 18 super().__init__(trainer, save_root) 19 self.my_root = save_root 20 self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ 21 os.path.join(self.my_root, "logs", trainer.name) 22 os.makedirs(self.log_dir, exist_ok=True) 23 24 self.tb = SummaryWriter(self.log_dir) 25 self.log_image_interval = trainer.log_image_interval