torch_em.trainer.tensorboard_logger
1import os 2import warnings 3 4import numpy as np 5import torch 6 7from elf.segmentation.embeddings import embedding_pca 8from skimage.segmentation import mark_boundaries 9from torchvision.utils import make_grid 10 11from .logger_base import TorchEmLogger 12 13# tensorboard import only works if tensobard package is available, so 14# we wrap this in a try except 15try: 16 from torch.utils.tensorboard import SummaryWriter 17except ImportError: 18 SummaryWriter = None 19 20from ..util import ensure_tensor 21from ..loss import EMBEDDING_LOSSES 22 23 24def normalize_im(im): 25 """@private 26 """ 27 im = ensure_tensor(im, dtype=torch.float32) 28 im -= im.min() 29 im /= im.max() 30 return im 31 32 33def make_grid_image(image, y, prediction, selection, gradients=None): 34 """@private 35 """ 36 target_image = normalize_im(y[selection].cpu()) 37 pred_image = normalize_im(prediction[selection].detach().cpu()) 38 39 if image.shape[0] > 1: 40 image = image[0:1] 41 42 n_channels = pred_image.shape[0] 43 n_channels_target = target_image.shape[0] 44 if n_channels_target == n_channels == 1: 45 nrow = 8 46 images = [image, target_image, pred_image] 47 elif n_channels_target == 1: 48 nrow = n_channels 49 images = nrow * [image] 50 images += (nrow * [target_image]) 51 images += [channel.unsqueeze(0) for channel in pred_image] 52 else: 53 nrow = n_channels 54 images = nrow * [image] 55 images += [channel.unsqueeze(0) for channel in target_image] 56 images += [channel.unsqueeze(0) for channel in pred_image] 57 58 if gradients is not None: 59 grad_image = normalize_im(gradients[selection].cpu()) 60 if n_channels == 1: 61 images.append(grad_image) 62 else: 63 images += [channel.unsqueeze(0) for channel in grad_image] 64 65 im = make_grid(images, nrow=nrow, padding=4) 66 name = "raw_targets_predictions" 67 if gradients is not None: 68 name += "_gradients" 69 return im, name 70 71 72def make_embedding_image(image, y, prediction, selection, gradients=None): 73 """@private 74 """ 75 assert gradients is None, "Not implemented" 76 image = image.numpy() 77 78 seg = y[selection].cpu().numpy() 79 seg = mark_boundaries(image[0], seg[0]) # need to get rid of singleton channel 80 seg = seg.transpose((2, 0, 1)) # to channel first 81 82 pred = prediction[selection].detach().cpu().numpy() 83 pca = embedding_pca(pred) 84 85 image = np.repeat(image, 3, axis=0) # to rgb 86 images = [torch.from_numpy(im) for im in (image, seg, pca)] 87 im = make_grid(images, padding=4) 88 name = "raw_segmentation_embedding" 89 if gradients is not None: 90 name += "_gradients" 91 return im, name 92 93 94class TensorboardLogger(TorchEmLogger): 95 """Logger to write training progress to tensorboard. 96 97 Args: 98 trainer: The instantiated trainer. 99 save_root: The root directury for writing checkpoints and log files. 100 """ 101 def __init__(self, trainer, save_root: str, **unused_kwargs): 102 super().__init__(trainer, save_root) 103 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 104 os.path.join(save_root, "logs", trainer.name) 105 106 try: 107 os.makedirs(self.log_dir, exist_ok=True) 108 except PermissionError: 109 warnings.warn( 110 f"The log dir at {self.log_dir} could not be created." 111 "The most likely reason for this is that you copied the checkpoint somewhere else," 112 "so we skip this error to enable loading the model from this checkpoint." 113 ) 114 return 115 116 if SummaryWriter is None: 117 msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'" 118 raise RuntimeError(msg) 119 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 120 self.log_image_interval = trainer.log_image_interval 121 122 # derive which visualisation method is appropriate, based on the loss function 123 if type(trainer.loss) in EMBEDDING_LOSSES: 124 self.have_embeddings = True 125 self.make_image = make_embedding_image 126 else: 127 self.have_embeddings = False 128 self.make_image = make_grid_image 129 130 def log_images(self, step, x, y, prediction, name, gradients=None): 131 """@private 132 """ 133 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 134 image = normalize_im(x[selection].cpu()) 135 self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step) 136 137 im, im_name = self.make_image(image, y, prediction, selection, gradients) 138 im_name = f"{name}/{im_name}" 139 self.tb.add_image(tag=im_name, img_tensor=im, global_step=step) 140 141 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 142 """@private 143 """ 144 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 145 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 146 147 # the embedding visualisation function currently doesn't support gradients, 148 # so we can't log them even if log_gradients is true 149 log_grads = log_gradients 150 if self.have_embeddings: 151 log_grads = False 152 153 if step % self.log_image_interval == 0: 154 gradients = prediction.grad if log_grads else None 155 self.log_images(step, x, y, prediction, "train", gradients=gradients) 156 157 def log_validation(self, step, metric, loss, x, y, prediction): 158 """@private 159 """ 160 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 161 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 162 self.log_images(step, x, y, prediction, "validation")
95class TensorboardLogger(TorchEmLogger): 96 """Logger to write training progress to tensorboard. 97 98 Args: 99 trainer: The instantiated trainer. 100 save_root: The root directury for writing checkpoints and log files. 101 """ 102 def __init__(self, trainer, save_root: str, **unused_kwargs): 103 super().__init__(trainer, save_root) 104 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 105 os.path.join(save_root, "logs", trainer.name) 106 107 try: 108 os.makedirs(self.log_dir, exist_ok=True) 109 except PermissionError: 110 warnings.warn( 111 f"The log dir at {self.log_dir} could not be created." 112 "The most likely reason for this is that you copied the checkpoint somewhere else," 113 "so we skip this error to enable loading the model from this checkpoint." 114 ) 115 return 116 117 if SummaryWriter is None: 118 msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'" 119 raise RuntimeError(msg) 120 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 121 self.log_image_interval = trainer.log_image_interval 122 123 # derive which visualisation method is appropriate, based on the loss function 124 if type(trainer.loss) in EMBEDDING_LOSSES: 125 self.have_embeddings = True 126 self.make_image = make_embedding_image 127 else: 128 self.have_embeddings = False 129 self.make_image = make_grid_image 130 131 def log_images(self, step, x, y, prediction, name, gradients=None): 132 """@private 133 """ 134 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 135 image = normalize_im(x[selection].cpu()) 136 self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step) 137 138 im, im_name = self.make_image(image, y, prediction, selection, gradients) 139 im_name = f"{name}/{im_name}" 140 self.tb.add_image(tag=im_name, img_tensor=im, global_step=step) 141 142 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 143 """@private 144 """ 145 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 146 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 147 148 # the embedding visualisation function currently doesn't support gradients, 149 # so we can't log them even if log_gradients is true 150 log_grads = log_gradients 151 if self.have_embeddings: 152 log_grads = False 153 154 if step % self.log_image_interval == 0: 155 gradients = prediction.grad if log_grads else None 156 self.log_images(step, x, y, prediction, "train", gradients=gradients) 157 158 def log_validation(self, step, metric, loss, x, y, prediction): 159 """@private 160 """ 161 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 162 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 163 self.log_images(step, x, y, prediction, "validation")
Logger to write training progress to tensorboard.
Arguments:
- trainer: The instantiated trainer.
- save_root: The root directury for writing checkpoints and log files.
TensorboardLogger(trainer, save_root: str, **unused_kwargs)
102 def __init__(self, trainer, save_root: str, **unused_kwargs): 103 super().__init__(trainer, save_root) 104 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 105 os.path.join(save_root, "logs", trainer.name) 106 107 try: 108 os.makedirs(self.log_dir, exist_ok=True) 109 except PermissionError: 110 warnings.warn( 111 f"The log dir at {self.log_dir} could not be created." 112 "The most likely reason for this is that you copied the checkpoint somewhere else," 113 "so we skip this error to enable loading the model from this checkpoint." 114 ) 115 return 116 117 if SummaryWriter is None: 118 msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'" 119 raise RuntimeError(msg) 120 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 121 self.log_image_interval = trainer.log_image_interval 122 123 # derive which visualisation method is appropriate, based on the loss function 124 if type(trainer.loss) in EMBEDDING_LOSSES: 125 self.have_embeddings = True 126 self.make_image = make_embedding_image 127 else: 128 self.have_embeddings = False 129 self.make_image = make_grid_image