torch_em.trainer.tensorboard_logger
1import os 2import warnings 3 4import numpy as np 5import torch 6 7from elf.segmentation.embeddings import embedding_pca 8from skimage.segmentation import mark_boundaries 9from torchvision.utils import make_grid 10 11from .logger_base import TorchEmLogger 12 13# tensorboard import only works if tensobard package is available, so 14# we wrap this in a try except 15try: 16 from torch.utils.tensorboard import SummaryWriter 17except ImportError: 18 SummaryWriter = None 19 20from ..util import ensure_tensor 21from ..loss import EMBEDDING_LOSSES 22 23 24def normalize_im(im): 25 im = ensure_tensor(im, dtype=torch.float32) 26 im -= im.min() 27 im /= im.max() 28 return im 29 30 31def make_grid_image(image, y, prediction, selection, gradients=None): 32 target_image = normalize_im(y[selection].cpu()) 33 pred_image = normalize_im(prediction[selection].detach().cpu()) 34 35 if image.shape[0] > 1: 36 image = image[0:1] 37 38 n_channels = pred_image.shape[0] 39 n_channels_target = target_image.shape[0] 40 if n_channels_target == n_channels == 1: 41 nrow = 8 42 images = [image, target_image, pred_image] 43 elif n_channels_target == 1: 44 nrow = n_channels 45 images = nrow * [image] 46 images += (nrow * [target_image]) 47 images += [channel.unsqueeze(0) for channel in pred_image] 48 else: 49 nrow = n_channels 50 images = nrow * [image] 51 images += [channel.unsqueeze(0) for channel in target_image] 52 images += [channel.unsqueeze(0) for channel in pred_image] 53 54 if gradients is not None: 55 grad_image = normalize_im(gradients[selection].cpu()) 56 if n_channels == 1: 57 images.append(grad_image) 58 else: 59 images += [channel.unsqueeze(0) for channel in grad_image] 60 61 im = make_grid(images, nrow=nrow, padding=4) 62 name = "raw_targets_predictions" 63 if gradients is not None: 64 name += "_gradients" 65 return im, name 66 67 68def make_embedding_image(image, y, prediction, selection, gradients=None): 69 assert gradients is None, "Not implemented" 70 image = image.numpy() 71 72 seg = y[selection].cpu().numpy() 73 seg = mark_boundaries(image[0], seg[0]) # need to get rid of singleton channel 74 seg = seg.transpose((2, 0, 1)) # to channel first 75 76 pred = prediction[selection].detach().cpu().numpy() 77 pca = embedding_pca(pred) 78 79 image = np.repeat(image, 3, axis=0) # to rgb 80 images = [torch.from_numpy(im) for im in (image, seg, pca)] 81 im = make_grid(images, padding=4) 82 name = "raw_segmentation_embedding" 83 if gradients is not None: 84 name += "_gradients" 85 return im, name 86 87 88class TensorboardLogger(TorchEmLogger): 89 def __init__(self, trainer, save_root, **unused_kwargs): 90 super().__init__(trainer, save_root) 91 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 92 os.path.join(save_root, "logs", trainer.name) 93 94 try: 95 os.makedirs(self.log_dir, exist_ok=True) 96 except PermissionError: 97 warnings.warn( 98 f"The log dir at {self.log_dir} could not be created." 99 "The most likely reason for this is that you copied the checkpoint somewhere else," 100 "so we skip this error to enable loading the model from this checkpoint." 101 ) 102 return 103 104 if SummaryWriter is None: 105 msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'" 106 raise RuntimeError(msg) 107 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 108 self.log_image_interval = trainer.log_image_interval 109 110 # derive which visualisation method is appropriate, based on the loss function 111 if type(trainer.loss) in EMBEDDING_LOSSES: 112 self.have_embeddings = True 113 self.make_image = make_embedding_image 114 else: 115 self.have_embeddings = False 116 self.make_image = make_grid_image 117 118 def log_images(self, step, x, y, prediction, name, gradients=None): 119 120 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 121 122 image = normalize_im(x[selection].cpu()) 123 self.tb.add_image(tag=f"{name}/input", 124 img_tensor=image, 125 global_step=step) 126 127 im, im_name = self.make_image(image, y, prediction, selection, gradients) 128 im_name = f"{name}/{im_name}" 129 self.tb.add_image(tag=im_name, img_tensor=im, global_step=step) 130 131 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 132 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 133 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 134 135 # the embedding visualisation function currently doesn't support gradients, 136 # so we can't log them even if log_gradients is true 137 log_grads = log_gradients 138 if self.have_embeddings: 139 log_grads = False 140 141 if step % self.log_image_interval == 0: 142 gradients = prediction.grad if log_grads else None 143 self.log_images(step, x, y, prediction, "train", gradients=gradients) 144 145 def log_validation(self, step, metric, loss, x, y, prediction): 146 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 147 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 148 self.log_images(step, x, y, prediction, "validation")
def
normalize_im(im):
def
make_grid_image(image, y, prediction, selection, gradients=None):
32def make_grid_image(image, y, prediction, selection, gradients=None): 33 target_image = normalize_im(y[selection].cpu()) 34 pred_image = normalize_im(prediction[selection].detach().cpu()) 35 36 if image.shape[0] > 1: 37 image = image[0:1] 38 39 n_channels = pred_image.shape[0] 40 n_channels_target = target_image.shape[0] 41 if n_channels_target == n_channels == 1: 42 nrow = 8 43 images = [image, target_image, pred_image] 44 elif n_channels_target == 1: 45 nrow = n_channels 46 images = nrow * [image] 47 images += (nrow * [target_image]) 48 images += [channel.unsqueeze(0) for channel in pred_image] 49 else: 50 nrow = n_channels 51 images = nrow * [image] 52 images += [channel.unsqueeze(0) for channel in target_image] 53 images += [channel.unsqueeze(0) for channel in pred_image] 54 55 if gradients is not None: 56 grad_image = normalize_im(gradients[selection].cpu()) 57 if n_channels == 1: 58 images.append(grad_image) 59 else: 60 images += [channel.unsqueeze(0) for channel in grad_image] 61 62 im = make_grid(images, nrow=nrow, padding=4) 63 name = "raw_targets_predictions" 64 if gradients is not None: 65 name += "_gradients" 66 return im, name
def
make_embedding_image(image, y, prediction, selection, gradients=None):
69def make_embedding_image(image, y, prediction, selection, gradients=None): 70 assert gradients is None, "Not implemented" 71 image = image.numpy() 72 73 seg = y[selection].cpu().numpy() 74 seg = mark_boundaries(image[0], seg[0]) # need to get rid of singleton channel 75 seg = seg.transpose((2, 0, 1)) # to channel first 76 77 pred = prediction[selection].detach().cpu().numpy() 78 pca = embedding_pca(pred) 79 80 image = np.repeat(image, 3, axis=0) # to rgb 81 images = [torch.from_numpy(im) for im in (image, seg, pca)] 82 im = make_grid(images, padding=4) 83 name = "raw_segmentation_embedding" 84 if gradients is not None: 85 name += "_gradients" 86 return im, name
89class TensorboardLogger(TorchEmLogger): 90 def __init__(self, trainer, save_root, **unused_kwargs): 91 super().__init__(trainer, save_root) 92 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 93 os.path.join(save_root, "logs", trainer.name) 94 95 try: 96 os.makedirs(self.log_dir, exist_ok=True) 97 except PermissionError: 98 warnings.warn( 99 f"The log dir at {self.log_dir} could not be created." 100 "The most likely reason for this is that you copied the checkpoint somewhere else," 101 "so we skip this error to enable loading the model from this checkpoint." 102 ) 103 return 104 105 if SummaryWriter is None: 106 msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'" 107 raise RuntimeError(msg) 108 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 109 self.log_image_interval = trainer.log_image_interval 110 111 # derive which visualisation method is appropriate, based on the loss function 112 if type(trainer.loss) in EMBEDDING_LOSSES: 113 self.have_embeddings = True 114 self.make_image = make_embedding_image 115 else: 116 self.have_embeddings = False 117 self.make_image = make_grid_image 118 119 def log_images(self, step, x, y, prediction, name, gradients=None): 120 121 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 122 123 image = normalize_im(x[selection].cpu()) 124 self.tb.add_image(tag=f"{name}/input", 125 img_tensor=image, 126 global_step=step) 127 128 im, im_name = self.make_image(image, y, prediction, selection, gradients) 129 im_name = f"{name}/{im_name}" 130 self.tb.add_image(tag=im_name, img_tensor=im, global_step=step) 131 132 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 133 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 134 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 135 136 # the embedding visualisation function currently doesn't support gradients, 137 # so we can't log them even if log_gradients is true 138 log_grads = log_gradients 139 if self.have_embeddings: 140 log_grads = False 141 142 if step % self.log_image_interval == 0: 143 gradients = prediction.grad if log_grads else None 144 self.log_images(step, x, y, prediction, "train", gradients=gradients) 145 146 def log_validation(self, step, metric, loss, x, y, prediction): 147 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 148 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 149 self.log_images(step, x, y, prediction, "validation")
TensorboardLogger(trainer, save_root, **unused_kwargs)
90 def __init__(self, trainer, save_root, **unused_kwargs): 91 super().__init__(trainer, save_root) 92 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 93 os.path.join(save_root, "logs", trainer.name) 94 95 try: 96 os.makedirs(self.log_dir, exist_ok=True) 97 except PermissionError: 98 warnings.warn( 99 f"The log dir at {self.log_dir} could not be created." 100 "The most likely reason for this is that you copied the checkpoint somewhere else," 101 "so we skip this error to enable loading the model from this checkpoint." 102 ) 103 return 104 105 if SummaryWriter is None: 106 msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'" 107 raise RuntimeError(msg) 108 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 109 self.log_image_interval = trainer.log_image_interval 110 111 # derive which visualisation method is appropriate, based on the loss function 112 if type(trainer.loss) in EMBEDDING_LOSSES: 113 self.have_embeddings = True 114 self.make_image = make_embedding_image 115 else: 116 self.have_embeddings = False 117 self.make_image = make_grid_image
def
log_images(self, step, x, y, prediction, name, gradients=None):
119 def log_images(self, step, x, y, prediction, name, gradients=None): 120 121 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 122 123 image = normalize_im(x[selection].cpu()) 124 self.tb.add_image(tag=f"{name}/input", 125 img_tensor=image, 126 global_step=step) 127 128 im, im_name = self.make_image(image, y, prediction, selection, gradients) 129 im_name = f"{name}/{im_name}" 130 self.tb.add_image(tag=im_name, img_tensor=im, global_step=step)
def
log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
132 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 133 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 134 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 135 136 # the embedding visualisation function currently doesn't support gradients, 137 # so we can't log them even if log_gradients is true 138 log_grads = log_gradients 139 if self.have_embeddings: 140 log_grads = False 141 142 if step % self.log_image_interval == 0: 143 gradients = prediction.grad if log_grads else None 144 self.log_images(step, x, y, prediction, "train", gradients=gradients)