torch_em.trainer.tensorboard_logger

  1import os
  2import warnings
  3
  4import numpy as np
  5import torch
  6
  7from elf.segmentation.embeddings import embedding_pca
  8from skimage.segmentation import mark_boundaries
  9from torchvision.utils import make_grid
 10
 11from .logger_base import TorchEmLogger
 12
 13# tensorboard import only works if tensobard package is available, so
 14# we wrap this in a try except
 15try:
 16    from torch.utils.tensorboard import SummaryWriter
 17except ImportError:
 18    SummaryWriter = None
 19
 20from ..util import ensure_tensor
 21from ..loss import EMBEDDING_LOSSES
 22
 23
 24def normalize_im(im):
 25    im = ensure_tensor(im, dtype=torch.float32)
 26    im -= im.min()
 27    im /= im.max()
 28    return im
 29
 30
 31def make_grid_image(image, y, prediction, selection, gradients=None):
 32    target_image = normalize_im(y[selection].cpu())
 33    pred_image = normalize_im(prediction[selection].detach().cpu())
 34
 35    if image.shape[0] > 1:
 36        image = image[0:1]
 37
 38    n_channels = pred_image.shape[0]
 39    n_channels_target = target_image.shape[0]
 40    if n_channels_target == n_channels == 1:
 41        nrow = 8
 42        images = [image, target_image, pred_image]
 43    elif n_channels_target == 1:
 44        nrow = n_channels
 45        images = nrow * [image]
 46        images += (nrow * [target_image])
 47        images += [channel.unsqueeze(0) for channel in pred_image]
 48    else:
 49        nrow = n_channels
 50        images = nrow * [image]
 51        images += [channel.unsqueeze(0) for channel in target_image]
 52        images += [channel.unsqueeze(0) for channel in pred_image]
 53
 54    if gradients is not None:
 55        grad_image = normalize_im(gradients[selection].cpu())
 56        if n_channels == 1:
 57            images.append(grad_image)
 58        else:
 59            images += [channel.unsqueeze(0) for channel in grad_image]
 60
 61    im = make_grid(images, nrow=nrow, padding=4)
 62    name = "raw_targets_predictions"
 63    if gradients is not None:
 64        name += "_gradients"
 65    return im, name
 66
 67
 68def make_embedding_image(image, y, prediction, selection, gradients=None):
 69    assert gradients is None, "Not implemented"
 70    image = image.numpy()
 71
 72    seg = y[selection].cpu().numpy()
 73    seg = mark_boundaries(image[0], seg[0])  # need to get rid of singleton channel
 74    seg = seg.transpose((2, 0, 1))  # to channel first
 75
 76    pred = prediction[selection].detach().cpu().numpy()
 77    pca = embedding_pca(pred)
 78
 79    image = np.repeat(image, 3, axis=0)  # to rgb
 80    images = [torch.from_numpy(im) for im in (image, seg, pca)]
 81    im = make_grid(images, padding=4)
 82    name = "raw_segmentation_embedding"
 83    if gradients is not None:
 84        name += "_gradients"
 85    return im, name
 86
 87
 88class TensorboardLogger(TorchEmLogger):
 89    def __init__(self, trainer, save_root, **unused_kwargs):
 90        super().__init__(trainer, save_root)
 91        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
 92            os.path.join(save_root, "logs", trainer.name)
 93
 94        try:
 95            os.makedirs(self.log_dir, exist_ok=True)
 96        except PermissionError:
 97            warnings.warn(
 98                f"The log dir at {self.log_dir} could not be created."
 99                "The most likely reason for this is that you copied the checkpoint somewhere else,"
100                "so we skip this error to enable loading the model from this checkpoint."
101            )
102            return
103
104        if SummaryWriter is None:
105            msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'"
106            raise RuntimeError(msg)
107        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
108        self.log_image_interval = trainer.log_image_interval
109
110        # derive which visualisation method is appropriate, based on the loss function
111        if type(trainer.loss) in EMBEDDING_LOSSES:
112            self.have_embeddings = True
113            self.make_image = make_embedding_image
114        else:
115            self.have_embeddings = False
116            self.make_image = make_grid_image
117
118    def log_images(self, step, x, y, prediction, name, gradients=None):
119
120        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
121
122        image = normalize_im(x[selection].cpu())
123        self.tb.add_image(tag=f"{name}/input",
124                          img_tensor=image,
125                          global_step=step)
126
127        im, im_name = self.make_image(image, y, prediction, selection, gradients)
128        im_name = f"{name}/{im_name}"
129        self.tb.add_image(tag=im_name, img_tensor=im, global_step=step)
130
131    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
132        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
133        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
134
135        # the embedding visualisation function currently doesn't support gradients,
136        # so we can't log them even if log_gradients is true
137        log_grads = log_gradients
138        if self.have_embeddings:
139            log_grads = False
140
141        if step % self.log_image_interval == 0:
142            gradients = prediction.grad if log_grads else None
143            self.log_images(step, x, y, prediction, "train", gradients=gradients)
144
145    def log_validation(self, step, metric, loss, x, y, prediction):
146        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
147        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
148        self.log_images(step, x, y, prediction, "validation")
def normalize_im(im):
25def normalize_im(im):
26    im = ensure_tensor(im, dtype=torch.float32)
27    im -= im.min()
28    im /= im.max()
29    return im
def make_grid_image(image, y, prediction, selection, gradients=None):
32def make_grid_image(image, y, prediction, selection, gradients=None):
33    target_image = normalize_im(y[selection].cpu())
34    pred_image = normalize_im(prediction[selection].detach().cpu())
35
36    if image.shape[0] > 1:
37        image = image[0:1]
38
39    n_channels = pred_image.shape[0]
40    n_channels_target = target_image.shape[0]
41    if n_channels_target == n_channels == 1:
42        nrow = 8
43        images = [image, target_image, pred_image]
44    elif n_channels_target == 1:
45        nrow = n_channels
46        images = nrow * [image]
47        images += (nrow * [target_image])
48        images += [channel.unsqueeze(0) for channel in pred_image]
49    else:
50        nrow = n_channels
51        images = nrow * [image]
52        images += [channel.unsqueeze(0) for channel in target_image]
53        images += [channel.unsqueeze(0) for channel in pred_image]
54
55    if gradients is not None:
56        grad_image = normalize_im(gradients[selection].cpu())
57        if n_channels == 1:
58            images.append(grad_image)
59        else:
60            images += [channel.unsqueeze(0) for channel in grad_image]
61
62    im = make_grid(images, nrow=nrow, padding=4)
63    name = "raw_targets_predictions"
64    if gradients is not None:
65        name += "_gradients"
66    return im, name
def make_embedding_image(image, y, prediction, selection, gradients=None):
69def make_embedding_image(image, y, prediction, selection, gradients=None):
70    assert gradients is None, "Not implemented"
71    image = image.numpy()
72
73    seg = y[selection].cpu().numpy()
74    seg = mark_boundaries(image[0], seg[0])  # need to get rid of singleton channel
75    seg = seg.transpose((2, 0, 1))  # to channel first
76
77    pred = prediction[selection].detach().cpu().numpy()
78    pca = embedding_pca(pred)
79
80    image = np.repeat(image, 3, axis=0)  # to rgb
81    images = [torch.from_numpy(im) for im in (image, seg, pca)]
82    im = make_grid(images, padding=4)
83    name = "raw_segmentation_embedding"
84    if gradients is not None:
85        name += "_gradients"
86    return im, name
class TensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
 89class TensorboardLogger(TorchEmLogger):
 90    def __init__(self, trainer, save_root, **unused_kwargs):
 91        super().__init__(trainer, save_root)
 92        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
 93            os.path.join(save_root, "logs", trainer.name)
 94
 95        try:
 96            os.makedirs(self.log_dir, exist_ok=True)
 97        except PermissionError:
 98            warnings.warn(
 99                f"The log dir at {self.log_dir} could not be created."
100                "The most likely reason for this is that you copied the checkpoint somewhere else,"
101                "so we skip this error to enable loading the model from this checkpoint."
102            )
103            return
104
105        if SummaryWriter is None:
106            msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'"
107            raise RuntimeError(msg)
108        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
109        self.log_image_interval = trainer.log_image_interval
110
111        # derive which visualisation method is appropriate, based on the loss function
112        if type(trainer.loss) in EMBEDDING_LOSSES:
113            self.have_embeddings = True
114            self.make_image = make_embedding_image
115        else:
116            self.have_embeddings = False
117            self.make_image = make_grid_image
118
119    def log_images(self, step, x, y, prediction, name, gradients=None):
120
121        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
122
123        image = normalize_im(x[selection].cpu())
124        self.tb.add_image(tag=f"{name}/input",
125                          img_tensor=image,
126                          global_step=step)
127
128        im, im_name = self.make_image(image, y, prediction, selection, gradients)
129        im_name = f"{name}/{im_name}"
130        self.tb.add_image(tag=im_name, img_tensor=im, global_step=step)
131
132    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
133        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
134        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
135
136        # the embedding visualisation function currently doesn't support gradients,
137        # so we can't log them even if log_gradients is true
138        log_grads = log_gradients
139        if self.have_embeddings:
140            log_grads = False
141
142        if step % self.log_image_interval == 0:
143            gradients = prediction.grad if log_grads else None
144            self.log_images(step, x, y, prediction, "train", gradients=gradients)
145
146    def log_validation(self, step, metric, loss, x, y, prediction):
147        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
148        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
149        self.log_images(step, x, y, prediction, "validation")
TensorboardLogger(trainer, save_root, **unused_kwargs)
 90    def __init__(self, trainer, save_root, **unused_kwargs):
 91        super().__init__(trainer, save_root)
 92        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
 93            os.path.join(save_root, "logs", trainer.name)
 94
 95        try:
 96            os.makedirs(self.log_dir, exist_ok=True)
 97        except PermissionError:
 98            warnings.warn(
 99                f"The log dir at {self.log_dir} could not be created."
100                "The most likely reason for this is that you copied the checkpoint somewhere else,"
101                "so we skip this error to enable loading the model from this checkpoint."
102            )
103            return
104
105        if SummaryWriter is None:
106            msg = "Need tensorboard package to use logger. Install it via 'conda install -c conda-forge tensorboard'"
107            raise RuntimeError(msg)
108        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
109        self.log_image_interval = trainer.log_image_interval
110
111        # derive which visualisation method is appropriate, based on the loss function
112        if type(trainer.loss) in EMBEDDING_LOSSES:
113            self.have_embeddings = True
114            self.make_image = make_embedding_image
115        else:
116            self.have_embeddings = False
117            self.make_image = make_grid_image
log_dir
tb
log_image_interval
def log_images(self, step, x, y, prediction, name, gradients=None):
119    def log_images(self, step, x, y, prediction, name, gradients=None):
120
121        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
122
123        image = normalize_im(x[selection].cpu())
124        self.tb.add_image(tag=f"{name}/input",
125                          img_tensor=image,
126                          global_step=step)
127
128        im, im_name = self.make_image(image, y, prediction, selection, gradients)
129        im_name = f"{name}/{im_name}"
130        self.tb.add_image(tag=im_name, img_tensor=im, global_step=step)
def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
132    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
133        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
134        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
135
136        # the embedding visualisation function currently doesn't support gradients,
137        # so we can't log them even if log_gradients is true
138        log_grads = log_gradients
139        if self.have_embeddings:
140            log_grads = False
141
142        if step % self.log_image_interval == 0:
143            gradients = prediction.grad if log_grads else None
144            self.log_images(step, x, y, prediction, "train", gradients=gradients)
def log_validation(self, step, metric, loss, x, y, prediction):
146    def log_validation(self, step, metric, loss, x, y, prediction):
147        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
148        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
149        self.log_images(step, x, y, prediction, "validation")