torch_em.trainer.tensorboard_logger

  1import os
  2import warnings
  3
  4import numpy as np
  5import torch
  6
  7from elf.segmentation.embeddings import embedding_pca
  8from skimage.segmentation import mark_boundaries
  9from torchvision.utils import make_grid
 10
 11from .logger_base import TorchEmLogger
 12
 13# tensorboard import only works if tensobard package is available, so
 14# we wrap this in a try except
 15try:
 16    from torch.utils.tensorboard import SummaryWriter
 17except ImportError:
 18    SummaryWriter = None
 19
 20from ..util import ensure_tensor
 21from ..loss import EMBEDDING_LOSSES
 22
 23
 24def normalize_im(im):
 25    """@private
 26    """
 27    im = ensure_tensor(im, dtype=torch.float32)
 28    im -= im.min()
 29    im /= im.max()
 30    return im
 31
 32
 33def make_grid_image(image, y, prediction, selection, gradients=None):
 34    """@private
 35    """
 36    target_image = normalize_im(y[selection].cpu())
 37    pred_image = normalize_im(prediction[selection].detach().cpu())
 38
 39    if image.shape[0] > 1:
 40        image = image[0:1]
 41
 42    n_channels = pred_image.shape[0]
 43    n_channels_target = target_image.shape[0]
 44    if n_channels_target == n_channels == 1:
 45        nrow = 8
 46        images = [image, target_image, pred_image]
 47    elif n_channels_target == 1:
 48        nrow = n_channels
 49        images = nrow * [image]
 50        images += (nrow * [target_image])
 51        images += [channel.unsqueeze(0) for channel in pred_image]
 52    else:
 53        nrow = n_channels
 54        images = nrow * [image]
 55        images += [channel.unsqueeze(0) for channel in target_image]
 56        images += [channel.unsqueeze(0) for channel in pred_image]
 57
 58    if gradients is not None:
 59        grad_image = normalize_im(gradients[selection].cpu())
 60        if n_channels == 1:
 61            images.append(grad_image)
 62        else:
 63            images += [channel.unsqueeze(0) for channel in grad_image]
 64
 65    im = make_grid(images, nrow=nrow, padding=4)
 66    name = "raw_targets_predictions"
 67    if gradients is not None:
 68        name += "_gradients"
 69    return im, name
 70
 71
 72def make_embedding_image(image, y, prediction, selection, gradients=None):
 73    """@private
 74    """
 75    assert gradients is None, "Not implemented"
 76    image = image.numpy()
 77
 78    seg = y[selection].cpu().numpy()
 79    seg = mark_boundaries(image[0], seg[0])  # need to get rid of singleton channel
 80    seg = seg.transpose((2, 0, 1))  # to channel first
 81
 82    pred = prediction[selection].detach().cpu().numpy()
 83    pca = embedding_pca(pred)
 84
 85    image = np.repeat(image, 3, axis=0)  # to rgb
 86    images = [torch.from_numpy(im) for im in (image, seg, pca)]
 87    im = make_grid(images, padding=4)
 88    name = "raw_segmentation_embedding"
 89    if gradients is not None:
 90        name += "_gradients"
 91    return im, name
 92
 93
 94class TensorboardLogger(TorchEmLogger):
 95    """Logger to write training progress to tensorboard.
 96
 97    Args:
 98        trainer: The instantiated trainer.
 99        save_root: The root directury for writing checkpoints and log files.
100    """
101    def __init__(self, trainer, save_root: str, **unused_kwargs):
102        super().__init__(trainer, save_root)
103        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
104            os.path.join(save_root, "logs", trainer.name)
105
106        try:
107            os.makedirs(self.log_dir, exist_ok=True)
108        except PermissionError:
109            warnings.warn(
110                f"The log dir at {self.log_dir} could not be created."
111                "The most likely reason for this is that you copied the checkpoint somewhere else,"
112                "so we skip this error to enable loading the model from this checkpoint."
113            )
114            return
115
116        if SummaryWriter is None:
117            msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'"
118            raise RuntimeError(msg)
119        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
120        self.log_image_interval = trainer.log_image_interval
121
122        # derive which visualisation method is appropriate, based on the loss function
123        if type(trainer.loss) in EMBEDDING_LOSSES:
124            self.have_embeddings = True
125            self.make_image = make_embedding_image
126        else:
127            self.have_embeddings = False
128            self.make_image = make_grid_image
129
130    def log_images(self, step, x, y, prediction, name, gradients=None):
131        """@private
132        """
133        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
134        image = normalize_im(x[selection].cpu())
135        self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step)
136
137        im, im_name = self.make_image(image, y, prediction, selection, gradients)
138        im_name = f"{name}/{im_name}"
139        self.tb.add_image(tag=im_name, img_tensor=im, global_step=step)
140
141    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
142        """@private
143        """
144        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
145        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
146
147        # the embedding visualisation function currently doesn't support gradients,
148        # so we can't log them even if log_gradients is true
149        log_grads = log_gradients
150        if self.have_embeddings:
151            log_grads = False
152
153        if step % self.log_image_interval == 0:
154            gradients = prediction.grad if log_grads else None
155            self.log_images(step, x, y, prediction, "train", gradients=gradients)
156
157    def log_validation(self, step, metric, loss, x, y, prediction):
158        """@private
159        """
160        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
161        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
162        self.log_images(step, x, y, prediction, "validation")
class TensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
 95class TensorboardLogger(TorchEmLogger):
 96    """Logger to write training progress to tensorboard.
 97
 98    Args:
 99        trainer: The instantiated trainer.
100        save_root: The root directury for writing checkpoints and log files.
101    """
102    def __init__(self, trainer, save_root: str, **unused_kwargs):
103        super().__init__(trainer, save_root)
104        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
105            os.path.join(save_root, "logs", trainer.name)
106
107        try:
108            os.makedirs(self.log_dir, exist_ok=True)
109        except PermissionError:
110            warnings.warn(
111                f"The log dir at {self.log_dir} could not be created."
112                "The most likely reason for this is that you copied the checkpoint somewhere else,"
113                "so we skip this error to enable loading the model from this checkpoint."
114            )
115            return
116
117        if SummaryWriter is None:
118            msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'"
119            raise RuntimeError(msg)
120        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
121        self.log_image_interval = trainer.log_image_interval
122
123        # derive which visualisation method is appropriate, based on the loss function
124        if type(trainer.loss) in EMBEDDING_LOSSES:
125            self.have_embeddings = True
126            self.make_image = make_embedding_image
127        else:
128            self.have_embeddings = False
129            self.make_image = make_grid_image
130
131    def log_images(self, step, x, y, prediction, name, gradients=None):
132        """@private
133        """
134        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
135        image = normalize_im(x[selection].cpu())
136        self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step)
137
138        im, im_name = self.make_image(image, y, prediction, selection, gradients)
139        im_name = f"{name}/{im_name}"
140        self.tb.add_image(tag=im_name, img_tensor=im, global_step=step)
141
142    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
143        """@private
144        """
145        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
146        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
147
148        # the embedding visualisation function currently doesn't support gradients,
149        # so we can't log them even if log_gradients is true
150        log_grads = log_gradients
151        if self.have_embeddings:
152            log_grads = False
153
154        if step % self.log_image_interval == 0:
155            gradients = prediction.grad if log_grads else None
156            self.log_images(step, x, y, prediction, "train", gradients=gradients)
157
158    def log_validation(self, step, metric, loss, x, y, prediction):
159        """@private
160        """
161        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
162        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
163        self.log_images(step, x, y, prediction, "validation")

Logger to write training progress to tensorboard.

Arguments:
  • trainer: The instantiated trainer.
  • save_root: The root directury for writing checkpoints and log files.
TensorboardLogger(trainer, save_root: str, **unused_kwargs)
102    def __init__(self, trainer, save_root: str, **unused_kwargs):
103        super().__init__(trainer, save_root)
104        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
105            os.path.join(save_root, "logs", trainer.name)
106
107        try:
108            os.makedirs(self.log_dir, exist_ok=True)
109        except PermissionError:
110            warnings.warn(
111                f"The log dir at {self.log_dir} could not be created."
112                "The most likely reason for this is that you copied the checkpoint somewhere else,"
113                "so we skip this error to enable loading the model from this checkpoint."
114            )
115            return
116
117        if SummaryWriter is None:
118            msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'"
119            raise RuntimeError(msg)
120        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
121        self.log_image_interval = trainer.log_image_interval
122
123        # derive which visualisation method is appropriate, based on the loss function
124        if type(trainer.loss) in EMBEDDING_LOSSES:
125            self.have_embeddings = True
126            self.make_image = make_embedding_image
127        else:
128            self.have_embeddings = False
129            self.make_image = make_grid_image
log_dir
tb
log_image_interval