torch_em.classification.classification_logger

  1import os
  2
  3import matplotlib.pyplot as plt
  4import numpy as np
  5
  6from matplotlib.backends.backend_agg import FigureCanvasAgg
  7from sklearn.metrics import ConfusionMatrixDisplay
  8from torch.utils.tensorboard import SummaryWriter
  9from torch_em.trainer.logger_base import TorchEmLogger
 10from torch_em.transform.raw import normalize
 11
 12
 13def confusion_matrix(y_true, y_pred, class_labels=None, title=None, save_path=None, **plot_kwargs):
 14    """@private
 15    """
 16    fig, ax = plt.subplots(1)
 17
 18    if save_path is None:
 19        canvas = FigureCanvasAgg(fig)
 20
 21    disp = ConfusionMatrixDisplay.from_predictions(
 22        y_true, y_pred, normalize="true", display_labels=class_labels
 23    )
 24    disp.plot(ax=ax, **plot_kwargs)
 25
 26    if title is not None:
 27        ax.set_title(title)
 28    if save_path is not None:
 29        plt.savefig(save_path)
 30        return
 31
 32    canvas.draw()
 33    image = np.asarray(canvas.buffer_rgba())[..., :3]
 34    image = image.transpose((2, 0, 1))
 35    plt.close()
 36    return image
 37
 38
 39def make_grid(images, target=None, prediction=None, images_per_row=8, **kwargs):
 40    """@private
 41    """
 42    assert images.ndim in (4, 5)
 43    assert images.shape[1] in (1, 3), f"{images.shape}"
 44
 45    if images.ndim == 5:
 46        is_3d = True
 47        z = images.shape[2] // 2
 48    else:
 49        is_3d = False
 50
 51    n_images = images.shape[0]
 52    n_rows = n_images // images_per_row
 53    if n_images % images_per_row != 0:
 54        n_rows += 1
 55
 56    images = images.detach().cpu().numpy()
 57    if target is not None:
 58        target = target.detach().cpu().numpy()
 59    if prediction is not None:
 60        prediction = prediction.max(1)[1].detach().cpu().numpy()
 61
 62    fig, axes = plt.subplots(n_rows, images_per_row)
 63    canvas = FigureCanvasAgg(fig)
 64    for r in range(n_rows):
 65        for c in range(images_per_row):
 66            i = r * images_per_row + c
 67            if i == len(images):
 68                break
 69            ax = axes[r, c] if n_rows > 1 else axes[r]
 70            ax.set_axis_off()
 71            im = images[i, :, z] if is_3d else images[i]
 72            im = im.transpose((1, 2, 0))
 73            im = normalize(im, axis=(0, 1))
 74            if im.shape[-1] == 3:  # rgb
 75                ax.imshow(im)
 76            else:
 77                ax.imshow(im[..., 0], cmap="gray")
 78
 79            if target is None and prediction is None:
 80                continue
 81
 82            # TODO get the class name, and if we have both target
 83            # and prediction check whether they agree or not and do stuff
 84            title = ""
 85            if target is not None:
 86                title += f"t: {target[i]} "
 87            if prediction is not None:
 88                title += f"p: {prediction[i]}"
 89            ax.set_title(title, fontsize=8)
 90
 91    canvas.draw()
 92    image = np.asarray(canvas.buffer_rgba())[..., :3]
 93    image = image.transpose((2, 0, 1))
 94    plt.close()
 95    return image
 96
 97
 98class ClassificationLogger(TorchEmLogger):
 99    """Logger for classification trainer.
100
101    Args:
102        trainer: The trainer instance.
103        save_root: Root folder for saving the checkpoints and logs.
104    """
105    def __init__(self, trainer, save_root: str, **unused_kwargs):
106        super().__init__(trainer, save_root)
107        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
108            os.path.join(save_root, "logs", trainer.name)
109        os.makedirs(self.log_dir, exist_ok=True)
110
111        self.tb = SummaryWriter(self.log_dir)
112        self.log_image_interval = trainer.log_image_interval
113
114    def add_image(self, x, y, pred, name, step):
115        """@private
116        """
117        scale_each = False
118        grid = make_grid(x, y, pred, padding=4, normalize=True, scale_each=scale_each)
119        self.tb.add_image(tag=f"{name}/images_and_predictions", img_tensor=grid, global_step=step)
120
121    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
122        """@private
123        """
124        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
125        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
126        if step % self.log_image_interval == 0:
127            self.add_image(x, y, prediction, "train", step)
128
129    def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None):
130        """@private
131        """
132        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
133        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
134        self.add_image(x, y, prediction, "validation", step)
135        if y_true is not None and y_pred is not None:
136            cm = confusion_matrix(y_true, y_pred)
137            self.tb.add_image(tag="validation/confusion_matrix", img_tensor=cm, global_step=step)
class ClassificationLogger(torch_em.trainer.logger_base.TorchEmLogger):
 99class ClassificationLogger(TorchEmLogger):
100    """Logger for classification trainer.
101
102    Args:
103        trainer: The trainer instance.
104        save_root: Root folder for saving the checkpoints and logs.
105    """
106    def __init__(self, trainer, save_root: str, **unused_kwargs):
107        super().__init__(trainer, save_root)
108        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
109            os.path.join(save_root, "logs", trainer.name)
110        os.makedirs(self.log_dir, exist_ok=True)
111
112        self.tb = SummaryWriter(self.log_dir)
113        self.log_image_interval = trainer.log_image_interval
114
115    def add_image(self, x, y, pred, name, step):
116        """@private
117        """
118        scale_each = False
119        grid = make_grid(x, y, pred, padding=4, normalize=True, scale_each=scale_each)
120        self.tb.add_image(tag=f"{name}/images_and_predictions", img_tensor=grid, global_step=step)
121
122    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
123        """@private
124        """
125        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
126        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
127        if step % self.log_image_interval == 0:
128            self.add_image(x, y, prediction, "train", step)
129
130    def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None):
131        """@private
132        """
133        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
134        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
135        self.add_image(x, y, prediction, "validation", step)
136        if y_true is not None and y_pred is not None:
137            cm = confusion_matrix(y_true, y_pred)
138            self.tb.add_image(tag="validation/confusion_matrix", img_tensor=cm, global_step=step)

Logger for classification trainer.

Arguments:
  • trainer: The trainer instance.
  • save_root: Root folder for saving the checkpoints and logs.
ClassificationLogger(trainer, save_root: str, **unused_kwargs)
106    def __init__(self, trainer, save_root: str, **unused_kwargs):
107        super().__init__(trainer, save_root)
108        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
109            os.path.join(save_root, "logs", trainer.name)
110        os.makedirs(self.log_dir, exist_ok=True)
111
112        self.tb = SummaryWriter(self.log_dir)
113        self.log_image_interval = trainer.log_image_interval
log_dir
tb
log_image_interval