torch_em.classification.classification_logger

  1import os
  2
  3import matplotlib.pyplot as plt
  4import numpy as np
  5import torch
  6
  7from matplotlib.backends.backend_agg import FigureCanvasAgg
  8from sklearn.metrics import ConfusionMatrixDisplay
  9from torch_em.trainer.logger_base import TorchEmLogger
 10from torch_em.transform.raw import normalize
 11
 12
 13def confusion_matrix(y_true, y_pred, class_labels=None, title=None, save_path=None, **plot_kwargs):
 14    fig, ax = plt.subplots(1)
 15
 16    if save_path is None:
 17        canvas = FigureCanvasAgg(fig)
 18
 19    disp = ConfusionMatrixDisplay.from_predictions(
 20        y_true, y_pred, normalize="true", display_labels=class_labels
 21    )
 22    disp.plot(ax=ax, **plot_kwargs)
 23
 24    if title is not None:
 25        ax.set_title(title)
 26    if save_path is not None:
 27        plt.savefig(save_path)
 28        return
 29
 30    canvas.draw()
 31    image = np.asarray(canvas.buffer_rgba())[..., :3]
 32    image = image.transpose((2, 0, 1))
 33    plt.close()
 34    return image
 35
 36
 37# TODO get the class names
 38def make_grid(images, target=None, prediction=None, images_per_row=8, **kwargs):
 39    assert images.ndim in (4, 5)
 40    assert images.shape[1] in (1, 3), f"{images.shape}"
 41
 42    if images.ndim == 5:
 43        is_3d = True
 44        z = images.shape[2] // 2
 45    else:
 46        is_3d = False
 47
 48    n_images = images.shape[0]
 49    n_rows = n_images // images_per_row
 50    if n_images % images_per_row != 0:
 51        n_rows += 1
 52
 53    images = images.detach().cpu().numpy()
 54    if target is not None:
 55        target = target.detach().cpu().numpy()
 56    if prediction is not None:
 57        prediction = prediction.max(1)[1].detach().cpu().numpy()
 58
 59    fig, axes = plt.subplots(n_rows, images_per_row)
 60    canvas = FigureCanvasAgg(fig)
 61    for r in range(n_rows):
 62        for c in range(images_per_row):
 63            i = r * images_per_row + c
 64            if i == len(images):
 65                break
 66            ax = axes[r, c] if n_rows > 1 else axes[r]
 67            ax.set_axis_off()
 68            im = images[i, :, z] if is_3d else images[i]
 69            im = im.transpose((1, 2, 0))
 70            im = normalize(im, axis=(0, 1))
 71            if im.shape[-1] == 3:  # rgb
 72                ax.imshow(im)
 73            else:
 74                ax.imshow(im[..., 0], cmap="gray")
 75
 76            if target is None and prediction is None:
 77                continue
 78
 79            # TODO get the class name, and if we have both target
 80            # and prediction check whether they agree or not and do stuff
 81            title = ""
 82            if target is not None:
 83                title += f"t: {target[i]} "
 84            if prediction is not None:
 85                title += f"p: {prediction[i]}"
 86            ax.set_title(title, fontsize=8)
 87
 88    canvas.draw()
 89    image = np.asarray(canvas.buffer_rgba())[..., :3]
 90    image = image.transpose((2, 0, 1))
 91    plt.close()
 92    return image
 93
 94
 95class ClassificationLogger(TorchEmLogger):
 96    def __init__(self, trainer, save_root, **unused_kwargs):
 97        super().__init__(trainer, save_root)
 98        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
 99            os.path.join(save_root, "logs", trainer.name)
100        os.makedirs(self.log_dir, exist_ok=True)
101
102        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
103        self.log_image_interval = trainer.log_image_interval
104
105    def add_image(self, x, y, pred, name, step):
106        scale_each = False
107        grid = make_grid(x, y, pred, padding=4, normalize=True, scale_each=scale_each)
108        self.tb.add_image(tag=f"{name}/images_and_predictions", img_tensor=grid, global_step=step)
109
110    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
111        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
112        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
113        if step % self.log_image_interval == 0:
114            self.add_image(x, y, prediction, "train", step)
115
116    def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None):
117        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
118        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
119        self.add_image(x, y, prediction, "validation", step)
120        if y_true is not None and y_pred is not None:
121            cm = confusion_matrix(y_true, y_pred)
122            self.tb.add_image(tag="validation/confusion_matrix", img_tensor=cm, global_step=step)
def confusion_matrix( y_true, y_pred, class_labels=None, title=None, save_path=None, **plot_kwargs):
14def confusion_matrix(y_true, y_pred, class_labels=None, title=None, save_path=None, **plot_kwargs):
15    fig, ax = plt.subplots(1)
16
17    if save_path is None:
18        canvas = FigureCanvasAgg(fig)
19
20    disp = ConfusionMatrixDisplay.from_predictions(
21        y_true, y_pred, normalize="true", display_labels=class_labels
22    )
23    disp.plot(ax=ax, **plot_kwargs)
24
25    if title is not None:
26        ax.set_title(title)
27    if save_path is not None:
28        plt.savefig(save_path)
29        return
30
31    canvas.draw()
32    image = np.asarray(canvas.buffer_rgba())[..., :3]
33    image = image.transpose((2, 0, 1))
34    plt.close()
35    return image
def make_grid(images, target=None, prediction=None, images_per_row=8, **kwargs):
39def make_grid(images, target=None, prediction=None, images_per_row=8, **kwargs):
40    assert images.ndim in (4, 5)
41    assert images.shape[1] in (1, 3), f"{images.shape}"
42
43    if images.ndim == 5:
44        is_3d = True
45        z = images.shape[2] // 2
46    else:
47        is_3d = False
48
49    n_images = images.shape[0]
50    n_rows = n_images // images_per_row
51    if n_images % images_per_row != 0:
52        n_rows += 1
53
54    images = images.detach().cpu().numpy()
55    if target is not None:
56        target = target.detach().cpu().numpy()
57    if prediction is not None:
58        prediction = prediction.max(1)[1].detach().cpu().numpy()
59
60    fig, axes = plt.subplots(n_rows, images_per_row)
61    canvas = FigureCanvasAgg(fig)
62    for r in range(n_rows):
63        for c in range(images_per_row):
64            i = r * images_per_row + c
65            if i == len(images):
66                break
67            ax = axes[r, c] if n_rows > 1 else axes[r]
68            ax.set_axis_off()
69            im = images[i, :, z] if is_3d else images[i]
70            im = im.transpose((1, 2, 0))
71            im = normalize(im, axis=(0, 1))
72            if im.shape[-1] == 3:  # rgb
73                ax.imshow(im)
74            else:
75                ax.imshow(im[..., 0], cmap="gray")
76
77            if target is None and prediction is None:
78                continue
79
80            # TODO get the class name, and if we have both target
81            # and prediction check whether they agree or not and do stuff
82            title = ""
83            if target is not None:
84                title += f"t: {target[i]} "
85            if prediction is not None:
86                title += f"p: {prediction[i]}"
87            ax.set_title(title, fontsize=8)
88
89    canvas.draw()
90    image = np.asarray(canvas.buffer_rgba())[..., :3]
91    image = image.transpose((2, 0, 1))
92    plt.close()
93    return image
class ClassificationLogger(torch_em.trainer.logger_base.TorchEmLogger):
 96class ClassificationLogger(TorchEmLogger):
 97    def __init__(self, trainer, save_root, **unused_kwargs):
 98        super().__init__(trainer, save_root)
 99        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
100            os.path.join(save_root, "logs", trainer.name)
101        os.makedirs(self.log_dir, exist_ok=True)
102
103        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
104        self.log_image_interval = trainer.log_image_interval
105
106    def add_image(self, x, y, pred, name, step):
107        scale_each = False
108        grid = make_grid(x, y, pred, padding=4, normalize=True, scale_each=scale_each)
109        self.tb.add_image(tag=f"{name}/images_and_predictions", img_tensor=grid, global_step=step)
110
111    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
112        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
113        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
114        if step % self.log_image_interval == 0:
115            self.add_image(x, y, prediction, "train", step)
116
117    def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None):
118        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
119        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
120        self.add_image(x, y, prediction, "validation", step)
121        if y_true is not None and y_pred is not None:
122            cm = confusion_matrix(y_true, y_pred)
123            self.tb.add_image(tag="validation/confusion_matrix", img_tensor=cm, global_step=step)
ClassificationLogger(trainer, save_root, **unused_kwargs)
 97    def __init__(self, trainer, save_root, **unused_kwargs):
 98        super().__init__(trainer, save_root)
 99        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
100            os.path.join(save_root, "logs", trainer.name)
101        os.makedirs(self.log_dir, exist_ok=True)
102
103        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
104        self.log_image_interval = trainer.log_image_interval
log_dir
tb
log_image_interval
def add_image(self, x, y, pred, name, step):
106    def add_image(self, x, y, pred, name, step):
107        scale_each = False
108        grid = make_grid(x, y, pred, padding=4, normalize=True, scale_each=scale_each)
109        self.tb.add_image(tag=f"{name}/images_and_predictions", img_tensor=grid, global_step=step)
def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
111    def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
112        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
113        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
114        if step % self.log_image_interval == 0:
115            self.add_image(x, y, prediction, "train", step)
def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None):
117    def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None):
118        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
119        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
120        self.add_image(x, y, prediction, "validation", step)
121        if y_true is not None and y_pred is not None:
122            cm = confusion_matrix(y_true, y_pred)
123            self.tb.add_image(tag="validation/confusion_matrix", img_tensor=cm, global_step=step)