torch_em.self_training.logger

  1import os
  2
  3import torch_em
  4import torch
  5
  6from torchvision.utils import make_grid
  7from torch.utils.tensorboard import SummaryWriter
  8
  9
 10class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
 11    """Logger for self-training via `torch_em.self_training.FixMatch` or `torch_em.self_training.MeanTeacher`.
 12    Also supports logging training with invertible augmentations.
 13
 14    Args:
 15        trainer: The instantiated trainer class.
 16        save_root: The root directory for saving the checkpoints and logs.
 17    """
 18    @staticmethod
 19    def _get_image_channel(x):
 20        return x[0, 0:1] if x.shape[1] > 1 else x[0]
 21
 22    def __init__(self, trainer, save_root, **unused_kwargs):
 23        super().__init__(trainer, save_root)
 24        self.my_root = save_root
 25        self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\
 26            os.path.join(self.my_root, "logs", trainer.name)
 27        os.makedirs(self.log_dir, exist_ok=True)
 28
 29        self.tb = SummaryWriter(self.log_dir)
 30        self.log_image_interval = trainer.log_image_interval
 31
 32    def _add_supervised_images(self, step, name, x, y, pred):
 33        if x.ndim == 5:
 34            assert y.ndim == pred.ndim == 5
 35            zindex = x.shape[2] // 2
 36            x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex]
 37
 38        num_channels = y.shape[1]
 39
 40        images = (
 41            [torch_em.transform.raw.normalize(self._get_image_channel(x))] * num_channels +
 42            [y[0, c:c+1] for c in range(num_channels)] +
 43            [pred[0, c:c+1] for c in range(num_channels)]
 44        )
 45        grid = make_grid(images, nrow=num_channels, padding=8)
 46        self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step)
 47
 48    def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter):
 49        if x1.ndim == 5:
 50            assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5
 51            zindex = x1.shape[2] // 2
 52            x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex]
 53            pseudo_labels = pseudo_labels[:, :, zindex]
 54            if label_filter is not None:
 55                assert label_filter.ndim == 5
 56                label_filter = label_filter[:, :, zindex]
 57
 58        num_channels = pred.shape[1]
 59
 60        images = (
 61            [torch_em.transform.raw.normalize(self._get_image_channel(x1))] +
 62            [torch_em.transform.raw.normalize(self._get_image_channel(x2))] +
 63            [torch.zeros_like(self._get_image_channel(x1))] * (num_channels - 2) +
 64            [pred[0, c:c+1] for c in range(num_channels)] +
 65            [pseudo_labels[0, c:c+1] for c in range(num_channels)]
 66        )
 67        im_name = f"{name}/unsupervised/image-prediction-pseudolabels"
 68        # if trainer with invertible augmentations, untransformed images
 69        # and inverted pred/labels are logged for better visual comparison,
 70        # otherwise the transformed images are logged
 71        if label_filter is not None:
 72            images.extend([label_filter[0, c:c+1] for c in range(num_channels)])
 73            im_name += "-labelfilter"
 74        grid = make_grid(images, nrow=num_channels, padding=8)
 75        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
 76
 77    def log_combined_loss(self, step, loss):
 78        """@private
 79        """
 80        self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step)
 81
 82    def log_lr(self, step, lr):
 83        """@private
 84        """
 85        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
 86
 87    def log_train_supervised(self, step, loss, x, y, pred):
 88        """@private
 89        """
 90        self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step)
 91        if step % self.log_image_interval == 0:
 92            self._add_supervised_images(step, "train", x, y, pred)
 93
 94    def log_validation_supervised(self, step, metric, loss, x, y, pred):
 95        """@private
 96        """
 97        self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step)
 98        self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step)
 99        self._add_supervised_images(step, "validation", x, y, pred)
100
101    def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None):
102        """@private
103        """
104        self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step)
105        if step % self.log_image_interval == 0:
106            self._add_unsupervised_images(step, "train", x1, x2, pred, pseudo_labels, label_filter)
107
108    def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None):
109        """@private
110        """
111        self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step)
112        self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step)
113        self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
114
115    def log_validation(self, step, metric, loss, gt_metric=None):
116        """@private
117        """
118        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
119        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
120        if gt_metric is not None:
121            self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)
122
123    def log_ct(self, step, ct):
124        self.tb.add_scalar(tag="train/confidence_threshold", scalar_value=ct, global_step=step)
125
126    def _add_augmented_images(
127        self, step, name, xu1, xu2, pseudo_labels, pred
128    ):
129        if xu1.ndim == 5:
130            assert (
131                xu2.ndim
132                == pseudo_labels.ndim
133                == pred.ndim
134                == 5
135            )
136            zindex = xu1.shape[2] // 2
137            xu1 = xu1[:, :, zindex]
138            xu2 = xu2[:, :, zindex]
139            pred = pred[:, :, zindex]
140            pseudo_labels = pseudo_labels[:, :, zindex]
141
142        images = [
143            torch_em.transform.raw.normalize(xu1[0]),
144            torch_em.transform.raw.normalize(xu2[0]),
145            pseudo_labels[0, 0:1],
146            pred[0, 0:1],
147        ]
148        im_name = (
149            f"{name}/unsupervised/aug1-aug2-pseudolabels-prediction"
150        )
151        grid = make_grid(images, nrow=2, padding=8)
152        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
153
154    def log_train_augmentations(
155        self, step, xu1, xu2, pseudo_labels, pred
156    ):
157        if step % self.log_image_interval == 0:
158            self._add_augmented_images(
159                step,
160                "train_augmentations",
161                xu1,
162                xu2,
163                pseudo_labels,
164                pred,
165            )
166
167    def log_validation_augmentations(
168        self, step, xu1, xu2, pseudo_labels, pred
169    ):
170        if step % self.log_image_interval == 0:
171            self._add_augmented_images(
172                step,
173                "validation_augmentations",
174                xu1,
175                xu2,
176                pseudo_labels,
177                pred,
178            )
179
180
181class UniMatchv2TensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
182    """Logger for self-training via `torch_em.self_training.UniMatchv2Trainer`.
183
184    Args:
185        trainer: The instantiated trainer class.
186        save_root: The root directory for saving the checkpoints and logs.
187    """
188
189    def __init__(self, trainer, save_root, **unused_kwargs):
190        super().__init__(trainer, save_root)
191        self.my_root = save_root
192        self.log_dir = (
193            f"./logs/{trainer.name}"
194            if self.my_root is None
195            else os.path.join(self.my_root, "logs", trainer.name)
196        )
197        os.makedirs(self.log_dir, exist_ok=True)
198
199        self.tb = SummaryWriter(self.log_dir)
200        self.log_image_interval = trainer.log_image_interval
201
202    def _add_supervised_images(self, step, name, x, y, pred):
203        if x.ndim == 5:
204            assert y.ndim == pred.ndim == 5
205            zindex = x.shape[2] // 2
206            x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex]
207
208        num_channels = y.shape[1]
209
210        images = (
211            [torch_em.transform.raw.normalize(x[0])] * num_channels +
212            [y[0, c:c+1] for c in range(num_channels)] +
213            [pred[0, c:c+1] for c in range(num_channels)]
214        )
215        grid = make_grid(images, nrow=num_channels, padding=8)
216        self.tb.add_image(
217            tag=f"{name}/supervised/input-labels-prediction",
218            img_tensor=grid,
219            global_step=step,
220        )
221
222    def _add_unsupervised_images(
223        self, step, name, x, pred_s1, pred_s2, pseudo_labels, label_filter
224    ):
225        if x.ndim == 5:
226            assert (
227                pred_s1.ndim
228                == pred_s2.ndim
229                == pseudo_labels.ndim
230                == 5
231            )
232            zindex = x.shape[2] // 2
233            x = x[:, :, zindex]
234            pred_s1, pred_s2 = pred_s1[:, :, zindex], pred_s2[:, :, zindex]
235            pseudo_labels = pseudo_labels[:, :, zindex]
236            if label_filter is not None:
237                assert label_filter.ndim == 5
238                label_filter = label_filter[:, :, zindex]
239        num_channels = pred_s1.shape[1]
240
241        images = (
242            [torch_em.transform.raw.normalize(x[0])] * num_channels +
243            [pred_s1[0, c:c+1] for c in range(num_channels)] +
244            [pred_s2[0, c:c+1] for c in range(num_channels)] +
245            [pseudo_labels[0, c:c+1] for c in range(num_channels)]
246        )
247
248        im_name = (
249            f"{name}/unsupervised/image-pred_s1-pred_s2-pseudolabels"
250        )
251        if label_filter is not None:
252            images.extend([label_filter[0, c:c+1] for c in range(num_channels)])
253            im_name += "-labelfilter"
254        grid = make_grid(images, nrow=num_channels, padding=8)
255        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
256
257    def log_combined_loss(self, step, loss):
258        """@private"""
259        self.tb.add_scalar(
260            tag="train/combined_loss", scalar_value=loss, global_step=step
261        )
262
263    def log_lr(self, step, lr):
264        """@private"""
265        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
266
267    def log_train_supervised(self, step, loss, x, y, pred):
268        """@private"""
269        self.tb.add_scalar(
270            tag="train/supervised/loss", scalar_value=loss, global_step=step
271        )
272        if step % self.log_image_interval == 0:
273            self._add_supervised_images(step, "train", x, y, pred)
274
275    def log_validation_supervised(self, step, metric, loss, x, y, pred):
276        """@private"""
277        self.tb.add_scalar(
278            tag="validation/supervised/loss", scalar_value=loss, global_step=step
279        )
280        self.tb.add_scalar(
281            tag="validation/supervised/metric", scalar_value=metric, global_step=step
282        )
283        self._add_supervised_images(step, "validation", x, y, pred)
284
285    def log_train_unsupervised(
286        self,
287        step,
288        loss,
289        x,
290        pred_s1,
291        pred_s2,
292        pseudo_labels,
293        label_filter=None,
294    ):
295        """@private"""
296        self.tb.add_scalar(
297            tag="train/unsupervised/loss", scalar_value=loss, global_step=step
298        )
299        if step % self.log_image_interval == 0:
300            self._add_unsupervised_images(
301                step,
302                "train",
303                x,
304                pred_s1,
305                pred_s2,
306                pseudo_labels,
307                label_filter,
308            )
309
310    def log_validation_unsupervised(
311        self,
312        step,
313        metric,
314        loss,
315        x,
316        pred_s1,
317        pred_s2,
318        pseudo_labels,
319        label_filter=None,
320    ):
321        """@private"""
322        self.tb.add_scalar(
323            tag="validation/unsupervised/loss", scalar_value=loss, global_step=step
324        )
325        self.tb.add_scalar(
326            tag="validation/unsupervised/metric", scalar_value=metric, global_step=step
327        )
328        self._add_unsupervised_images(
329            step,
330            "validation",
331            x,
332            pred_s1,
333            pred_s2,
334            pseudo_labels,
335            label_filter,
336        )
337
338    def log_ct(self, step, ct):
339        self.tb.add_scalar(
340            tag="train/confidence_threshold", scalar_value=ct, global_step=step
341        )
342
343    # LOG AUGMENTATIONS FOR DEBUGGING ###
344    def _add_augmented_images(
345        self, step, name, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2
346    ):
347        if x_u_w.ndim == 5:
348            assert (
349                x_u_s1.ndim
350                == x_u_s2.ndim
351                == pseudo_labels.ndim
352                == pred_s1.ndim
353                == pred_s2.ndim
354                == 5
355            )
356            zindex = x_u_w.shape[2] // 2
357            x_u_w = x_u_w[:, :, zindex]
358            x_u_s1, x_u_s2 = x_u_s1[:, :, zindex], x_u_s2[:, :, zindex]
359            pred_s1, pred_s2 = pred_s1[:, :, zindex], pred_s2[:, :, zindex]
360            pseudo_labels = pseudo_labels[:, :, zindex]
361
362        images = [
363            torch_em.transform.raw.normalize(x_u_w[0]),
364            torch_em.transform.raw.normalize(x_u_s1[0]),
365            torch_em.transform.raw.normalize(x_u_s2[0]),
366            pseudo_labels[0, 0:1],
367            pred_s1[0, 0:1],
368            pred_s2[0, 0:1],
369        ]
370        im_name = (
371            f"{name}/unsupervised/aug_w-aug_s1-aug_s2-pseudolabels-pred_s1-pred_s2"
372        )
373        grid = make_grid(images, nrow=3, padding=8)
374        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
375
376    def log_train_augmentations(
377        self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2
378    ):
379        if step % self.log_image_interval == 0:
380            self._add_augmented_images(
381                step,
382                "train_augmentations",
383                x_u_w,
384                x_u_s1,
385                x_u_s2,
386                pseudo_labels,
387                pred_s1,
388                pred_s2,
389            )
390
391    def log_validation_augmentations(
392        self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2
393    ):
394        if step % self.log_image_interval == 0:
395            self._add_augmented_images(
396                step,
397                "validation_augmentations",
398                x_u_w,
399                x_u_s1,
400                x_u_s2,
401                pseudo_labels,
402                pred_s1,
403                pred_s2,
404            )
class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
 11class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
 12    """Logger for self-training via `torch_em.self_training.FixMatch` or `torch_em.self_training.MeanTeacher`.
 13    Also supports logging training with invertible augmentations.
 14
 15    Args:
 16        trainer: The instantiated trainer class.
 17        save_root: The root directory for saving the checkpoints and logs.
 18    """
 19    @staticmethod
 20    def _get_image_channel(x):
 21        return x[0, 0:1] if x.shape[1] > 1 else x[0]
 22
 23    def __init__(self, trainer, save_root, **unused_kwargs):
 24        super().__init__(trainer, save_root)
 25        self.my_root = save_root
 26        self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\
 27            os.path.join(self.my_root, "logs", trainer.name)
 28        os.makedirs(self.log_dir, exist_ok=True)
 29
 30        self.tb = SummaryWriter(self.log_dir)
 31        self.log_image_interval = trainer.log_image_interval
 32
 33    def _add_supervised_images(self, step, name, x, y, pred):
 34        if x.ndim == 5:
 35            assert y.ndim == pred.ndim == 5
 36            zindex = x.shape[2] // 2
 37            x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex]
 38
 39        num_channels = y.shape[1]
 40
 41        images = (
 42            [torch_em.transform.raw.normalize(self._get_image_channel(x))] * num_channels +
 43            [y[0, c:c+1] for c in range(num_channels)] +
 44            [pred[0, c:c+1] for c in range(num_channels)]
 45        )
 46        grid = make_grid(images, nrow=num_channels, padding=8)
 47        self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step)
 48
 49    def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter):
 50        if x1.ndim == 5:
 51            assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5
 52            zindex = x1.shape[2] // 2
 53            x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex]
 54            pseudo_labels = pseudo_labels[:, :, zindex]
 55            if label_filter is not None:
 56                assert label_filter.ndim == 5
 57                label_filter = label_filter[:, :, zindex]
 58
 59        num_channels = pred.shape[1]
 60
 61        images = (
 62            [torch_em.transform.raw.normalize(self._get_image_channel(x1))] +
 63            [torch_em.transform.raw.normalize(self._get_image_channel(x2))] +
 64            [torch.zeros_like(self._get_image_channel(x1))] * (num_channels - 2) +
 65            [pred[0, c:c+1] for c in range(num_channels)] +
 66            [pseudo_labels[0, c:c+1] for c in range(num_channels)]
 67        )
 68        im_name = f"{name}/unsupervised/image-prediction-pseudolabels"
 69        # if trainer with invertible augmentations, untransformed images
 70        # and inverted pred/labels are logged for better visual comparison,
 71        # otherwise the transformed images are logged
 72        if label_filter is not None:
 73            images.extend([label_filter[0, c:c+1] for c in range(num_channels)])
 74            im_name += "-labelfilter"
 75        grid = make_grid(images, nrow=num_channels, padding=8)
 76        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
 77
 78    def log_combined_loss(self, step, loss):
 79        """@private
 80        """
 81        self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step)
 82
 83    def log_lr(self, step, lr):
 84        """@private
 85        """
 86        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
 87
 88    def log_train_supervised(self, step, loss, x, y, pred):
 89        """@private
 90        """
 91        self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step)
 92        if step % self.log_image_interval == 0:
 93            self._add_supervised_images(step, "train", x, y, pred)
 94
 95    def log_validation_supervised(self, step, metric, loss, x, y, pred):
 96        """@private
 97        """
 98        self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step)
 99        self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step)
100        self._add_supervised_images(step, "validation", x, y, pred)
101
102    def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None):
103        """@private
104        """
105        self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step)
106        if step % self.log_image_interval == 0:
107            self._add_unsupervised_images(step, "train", x1, x2, pred, pseudo_labels, label_filter)
108
109    def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None):
110        """@private
111        """
112        self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step)
113        self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step)
114        self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter)
115
116    def log_validation(self, step, metric, loss, gt_metric=None):
117        """@private
118        """
119        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
120        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
121        if gt_metric is not None:
122            self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)
123
124    def log_ct(self, step, ct):
125        self.tb.add_scalar(tag="train/confidence_threshold", scalar_value=ct, global_step=step)
126
127    def _add_augmented_images(
128        self, step, name, xu1, xu2, pseudo_labels, pred
129    ):
130        if xu1.ndim == 5:
131            assert (
132                xu2.ndim
133                == pseudo_labels.ndim
134                == pred.ndim
135                == 5
136            )
137            zindex = xu1.shape[2] // 2
138            xu1 = xu1[:, :, zindex]
139            xu2 = xu2[:, :, zindex]
140            pred = pred[:, :, zindex]
141            pseudo_labels = pseudo_labels[:, :, zindex]
142
143        images = [
144            torch_em.transform.raw.normalize(xu1[0]),
145            torch_em.transform.raw.normalize(xu2[0]),
146            pseudo_labels[0, 0:1],
147            pred[0, 0:1],
148        ]
149        im_name = (
150            f"{name}/unsupervised/aug1-aug2-pseudolabels-prediction"
151        )
152        grid = make_grid(images, nrow=2, padding=8)
153        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
154
155    def log_train_augmentations(
156        self, step, xu1, xu2, pseudo_labels, pred
157    ):
158        if step % self.log_image_interval == 0:
159            self._add_augmented_images(
160                step,
161                "train_augmentations",
162                xu1,
163                xu2,
164                pseudo_labels,
165                pred,
166            )
167
168    def log_validation_augmentations(
169        self, step, xu1, xu2, pseudo_labels, pred
170    ):
171        if step % self.log_image_interval == 0:
172            self._add_augmented_images(
173                step,
174                "validation_augmentations",
175                xu1,
176                xu2,
177                pseudo_labels,
178                pred,
179            )

Logger for self-training via torch_em.self_training.FixMatch or torch_em.self_training.MeanTeacher. Also supports logging training with invertible augmentations.

Arguments:
  • trainer: The instantiated trainer class.
  • save_root: The root directory for saving the checkpoints and logs.
SelfTrainingTensorboardLogger(trainer, save_root, **unused_kwargs)
23    def __init__(self, trainer, save_root, **unused_kwargs):
24        super().__init__(trainer, save_root)
25        self.my_root = save_root
26        self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\
27            os.path.join(self.my_root, "logs", trainer.name)
28        os.makedirs(self.log_dir, exist_ok=True)
29
30        self.tb = SummaryWriter(self.log_dir)
31        self.log_image_interval = trainer.log_image_interval
my_root
log_dir
tb
log_image_interval
def log_ct(self, step, ct):
124    def log_ct(self, step, ct):
125        self.tb.add_scalar(tag="train/confidence_threshold", scalar_value=ct, global_step=step)
def log_train_augmentations(self, step, xu1, xu2, pseudo_labels, pred):
155    def log_train_augmentations(
156        self, step, xu1, xu2, pseudo_labels, pred
157    ):
158        if step % self.log_image_interval == 0:
159            self._add_augmented_images(
160                step,
161                "train_augmentations",
162                xu1,
163                xu2,
164                pseudo_labels,
165                pred,
166            )
def log_validation_augmentations(self, step, xu1, xu2, pseudo_labels, pred):
168    def log_validation_augmentations(
169        self, step, xu1, xu2, pseudo_labels, pred
170    ):
171        if step % self.log_image_interval == 0:
172            self._add_augmented_images(
173                step,
174                "validation_augmentations",
175                xu1,
176                xu2,
177                pseudo_labels,
178                pred,
179            )
class UniMatchv2TensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
182class UniMatchv2TensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger):
183    """Logger for self-training via `torch_em.self_training.UniMatchv2Trainer`.
184
185    Args:
186        trainer: The instantiated trainer class.
187        save_root: The root directory for saving the checkpoints and logs.
188    """
189
190    def __init__(self, trainer, save_root, **unused_kwargs):
191        super().__init__(trainer, save_root)
192        self.my_root = save_root
193        self.log_dir = (
194            f"./logs/{trainer.name}"
195            if self.my_root is None
196            else os.path.join(self.my_root, "logs", trainer.name)
197        )
198        os.makedirs(self.log_dir, exist_ok=True)
199
200        self.tb = SummaryWriter(self.log_dir)
201        self.log_image_interval = trainer.log_image_interval
202
203    def _add_supervised_images(self, step, name, x, y, pred):
204        if x.ndim == 5:
205            assert y.ndim == pred.ndim == 5
206            zindex = x.shape[2] // 2
207            x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex]
208
209        num_channels = y.shape[1]
210
211        images = (
212            [torch_em.transform.raw.normalize(x[0])] * num_channels +
213            [y[0, c:c+1] for c in range(num_channels)] +
214            [pred[0, c:c+1] for c in range(num_channels)]
215        )
216        grid = make_grid(images, nrow=num_channels, padding=8)
217        self.tb.add_image(
218            tag=f"{name}/supervised/input-labels-prediction",
219            img_tensor=grid,
220            global_step=step,
221        )
222
223    def _add_unsupervised_images(
224        self, step, name, x, pred_s1, pred_s2, pseudo_labels, label_filter
225    ):
226        if x.ndim == 5:
227            assert (
228                pred_s1.ndim
229                == pred_s2.ndim
230                == pseudo_labels.ndim
231                == 5
232            )
233            zindex = x.shape[2] // 2
234            x = x[:, :, zindex]
235            pred_s1, pred_s2 = pred_s1[:, :, zindex], pred_s2[:, :, zindex]
236            pseudo_labels = pseudo_labels[:, :, zindex]
237            if label_filter is not None:
238                assert label_filter.ndim == 5
239                label_filter = label_filter[:, :, zindex]
240        num_channels = pred_s1.shape[1]
241
242        images = (
243            [torch_em.transform.raw.normalize(x[0])] * num_channels +
244            [pred_s1[0, c:c+1] for c in range(num_channels)] +
245            [pred_s2[0, c:c+1] for c in range(num_channels)] +
246            [pseudo_labels[0, c:c+1] for c in range(num_channels)]
247        )
248
249        im_name = (
250            f"{name}/unsupervised/image-pred_s1-pred_s2-pseudolabels"
251        )
252        if label_filter is not None:
253            images.extend([label_filter[0, c:c+1] for c in range(num_channels)])
254            im_name += "-labelfilter"
255        grid = make_grid(images, nrow=num_channels, padding=8)
256        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
257
258    def log_combined_loss(self, step, loss):
259        """@private"""
260        self.tb.add_scalar(
261            tag="train/combined_loss", scalar_value=loss, global_step=step
262        )
263
264    def log_lr(self, step, lr):
265        """@private"""
266        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
267
268    def log_train_supervised(self, step, loss, x, y, pred):
269        """@private"""
270        self.tb.add_scalar(
271            tag="train/supervised/loss", scalar_value=loss, global_step=step
272        )
273        if step % self.log_image_interval == 0:
274            self._add_supervised_images(step, "train", x, y, pred)
275
276    def log_validation_supervised(self, step, metric, loss, x, y, pred):
277        """@private"""
278        self.tb.add_scalar(
279            tag="validation/supervised/loss", scalar_value=loss, global_step=step
280        )
281        self.tb.add_scalar(
282            tag="validation/supervised/metric", scalar_value=metric, global_step=step
283        )
284        self._add_supervised_images(step, "validation", x, y, pred)
285
286    def log_train_unsupervised(
287        self,
288        step,
289        loss,
290        x,
291        pred_s1,
292        pred_s2,
293        pseudo_labels,
294        label_filter=None,
295    ):
296        """@private"""
297        self.tb.add_scalar(
298            tag="train/unsupervised/loss", scalar_value=loss, global_step=step
299        )
300        if step % self.log_image_interval == 0:
301            self._add_unsupervised_images(
302                step,
303                "train",
304                x,
305                pred_s1,
306                pred_s2,
307                pseudo_labels,
308                label_filter,
309            )
310
311    def log_validation_unsupervised(
312        self,
313        step,
314        metric,
315        loss,
316        x,
317        pred_s1,
318        pred_s2,
319        pseudo_labels,
320        label_filter=None,
321    ):
322        """@private"""
323        self.tb.add_scalar(
324            tag="validation/unsupervised/loss", scalar_value=loss, global_step=step
325        )
326        self.tb.add_scalar(
327            tag="validation/unsupervised/metric", scalar_value=metric, global_step=step
328        )
329        self._add_unsupervised_images(
330            step,
331            "validation",
332            x,
333            pred_s1,
334            pred_s2,
335            pseudo_labels,
336            label_filter,
337        )
338
339    def log_ct(self, step, ct):
340        self.tb.add_scalar(
341            tag="train/confidence_threshold", scalar_value=ct, global_step=step
342        )
343
344    # LOG AUGMENTATIONS FOR DEBUGGING ###
345    def _add_augmented_images(
346        self, step, name, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2
347    ):
348        if x_u_w.ndim == 5:
349            assert (
350                x_u_s1.ndim
351                == x_u_s2.ndim
352                == pseudo_labels.ndim
353                == pred_s1.ndim
354                == pred_s2.ndim
355                == 5
356            )
357            zindex = x_u_w.shape[2] // 2
358            x_u_w = x_u_w[:, :, zindex]
359            x_u_s1, x_u_s2 = x_u_s1[:, :, zindex], x_u_s2[:, :, zindex]
360            pred_s1, pred_s2 = pred_s1[:, :, zindex], pred_s2[:, :, zindex]
361            pseudo_labels = pseudo_labels[:, :, zindex]
362
363        images = [
364            torch_em.transform.raw.normalize(x_u_w[0]),
365            torch_em.transform.raw.normalize(x_u_s1[0]),
366            torch_em.transform.raw.normalize(x_u_s2[0]),
367            pseudo_labels[0, 0:1],
368            pred_s1[0, 0:1],
369            pred_s2[0, 0:1],
370        ]
371        im_name = (
372            f"{name}/unsupervised/aug_w-aug_s1-aug_s2-pseudolabels-pred_s1-pred_s2"
373        )
374        grid = make_grid(images, nrow=3, padding=8)
375        self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step)
376
377    def log_train_augmentations(
378        self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2
379    ):
380        if step % self.log_image_interval == 0:
381            self._add_augmented_images(
382                step,
383                "train_augmentations",
384                x_u_w,
385                x_u_s1,
386                x_u_s2,
387                pseudo_labels,
388                pred_s1,
389                pred_s2,
390            )
391
392    def log_validation_augmentations(
393        self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2
394    ):
395        if step % self.log_image_interval == 0:
396            self._add_augmented_images(
397                step,
398                "validation_augmentations",
399                x_u_w,
400                x_u_s1,
401                x_u_s2,
402                pseudo_labels,
403                pred_s1,
404                pred_s2,
405            )

Logger for self-training via torch_em.self_training.UniMatchv2Trainer.

Arguments:
  • trainer: The instantiated trainer class.
  • save_root: The root directory for saving the checkpoints and logs.
UniMatchv2TensorboardLogger(trainer, save_root, **unused_kwargs)
190    def __init__(self, trainer, save_root, **unused_kwargs):
191        super().__init__(trainer, save_root)
192        self.my_root = save_root
193        self.log_dir = (
194            f"./logs/{trainer.name}"
195            if self.my_root is None
196            else os.path.join(self.my_root, "logs", trainer.name)
197        )
198        os.makedirs(self.log_dir, exist_ok=True)
199
200        self.tb = SummaryWriter(self.log_dir)
201        self.log_image_interval = trainer.log_image_interval
my_root
log_dir
tb
log_image_interval
def log_ct(self, step, ct):
339    def log_ct(self, step, ct):
340        self.tb.add_scalar(
341            tag="train/confidence_threshold", scalar_value=ct, global_step=step
342        )
def log_train_augmentations(self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2):
377    def log_train_augmentations(
378        self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2
379    ):
380        if step % self.log_image_interval == 0:
381            self._add_augmented_images(
382                step,
383                "train_augmentations",
384                x_u_w,
385                x_u_s1,
386                x_u_s2,
387                pseudo_labels,
388                pred_s1,
389                pred_s2,
390            )
def log_validation_augmentations(self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2):
392    def log_validation_augmentations(
393        self, step, x_u_w, x_u_s1, x_u_s2, pseudo_labels, pred_s1, pred_s2
394    ):
395        if step % self.log_image_interval == 0:
396            self._add_augmented_images(
397                step,
398                "validation_augmentations",
399                x_u_w,
400                x_u_s1,
401                x_u_s2,
402                pseudo_labels,
403                pred_s1,
404                pred_s2,
405            )