torch_em.classification.classification_logger
1import os 2 3import matplotlib.pyplot as plt 4import numpy as np 5 6from matplotlib.backends.backend_agg import FigureCanvasAgg 7from sklearn.metrics import ConfusionMatrixDisplay 8from torch.utils.tensorboard import SummaryWriter 9from torch_em.trainer.logger_base import TorchEmLogger 10from torch_em.transform.raw import normalize 11 12 13def confusion_matrix(y_true, y_pred, class_labels=None, title=None, save_path=None, **plot_kwargs): 14 """@private 15 """ 16 fig, ax = plt.subplots(1) 17 18 if save_path is None: 19 canvas = FigureCanvasAgg(fig) 20 21 disp = ConfusionMatrixDisplay.from_predictions( 22 y_true, y_pred, normalize="true", display_labels=class_labels 23 ) 24 disp.plot(ax=ax, **plot_kwargs) 25 26 if title is not None: 27 ax.set_title(title) 28 if save_path is not None: 29 plt.savefig(save_path) 30 return 31 32 canvas.draw() 33 image = np.asarray(canvas.buffer_rgba())[..., :3] 34 image = image.transpose((2, 0, 1)) 35 plt.close() 36 return image 37 38 39def make_grid(images, target=None, prediction=None, images_per_row=8, **kwargs): 40 """@private 41 """ 42 assert images.ndim in (4, 5) 43 assert images.shape[1] in (1, 3), f"{images.shape}" 44 45 if images.ndim == 5: 46 is_3d = True 47 z = images.shape[2] // 2 48 else: 49 is_3d = False 50 51 n_images = images.shape[0] 52 n_rows = n_images // images_per_row 53 if n_images % images_per_row != 0: 54 n_rows += 1 55 56 images = images.detach().cpu().numpy() 57 if target is not None: 58 target = target.detach().cpu().numpy() 59 if prediction is not None: 60 prediction = prediction.max(1)[1].detach().cpu().numpy() 61 62 fig, axes = plt.subplots(n_rows, images_per_row) 63 canvas = FigureCanvasAgg(fig) 64 for r in range(n_rows): 65 for c in range(images_per_row): 66 i = r * images_per_row + c 67 if i == len(images): 68 break 69 ax = axes[r, c] if n_rows > 1 else axes[r] 70 ax.set_axis_off() 71 im = images[i, :, z] if is_3d else images[i] 72 im = im.transpose((1, 2, 0)) 73 im = normalize(im, axis=(0, 1)) 74 if im.shape[-1] == 3: # rgb 75 ax.imshow(im) 76 else: 77 ax.imshow(im[..., 0], cmap="gray") 78 79 if target is None and prediction is None: 80 continue 81 82 # TODO get the class name, and if we have both target 83 # and prediction check whether they agree or not and do stuff 84 title = "" 85 if target is not None: 86 title += f"t: {target[i]} " 87 if prediction is not None: 88 title += f"p: {prediction[i]}" 89 ax.set_title(title, fontsize=8) 90 91 canvas.draw() 92 image = np.asarray(canvas.buffer_rgba())[..., :3] 93 image = image.transpose((2, 0, 1)) 94 plt.close() 95 return image 96 97 98class ClassificationLogger(TorchEmLogger): 99 """Logger for classification trainer. 100 101 Args: 102 trainer: The trainer instance. 103 save_root: Root folder for saving the checkpoints and logs. 104 """ 105 def __init__(self, trainer, save_root: str, **unused_kwargs): 106 super().__init__(trainer, save_root) 107 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 108 os.path.join(save_root, "logs", trainer.name) 109 os.makedirs(self.log_dir, exist_ok=True) 110 111 self.tb = SummaryWriter(self.log_dir) 112 self.log_image_interval = trainer.log_image_interval 113 114 def add_image(self, x, y, pred, name, step): 115 """@private 116 """ 117 scale_each = False 118 grid = make_grid(x, y, pred, padding=4, normalize=True, scale_each=scale_each) 119 self.tb.add_image(tag=f"{name}/images_and_predictions", img_tensor=grid, global_step=step) 120 121 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 122 """@private 123 """ 124 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 125 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 126 if step % self.log_image_interval == 0: 127 self.add_image(x, y, prediction, "train", step) 128 129 def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None): 130 """@private 131 """ 132 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 133 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 134 self.add_image(x, y, prediction, "validation", step) 135 if y_true is not None and y_pred is not None: 136 cm = confusion_matrix(y_true, y_pred) 137 self.tb.add_image(tag="validation/confusion_matrix", img_tensor=cm, global_step=step)
99class ClassificationLogger(TorchEmLogger): 100 """Logger for classification trainer. 101 102 Args: 103 trainer: The trainer instance. 104 save_root: Root folder for saving the checkpoints and logs. 105 """ 106 def __init__(self, trainer, save_root: str, **unused_kwargs): 107 super().__init__(trainer, save_root) 108 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 109 os.path.join(save_root, "logs", trainer.name) 110 os.makedirs(self.log_dir, exist_ok=True) 111 112 self.tb = SummaryWriter(self.log_dir) 113 self.log_image_interval = trainer.log_image_interval 114 115 def add_image(self, x, y, pred, name, step): 116 """@private 117 """ 118 scale_each = False 119 grid = make_grid(x, y, pred, padding=4, normalize=True, scale_each=scale_each) 120 self.tb.add_image(tag=f"{name}/images_and_predictions", img_tensor=grid, global_step=step) 121 122 def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): 123 """@private 124 """ 125 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 126 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 127 if step % self.log_image_interval == 0: 128 self.add_image(x, y, prediction, "train", step) 129 130 def log_validation(self, step, metric, loss, x, y, prediction, y_true=None, y_pred=None): 131 """@private 132 """ 133 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 134 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 135 self.add_image(x, y, prediction, "validation", step) 136 if y_true is not None and y_pred is not None: 137 cm = confusion_matrix(y_true, y_pred) 138 self.tb.add_image(tag="validation/confusion_matrix", img_tensor=cm, global_step=step)
Logger for classification trainer.
Arguments:
- trainer: The trainer instance.
- save_root: Root folder for saving the checkpoints and logs.
ClassificationLogger(trainer, save_root: str, **unused_kwargs)
106 def __init__(self, trainer, save_root: str, **unused_kwargs): 107 super().__init__(trainer, save_root) 108 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 109 os.path.join(save_root, "logs", trainer.name) 110 os.makedirs(self.log_dir, exist_ok=True) 111 112 self.tb = SummaryWriter(self.log_dir) 113 self.log_image_interval = trainer.log_image_interval