torch_em.self_training.logger
1import os 2 3import torch_em 4import torch 5 6from torchvision.utils import make_grid 7from torch.utils.tensorboard import SummaryWriter 8 9 10class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger): 11 """Logger for self-training via `torch_em.self_training.FixMatch` or `torch_em.self_training.MeanTeacher`. 12 Also supports logging training with invertible augmentations. 13 14 Args: 15 trainer: The instantiated trainer class. 16 save_root: The root directory for saving the checkpoints and logs. 17 """ 18 @staticmethod 19 def _get_image_channel(x): 20 return x[0, 0:1] if x.shape[1] > 1 else x[0] 21 22 def __init__(self, trainer, save_root, **unused_kwargs): 23 super().__init__(trainer, save_root) 24 self.my_root = save_root 25 self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ 26 os.path.join(self.my_root, "logs", trainer.name) 27 os.makedirs(self.log_dir, exist_ok=True) 28 29 self.tb = SummaryWriter(self.log_dir) 30 self.log_image_interval = trainer.log_image_interval 31 32 def _add_supervised_images(self, step, name, x, y, pred): 33 if x.ndim == 5: 34 assert y.ndim == pred.ndim == 5 35 zindex = x.shape[2] // 2 36 x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex] 37 38 num_channels = y.shape[1] 39 40 images = ( 41 [torch_em.transform.raw.normalize(self._get_image_channel(x))] * num_channels + 42 [y[0, c:c+1] for c in range(num_channels)] + 43 [pred[0, c:c+1] for c in range(num_channels)] 44 ) 45 grid = make_grid(images, nrow=num_channels, padding=8) 46 self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step) 47 48 def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter): 49 if x1.ndim == 5: 50 assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5 51 zindex = x1.shape[2] // 2 52 x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex] 53 pseudo_labels = pseudo_labels[:, :, zindex] 54 if label_filter is not None: 55 assert label_filter.ndim == 5 56 label_filter = label_filter[:, :, zindex] 57 58 num_channels = pred.shape[1] 59 60 images = ( 61 [torch_em.transform.raw.normalize(self._get_image_channel(x1))] + 62 [torch_em.transform.raw.normalize(self._get_image_channel(x2))] + 63 [torch.zeros_like(self._get_image_channel(x1))] * (num_channels - 2) + 64 [pred[0, c:c+1] for c in range(num_channels)] + 65 [pseudo_labels[0, c:c+1] for c in range(num_channels)] 66 ) 67 im_name = f"{name}/unsupervised/image-prediction-pseudolabels" 68 # if trainer with invertible augmentations, untransformed images 69 # and inverted pred/labels are logged for better visual comparison, 70 # otherwise the transformed images are logged 71 if label_filter is not None: 72 images.extend([label_filter[0, c:c+1] for c in range(num_channels)]) 73 im_name += "-labelfilter" 74 grid = make_grid(images, nrow=num_channels, padding=8) 75 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 76 77 def log_combined_loss(self, step, loss): 78 """@private 79 """ 80 self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step) 81 82 def log_lr(self, step, lr): 83 """@private 84 """ 85 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 86 87 def log_train_supervised(self, step, loss, x, y, pred): 88 """@private 89 """ 90 self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step) 91 if step % self.log_image_interval == 0: 92 self._add_supervised_images(step, "train", x, y, pred) 93 94 def log_validation_supervised(self, step, metric, loss, x, y, pred): 95 """@private 96 """ 97 self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step) 98 self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step) 99 self._add_supervised_images(step, "validation", x, y, pred) 100 101 def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None): 102 """@private 103 """ 104 self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step) 105 if step % self.log_image_interval == 0: 106 self._add_unsupervised_images(step, "train", x1, x2, pred, pseudo_labels, label_filter) 107 108 def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None): 109 """@private 110 """ 111 self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step) 112 self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step) 113 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) 114 115 def log_validation(self, step, metric, loss, gt_metric=None): 116 """@private 117 """ 118 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 119 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 120 if gt_metric is not None: 121 self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step) 122 123 def log_ct(self, step, ct): 124 self.tb.add_scalar(tag="train/confidence_threshold", scalar_value=ct, global_step=step) 125 126 def _add_augmented_images( 127 self, step, name, xu1, xu2, pseudo_labels, pred 128 ): 129 if xu1.ndim == 5: 130 assert ( 131 xu2.ndim 132 == pseudo_labels.ndim 133 == pred.ndim 134 == 5 135 ) 136 zindex = xu1.shape[2] // 2 137 xu1 = xu1[:, :, zindex] 138 xu2 = xu2[:, :, zindex] 139 pred = pred[:, :, zindex] 140 pseudo_labels = pseudo_labels[:, :, zindex] 141 142 images = [ 143 torch_em.transform.raw.normalize(xu1[0]), 144 torch_em.transform.raw.normalize(xu2[0]), 145 pseudo_labels[0, 0:1], 146 pred[0, 0:1], 147 ] 148 im_name = ( 149 f"{name}/unsupervised/aug1-aug2-pseudolabels-prediction" 150 ) 151 grid = make_grid(images, nrow=2, padding=8) 152 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 153 154 def log_train_augmentations( 155 self, step, xu1, xu2, pseudo_labels, pred 156 ): 157 if step % self.log_image_interval == 0: 158 self._add_augmented_images( 159 step, 160 "train_augmentations", 161 xu1, 162 xu2, 163 pseudo_labels, 164 pred, 165 ) 166 167 def log_validation_augmentations( 168 self, step, xu1, xu2, pseudo_labels, pred 169 ): 170 if step % self.log_image_interval == 0: 171 self._add_augmented_images( 172 step, 173 "validation_augmentations", 174 xu1, 175 xu2, 176 pseudo_labels, 177 pred, 178 ) 179 180 181class UniMatchv2TensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger): 182 """Logger for self-training via `torch_em.self_training.UniMatchv2Trainer`. 183 184 Args: 185 trainer: The instantiated trainer class. 186 save_root: The root directory for saving the checkpoints and logs. 187 """ 188 189 def __init__(self, trainer, save_root, **unused_kwargs): 190 super().__init__(trainer, save_root) 191 self.my_root = save_root 192 self.log_dir = ( 193 f"./logs/{trainer.name}" 194 if self.my_root is None 195 else os.path.join(self.my_root, "logs", trainer.name) 196 ) 197 os.makedirs(self.log_dir, exist_ok=True) 198 199 self.tb = SummaryWriter(self.log_dir) 200 self.log_image_interval = trainer.log_image_interval 201 202 def _add_supervised_images(self, step, name, x, y, pred): 203 if x.ndim == 5: 204 assert y.ndim == pred.ndim == 5 205 zindex = x.shape[2] // 2 206 x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex] 207 208 num_channels = y.shape[1] 209 210 images = ( 211 [torch_em.transform.raw.normalize(x[0])] * num_channels + 212 [y[0, c:c+1] for c in range(num_channels)] + 213 [pred[0, c:c+1] for c in range(num_channels)] 214 ) 215 grid = make_grid(images, nrow=num_channels, padding=8) 216 self.tb.add_image( 217 tag=f"{name}/supervised/input-labels-prediction", 218 img_tensor=grid, 219 global_step=step, 220 ) 221 222 def _add_unsupervised_images( 223 self, step, name, x, pred_s1, pred_s2, pseudo_labels, label_filter 224 ): 225 if x.ndim == 5: 226 assert ( 227 pred_s1.ndim 228 == pred_s2.ndim 229 == pseudo_labels.ndim 230 == 5 231 ) 232 zindex = x.shape[2] // 2 233 x = x[:, :, zindex] 234 pred_s1, pred_s2 = pred_s1[:, :, zindex], pred_s2[:, :, zindex] 235 pseudo_labels = pseudo_labels[:, :, zindex] 236 if label_filter is not None: 237 assert label_filter.ndim == 5 238 label_filter = label_filter[:, :, zindex] 239 num_channels = pred_s1.shape[1] 240 241 images = ( 242 [torch_em.transform.raw.normalize(x[0])] * num_channels + 243 [pred_s1[0, c:c+1] for c in range(num_channels)] + 244 [pred_s2[0, c:c+1] for c in range(num_channels)] + 245 [pseudo_labels[0, c:c+1] for c in range(num_channels)] 246 ) 247 248 im_name = ( 249 f"{name}/unsupervised/image-pred_s1-pred_s2-pseudolabels" 250 ) 251 if label_filter is not None: 252 images.extend([label_filter[0, c:c+1] for c in range(num_channels)]) 253 im_name += "-labelfilter" 254 grid = make_grid(images, nrow=num_channels, padding=8) 255 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 256 257 def log_combined_loss(self, step, loss): 258 """@private""" 259 self.tb.add_scalar( 260 tag="train/combined_loss", scalar_value=loss, global_step=step 261 ) 262 263 def log_lr(self, step, lr): 264 """@private""" 265 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 266 267 def log_train_supervised(self, step, loss, x, y, pred): 268 """@private""" 269 self.tb.add_scalar( 270 tag="train/supervised/loss", scalar_value=loss, global_step=step 271 ) 272 if step % self.log_image_interval == 0: 273 self._add_supervised_images(step, "train", x, y, pred) 274 275 def log_validation_supervised(self, step, metric, loss, x, y, pred): 276 """@private""" 277 self.tb.add_scalar( 278 tag="validation/supervised/loss", scalar_value=loss, global_step=step 279 ) 280 self.tb.add_scalar( 281 tag="validation/supervised/metric", scalar_value=metric, global_step=step 282 ) 283 self._add_supervised_images(step, "validation", x, y, pred) 284 285 def log_train_unsupervised( 286 self, 287 step, 288 loss, 289 x, 290 pred_s1, 291 pred_s2, 292 pseudo_labels, 293 label_filter=None, 294 ): 295 """@private""" 296 self.tb.add_scalar( 297 tag="train/unsupervised/loss", scalar_value=loss, global_step=step 298 ) 299 if step % self.log_image_interval == 0: 300 self._add_unsupervised_images( 301 step, 302 "train", 303 x, 304 pred_s1, 305 pred_s2, 306 pseudo_labels, 307 label_filter, 308 ) 309 310 def log_validation_unsupervised( 311 self, 312 step, 313 metric, 314 loss, 315 x, 316 pred_s1, 317 pred_s2, 318 pseudo_labels, 319 label_filter=None, 320 ): 321 """@private""" 322 self.tb.add_scalar( 323 tag="validation/unsupervised/loss", scalar_value=loss, global_step=step 324 ) 325 self.tb.add_scalar( 326 tag="validation/unsupervised/metric", scalar_value=metric, global_step=step 327 ) 328 self._add_unsupervised_images( 329 step, 330 "validation", 331 x, 332 pred_s1, 333 pred_s2, 334 pseudo_labels, 335 label_filter, 336 ) 337 338 def log_ct(self, step, ct): 339 self.tb.add_scalar( 340 tag="train/confidence_threshold", scalar_value=ct, global_step=step 341 ) 342 343 # LOG AUGMENTATIONS FOR DEBUGGING ### 344 def _add_augmented_images( 345 self, step, name, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2 346 ): 347 if x_u_w.ndim == 5: 348 assert ( 349 x_u_s1.ndim 350 == x_u_s2.ndim 351 == pseudo_labels.ndim 352 == pred_s1.ndim 353 == pred_s2.ndim 354 == 5 355 ) 356 zindex = x_u_w.shape[2] // 2 357 x_u_w = x_u_w[:, :, zindex] 358 x_u_s1, x_u_s2 = x_u_s1[:, :, zindex], x_u_s2[:, :, zindex] 359 pred_s1, pred_s2 = pred_s1[:, :, zindex], pred_s2[:, :, zindex] 360 pseudo_labels = pseudo_labels[:, :, zindex] 361 362 images = [ 363 torch_em.transform.raw.normalize(x_u_w[0]), 364 torch_em.transform.raw.normalize(x_u_s1[0]), 365 torch_em.transform.raw.normalize(x_u_s2[0]), 366 pseudo_labels[0, 0:1], 367 pred_s1[0, 0:1], 368 pred_s2[0, 0:1], 369 ] 370 im_name = ( 371 f"{name}/unsupervised/aug_w-aug_s1-aug_s2-pseudolabels-pred_s1-pred_s2" 372 ) 373 grid = make_grid(images, nrow=3, padding=8) 374 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 375 376 def log_train_augmentations( 377 self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2 378 ): 379 if step % self.log_image_interval == 0: 380 self._add_augmented_images( 381 step, 382 "train_augmentations", 383 x_u_w, 384 x_u_s1, 385 x_u_s2, 386 pseudo_labels, 387 pred_s1, 388 pred_s2, 389 ) 390 391 def log_validation_augmentations( 392 self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2 393 ): 394 if step % self.log_image_interval == 0: 395 self._add_augmented_images( 396 step, 397 "validation_augmentations", 398 x_u_w, 399 x_u_s1, 400 x_u_s2, 401 pseudo_labels, 402 pred_s1, 403 pred_s2, 404 )
11class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger): 12 """Logger for self-training via `torch_em.self_training.FixMatch` or `torch_em.self_training.MeanTeacher`. 13 Also supports logging training with invertible augmentations. 14 15 Args: 16 trainer: The instantiated trainer class. 17 save_root: The root directory for saving the checkpoints and logs. 18 """ 19 @staticmethod 20 def _get_image_channel(x): 21 return x[0, 0:1] if x.shape[1] > 1 else x[0] 22 23 def __init__(self, trainer, save_root, **unused_kwargs): 24 super().__init__(trainer, save_root) 25 self.my_root = save_root 26 self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ 27 os.path.join(self.my_root, "logs", trainer.name) 28 os.makedirs(self.log_dir, exist_ok=True) 29 30 self.tb = SummaryWriter(self.log_dir) 31 self.log_image_interval = trainer.log_image_interval 32 33 def _add_supervised_images(self, step, name, x, y, pred): 34 if x.ndim == 5: 35 assert y.ndim == pred.ndim == 5 36 zindex = x.shape[2] // 2 37 x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex] 38 39 num_channels = y.shape[1] 40 41 images = ( 42 [torch_em.transform.raw.normalize(self._get_image_channel(x))] * num_channels + 43 [y[0, c:c+1] for c in range(num_channels)] + 44 [pred[0, c:c+1] for c in range(num_channels)] 45 ) 46 grid = make_grid(images, nrow=num_channels, padding=8) 47 self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step) 48 49 def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter): 50 if x1.ndim == 5: 51 assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5 52 zindex = x1.shape[2] // 2 53 x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex] 54 pseudo_labels = pseudo_labels[:, :, zindex] 55 if label_filter is not None: 56 assert label_filter.ndim == 5 57 label_filter = label_filter[:, :, zindex] 58 59 num_channels = pred.shape[1] 60 61 images = ( 62 [torch_em.transform.raw.normalize(self._get_image_channel(x1))] + 63 [torch_em.transform.raw.normalize(self._get_image_channel(x2))] + 64 [torch.zeros_like(self._get_image_channel(x1))] * (num_channels - 2) + 65 [pred[0, c:c+1] for c in range(num_channels)] + 66 [pseudo_labels[0, c:c+1] for c in range(num_channels)] 67 ) 68 im_name = f"{name}/unsupervised/image-prediction-pseudolabels" 69 # if trainer with invertible augmentations, untransformed images 70 # and inverted pred/labels are logged for better visual comparison, 71 # otherwise the transformed images are logged 72 if label_filter is not None: 73 images.extend([label_filter[0, c:c+1] for c in range(num_channels)]) 74 im_name += "-labelfilter" 75 grid = make_grid(images, nrow=num_channels, padding=8) 76 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 77 78 def log_combined_loss(self, step, loss): 79 """@private 80 """ 81 self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step) 82 83 def log_lr(self, step, lr): 84 """@private 85 """ 86 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 87 88 def log_train_supervised(self, step, loss, x, y, pred): 89 """@private 90 """ 91 self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step) 92 if step % self.log_image_interval == 0: 93 self._add_supervised_images(step, "train", x, y, pred) 94 95 def log_validation_supervised(self, step, metric, loss, x, y, pred): 96 """@private 97 """ 98 self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step) 99 self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step) 100 self._add_supervised_images(step, "validation", x, y, pred) 101 102 def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None): 103 """@private 104 """ 105 self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step) 106 if step % self.log_image_interval == 0: 107 self._add_unsupervised_images(step, "train", x1, x2, pred, pseudo_labels, label_filter) 108 109 def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None): 110 """@private 111 """ 112 self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step) 113 self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step) 114 self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) 115 116 def log_validation(self, step, metric, loss, gt_metric=None): 117 """@private 118 """ 119 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 120 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 121 if gt_metric is not None: 122 self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step) 123 124 def log_ct(self, step, ct): 125 self.tb.add_scalar(tag="train/confidence_threshold", scalar_value=ct, global_step=step) 126 127 def _add_augmented_images( 128 self, step, name, xu1, xu2, pseudo_labels, pred 129 ): 130 if xu1.ndim == 5: 131 assert ( 132 xu2.ndim 133 == pseudo_labels.ndim 134 == pred.ndim 135 == 5 136 ) 137 zindex = xu1.shape[2] // 2 138 xu1 = xu1[:, :, zindex] 139 xu2 = xu2[:, :, zindex] 140 pred = pred[:, :, zindex] 141 pseudo_labels = pseudo_labels[:, :, zindex] 142 143 images = [ 144 torch_em.transform.raw.normalize(xu1[0]), 145 torch_em.transform.raw.normalize(xu2[0]), 146 pseudo_labels[0, 0:1], 147 pred[0, 0:1], 148 ] 149 im_name = ( 150 f"{name}/unsupervised/aug1-aug2-pseudolabels-prediction" 151 ) 152 grid = make_grid(images, nrow=2, padding=8) 153 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 154 155 def log_train_augmentations( 156 self, step, xu1, xu2, pseudo_labels, pred 157 ): 158 if step % self.log_image_interval == 0: 159 self._add_augmented_images( 160 step, 161 "train_augmentations", 162 xu1, 163 xu2, 164 pseudo_labels, 165 pred, 166 ) 167 168 def log_validation_augmentations( 169 self, step, xu1, xu2, pseudo_labels, pred 170 ): 171 if step % self.log_image_interval == 0: 172 self._add_augmented_images( 173 step, 174 "validation_augmentations", 175 xu1, 176 xu2, 177 pseudo_labels, 178 pred, 179 )
Logger for self-training via torch_em.self_training.FixMatch or torch_em.self_training.MeanTeacher.
Also supports logging training with invertible augmentations.
Arguments:
- trainer: The instantiated trainer class.
- save_root: The root directory for saving the checkpoints and logs.
SelfTrainingTensorboardLogger(trainer, save_root, **unused_kwargs)
23 def __init__(self, trainer, save_root, **unused_kwargs): 24 super().__init__(trainer, save_root) 25 self.my_root = save_root 26 self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ 27 os.path.join(self.my_root, "logs", trainer.name) 28 os.makedirs(self.log_dir, exist_ok=True) 29 30 self.tb = SummaryWriter(self.log_dir) 31 self.log_image_interval = trainer.log_image_interval
Inherited Members
182class UniMatchv2TensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger): 183 """Logger for self-training via `torch_em.self_training.UniMatchv2Trainer`. 184 185 Args: 186 trainer: The instantiated trainer class. 187 save_root: The root directory for saving the checkpoints and logs. 188 """ 189 190 def __init__(self, trainer, save_root, **unused_kwargs): 191 super().__init__(trainer, save_root) 192 self.my_root = save_root 193 self.log_dir = ( 194 f"./logs/{trainer.name}" 195 if self.my_root is None 196 else os.path.join(self.my_root, "logs", trainer.name) 197 ) 198 os.makedirs(self.log_dir, exist_ok=True) 199 200 self.tb = SummaryWriter(self.log_dir) 201 self.log_image_interval = trainer.log_image_interval 202 203 def _add_supervised_images(self, step, name, x, y, pred): 204 if x.ndim == 5: 205 assert y.ndim == pred.ndim == 5 206 zindex = x.shape[2] // 2 207 x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex] 208 209 num_channels = y.shape[1] 210 211 images = ( 212 [torch_em.transform.raw.normalize(x[0])] * num_channels + 213 [y[0, c:c+1] for c in range(num_channels)] + 214 [pred[0, c:c+1] for c in range(num_channels)] 215 ) 216 grid = make_grid(images, nrow=num_channels, padding=8) 217 self.tb.add_image( 218 tag=f"{name}/supervised/input-labels-prediction", 219 img_tensor=grid, 220 global_step=step, 221 ) 222 223 def _add_unsupervised_images( 224 self, step, name, x, pred_s1, pred_s2, pseudo_labels, label_filter 225 ): 226 if x.ndim == 5: 227 assert ( 228 pred_s1.ndim 229 == pred_s2.ndim 230 == pseudo_labels.ndim 231 == 5 232 ) 233 zindex = x.shape[2] // 2 234 x = x[:, :, zindex] 235 pred_s1, pred_s2 = pred_s1[:, :, zindex], pred_s2[:, :, zindex] 236 pseudo_labels = pseudo_labels[:, :, zindex] 237 if label_filter is not None: 238 assert label_filter.ndim == 5 239 label_filter = label_filter[:, :, zindex] 240 num_channels = pred_s1.shape[1] 241 242 images = ( 243 [torch_em.transform.raw.normalize(x[0])] * num_channels + 244 [pred_s1[0, c:c+1] for c in range(num_channels)] + 245 [pred_s2[0, c:c+1] for c in range(num_channels)] + 246 [pseudo_labels[0, c:c+1] for c in range(num_channels)] 247 ) 248 249 im_name = ( 250 f"{name}/unsupervised/image-pred_s1-pred_s2-pseudolabels" 251 ) 252 if label_filter is not None: 253 images.extend([label_filter[0, c:c+1] for c in range(num_channels)]) 254 im_name += "-labelfilter" 255 grid = make_grid(images, nrow=num_channels, padding=8) 256 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 257 258 def log_combined_loss(self, step, loss): 259 """@private""" 260 self.tb.add_scalar( 261 tag="train/combined_loss", scalar_value=loss, global_step=step 262 ) 263 264 def log_lr(self, step, lr): 265 """@private""" 266 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 267 268 def log_train_supervised(self, step, loss, x, y, pred): 269 """@private""" 270 self.tb.add_scalar( 271 tag="train/supervised/loss", scalar_value=loss, global_step=step 272 ) 273 if step % self.log_image_interval == 0: 274 self._add_supervised_images(step, "train", x, y, pred) 275 276 def log_validation_supervised(self, step, metric, loss, x, y, pred): 277 """@private""" 278 self.tb.add_scalar( 279 tag="validation/supervised/loss", scalar_value=loss, global_step=step 280 ) 281 self.tb.add_scalar( 282 tag="validation/supervised/metric", scalar_value=metric, global_step=step 283 ) 284 self._add_supervised_images(step, "validation", x, y, pred) 285 286 def log_train_unsupervised( 287 self, 288 step, 289 loss, 290 x, 291 pred_s1, 292 pred_s2, 293 pseudo_labels, 294 label_filter=None, 295 ): 296 """@private""" 297 self.tb.add_scalar( 298 tag="train/unsupervised/loss", scalar_value=loss, global_step=step 299 ) 300 if step % self.log_image_interval == 0: 301 self._add_unsupervised_images( 302 step, 303 "train", 304 x, 305 pred_s1, 306 pred_s2, 307 pseudo_labels, 308 label_filter, 309 ) 310 311 def log_validation_unsupervised( 312 self, 313 step, 314 metric, 315 loss, 316 x, 317 pred_s1, 318 pred_s2, 319 pseudo_labels, 320 label_filter=None, 321 ): 322 """@private""" 323 self.tb.add_scalar( 324 tag="validation/unsupervised/loss", scalar_value=loss, global_step=step 325 ) 326 self.tb.add_scalar( 327 tag="validation/unsupervised/metric", scalar_value=metric, global_step=step 328 ) 329 self._add_unsupervised_images( 330 step, 331 "validation", 332 x, 333 pred_s1, 334 pred_s2, 335 pseudo_labels, 336 label_filter, 337 ) 338 339 def log_ct(self, step, ct): 340 self.tb.add_scalar( 341 tag="train/confidence_threshold", scalar_value=ct, global_step=step 342 ) 343 344 # LOG AUGMENTATIONS FOR DEBUGGING ### 345 def _add_augmented_images( 346 self, step, name, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2 347 ): 348 if x_u_w.ndim == 5: 349 assert ( 350 x_u_s1.ndim 351 == x_u_s2.ndim 352 == pseudo_labels.ndim 353 == pred_s1.ndim 354 == pred_s2.ndim 355 == 5 356 ) 357 zindex = x_u_w.shape[2] // 2 358 x_u_w = x_u_w[:, :, zindex] 359 x_u_s1, x_u_s2 = x_u_s1[:, :, zindex], x_u_s2[:, :, zindex] 360 pred_s1, pred_s2 = pred_s1[:, :, zindex], pred_s2[:, :, zindex] 361 pseudo_labels = pseudo_labels[:, :, zindex] 362 363 images = [ 364 torch_em.transform.raw.normalize(x_u_w[0]), 365 torch_em.transform.raw.normalize(x_u_s1[0]), 366 torch_em.transform.raw.normalize(x_u_s2[0]), 367 pseudo_labels[0, 0:1], 368 pred_s1[0, 0:1], 369 pred_s2[0, 0:1], 370 ] 371 im_name = ( 372 f"{name}/unsupervised/aug_w-aug_s1-aug_s2-pseudolabels-pred_s1-pred_s2" 373 ) 374 grid = make_grid(images, nrow=3, padding=8) 375 self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) 376 377 def log_train_augmentations( 378 self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2 379 ): 380 if step % self.log_image_interval == 0: 381 self._add_augmented_images( 382 step, 383 "train_augmentations", 384 x_u_w, 385 x_u_s1, 386 x_u_s2, 387 pseudo_labels, 388 pred_s1, 389 pred_s2, 390 ) 391 392 def log_validation_augmentations( 393 self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2 394 ): 395 if step % self.log_image_interval == 0: 396 self._add_augmented_images( 397 step, 398 "validation_augmentations", 399 x_u_w, 400 x_u_s1, 401 x_u_s2, 402 pseudo_labels, 403 pred_s1, 404 pred_s2, 405 )
Logger for self-training via torch_em.self_training.UniMatchv2Trainer.
Arguments:
- trainer: The instantiated trainer class.
- save_root: The root directory for saving the checkpoints and logs.
UniMatchv2TensorboardLogger(trainer, save_root, **unused_kwargs)
190 def __init__(self, trainer, save_root, **unused_kwargs): 191 super().__init__(trainer, save_root) 192 self.my_root = save_root 193 self.log_dir = ( 194 f"./logs/{trainer.name}" 195 if self.my_root is None 196 else os.path.join(self.my_root, "logs", trainer.name) 197 ) 198 os.makedirs(self.log_dir, exist_ok=True) 199 200 self.tb = SummaryWriter(self.log_dir) 201 self.log_image_interval = trainer.log_image_interval
def
log_train_augmentations(self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2):
377 def log_train_augmentations( 378 self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2 379 ): 380 if step % self.log_image_interval == 0: 381 self._add_augmented_images( 382 step, 383 "train_augmentations", 384 x_u_w, 385 x_u_s1, 386 x_u_s2, 387 pseudo_labels, 388 pred_s1, 389 pred_s2, 390 )
def
log_validation_augmentations(self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2):
392 def log_validation_augmentations( 393 self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2 394 ): 395 if step % self.log_image_interval == 0: 396 self._add_augmented_images( 397 step, 398 "validation_augmentations", 399 x_u_w, 400 x_u_s1, 401 x_u_s2, 402 pseudo_labels, 403 pred_s1, 404 pred_s2, 405 )