torch_em.classification.classification_logger
1import os 2 3import matplotlib.pyplot as plt 4import numpy as np 5import torch 6 7from matplotlib.backends.backend_agg import FigureCanvasAgg 8from sklearn.metrics import ConfusionMatrixDisplay 9from torch_em.trainer.logger_base import TorchEmLogger 10from torch_em.transform.raw import normalize 11 12 13def confusion_matrix(y_true, y_pred, class_labels=None, title=None, save_path=None, **plot_kwargs): 14 fig, ax = plt.subplots(1) 15 16 if save_path is None: 17 canvas = FigureCanvasAgg(fig) 18 19 disp = ConfusionMatrixDisplay.from_predictions( 20 y_true, y_pred, normalize="true", display_labels=class_labels 21 ) 22 disp.plot(ax=ax, **plot_kwargs) 23 24 if title is not None: 25 ax.set_title(title) 26 if save_path is not None: 27 plt.savefig(save_path) 28 return 29 30 canvas.draw() 31 image = np.asarray(canvas.buffer_rgba())[..., :3] 32 image = image.transpose((2, 0, 1)) 33 plt.close() 34 return image 35 36 37# TODO get the class names 38def make_grid(images, target=None, prediction=None, images_per_row=8, **kwargs): 39 assert images.ndim in (4, 5) 40 assert images.shape[1] in (1, 3), f"{images.shape}" 41 42 if images.ndim == 5: 43 is_3d = True 44 z = images.shape[2] // 2 45 else: 46 is_3d = False 47 48 n_images = images.shape[0] 49 n_rows = n_images // images_per_row 50 if n_images % images_per_row != 0: 51 n_rows += 1 52 53 images = images.detach().cpu().numpy() 54 if target is not None: 55 target = target.detach().cpu().numpy() 56 if prediction is not None: 57 prediction = prediction.max(1)[1].detach().cpu().numpy() 58 59 fig, axes = plt.subplots(n_rows, images_per_row) 60 canvas = FigureCanvasAgg(fig) 61 for r in range(n_rows): 62 for c in range(images_per_row): 63 i = r * images_per_row + c 64 if i == len(images): 65 break 66 ax = axes[r, c] if n_rows > 1 else axes[r] 67 ax.set_axis_off() 68 im = images[i, :, z] if is_3d else images[i] 69 im = im.transpose((1, 2, 0)) 70 im = normalize(im, axis=(0, 1)) 71 if im.shape[-1] == 3: # rgb 72 ax.imshow(im) 73 else: 74 ax.imshow(im[..., 0], cmap="gray") 75 76 if target is None and prediction is None: 77 continue 78 79 # TODO get the class name, and if we have both target 80 # and prediction check whether they agree or not and do stuff 81 title = "" 82 if target is not None: 83 title += f"t: {target[i]} " 84 if prediction is not None: 85 title += f"p: {prediction[i]}" 86 ax.set_title(title, fontsize=8) 87 88 canvas.draw() 89 image = np.asarray(canvas.buffer_rgba())[..., :3] 90 image = image.transpose((2, 0, 1)) 91 plt.close() 92 return image 93 94 95class ClassificationLogger(TorchEmLogger): 96 def __init__(self, trainer, save_root, **unused_kwargs): 97 super().__init__(trainer, save_root) 98 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 99 os.path.join(save_root, "logs", trainer.name) 100 os.makedirs(self.log_dir, exist_ok=True) 101 102 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 103 self.log_image_interval = trainer.log_image_interval 104 105 def add_image(self, x, y, pred, name, step): 106 scale_each = False 107 grid = make_grid(x, y, pred, padding=4, normalize=True, scale_each=scale_each) 108 self.tb.add_image(tag=f"{name}/images_and_predictions", img_tensor=grid, global_step=step) 109 110 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 111 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 112 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 113 if step % self.log_image_interval == 0: 114 self.add_image(x, y, prediction, "train", step) 115 116 def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None): 117 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 118 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 119 self.add_image(x, y, prediction, "validation", step) 120 if y_true is not None and y_pred is not None: 121 cm = confusion_matrix(y_true, y_pred) 122 self.tb.add_image(tag="validation/confusion_matrix", img_tensor=cm, global_step=step)
def
confusion_matrix( y_true, y_pred, class_labels=None, title=None, save_path=None, **plot_kwargs):
14def confusion_matrix(y_true, y_pred, class_labels=None, title=None, save_path=None, **plot_kwargs): 15 fig, ax = plt.subplots(1) 16 17 if save_path is None: 18 canvas = FigureCanvasAgg(fig) 19 20 disp = ConfusionMatrixDisplay.from_predictions( 21 y_true, y_pred, normalize="true", display_labels=class_labels 22 ) 23 disp.plot(ax=ax, **plot_kwargs) 24 25 if title is not None: 26 ax.set_title(title) 27 if save_path is not None: 28 plt.savefig(save_path) 29 return 30 31 canvas.draw() 32 image = np.asarray(canvas.buffer_rgba())[..., :3] 33 image = image.transpose((2, 0, 1)) 34 plt.close() 35 return image
def
make_grid(images, target=None, prediction=None, images_per_row=8, **kwargs):
39def make_grid(images, target=None, prediction=None, images_per_row=8, **kwargs): 40 assert images.ndim in (4, 5) 41 assert images.shape[1] in (1, 3), f"{images.shape}" 42 43 if images.ndim == 5: 44 is_3d = True 45 z = images.shape[2] // 2 46 else: 47 is_3d = False 48 49 n_images = images.shape[0] 50 n_rows = n_images // images_per_row 51 if n_images % images_per_row != 0: 52 n_rows += 1 53 54 images = images.detach().cpu().numpy() 55 if target is not None: 56 target = target.detach().cpu().numpy() 57 if prediction is not None: 58 prediction = prediction.max(1)[1].detach().cpu().numpy() 59 60 fig, axes = plt.subplots(n_rows, images_per_row) 61 canvas = FigureCanvasAgg(fig) 62 for r in range(n_rows): 63 for c in range(images_per_row): 64 i = r * images_per_row + c 65 if i == len(images): 66 break 67 ax = axes[r, c] if n_rows > 1 else axes[r] 68 ax.set_axis_off() 69 im = images[i, :, z] if is_3d else images[i] 70 im = im.transpose((1, 2, 0)) 71 im = normalize(im, axis=(0, 1)) 72 if im.shape[-1] == 3: # rgb 73 ax.imshow(im) 74 else: 75 ax.imshow(im[..., 0], cmap="gray") 76 77 if target is None and prediction is None: 78 continue 79 80 # TODO get the class name, and if we have both target 81 # and prediction check whether they agree or not and do stuff 82 title = "" 83 if target is not None: 84 title += f"t: {target[i]} " 85 if prediction is not None: 86 title += f"p: {prediction[i]}" 87 ax.set_title(title, fontsize=8) 88 89 canvas.draw() 90 image = np.asarray(canvas.buffer_rgba())[..., :3] 91 image = image.transpose((2, 0, 1)) 92 plt.close() 93 return image
96class ClassificationLogger(TorchEmLogger): 97 def __init__(self, trainer, save_root, **unused_kwargs): 98 super().__init__(trainer, save_root) 99 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 100 os.path.join(save_root, "logs", trainer.name) 101 os.makedirs(self.log_dir, exist_ok=True) 102 103 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 104 self.log_image_interval = trainer.log_image_interval 105 106 def add_image(self, x, y, pred, name, step): 107 scale_each = False 108 grid = make_grid(x, y, pred, padding=4, normalize=True, scale_each=scale_each) 109 self.tb.add_image(tag=f"{name}/images_and_predictions", img_tensor=grid, global_step=step) 110 111 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 112 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 113 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 114 if step % self.log_image_interval == 0: 115 self.add_image(x, y, prediction, "train", step) 116 117 def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None): 118 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 119 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 120 self.add_image(x, y, prediction, "validation", step) 121 if y_true is not None and y_pred is not None: 122 cm = confusion_matrix(y_true, y_pred) 123 self.tb.add_image(tag="validation/confusion_matrix", img_tensor=cm, global_step=step)
ClassificationLogger(trainer, save_root, **unused_kwargs)
97 def __init__(self, trainer, save_root, **unused_kwargs): 98 super().__init__(trainer, save_root) 99 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 100 os.path.join(save_root, "logs", trainer.name) 101 os.makedirs(self.log_dir, exist_ok=True) 102 103 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 104 self.log_image_interval = trainer.log_image_interval
def
log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
111 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 112 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 113 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 114 if step % self.log_image_interval == 0: 115 self.add_image(x, y, prediction, "train", step)
def
log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None):
117 def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None): 118 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 119 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 120 self.add_image(x, y, prediction, "validation", step) 121 if y_true is not None and y_pred is not None: 122 cm = confusion_matrix(y_true, y_pred) 123 self.tb.add_image(tag="validation/confusion_matrix", img_tensor=cm, global_step=step)