torch_em.util.debug

  1import warnings
  2import torch
  3from .util import ensure_array
  4
  5
  6def _check_plt(loader, n_samples, instance_labels, model=None, device=None, save_path=None):
  7    import matplotlib.pyplot as plt
  8    img_size = 5
  9
 10    fig = None
 11    n_rows = None
 12
 13    def to_index(ns, rid, sid):
 14        index = 1 + rid * ns + sid
 15        return index
 16
 17    for ii, (x, y) in enumerate(loader):
 18        if ii >= n_samples:
 19            break
 20
 21        if model is None:
 22            pred = None
 23        else:
 24            pred = model(x if device is None else x.to(device))
 25            pred = ensure_array(pred)[0]
 26
 27        # cast the data to array and remove the batch axis / choose first sample in batch
 28        x = ensure_array(x)[0]
 29        y = ensure_array(y)[0]
 30        assert x.ndim == y.ndim
 31        if x.ndim == 4:  # 3d data (with channel axis)
 32            z_slice = x.shape[1] // 2
 33            warnings.warn(f"3d input data is not yet supported, will only show slice {z_slice} / {x.shape[1]}")
 34            x, y = x[:, z_slice], y[:, z_slice]
 35            if pred is not None:
 36                pred = pred[:, z_slice]
 37
 38        if x.shape[0] > 1:
 39            warnings.warn(f"Multi-channel input data is not yet supported, will only show channel 0 / {x.shape[0]}")
 40        x = x[0]
 41
 42        if pred is None:
 43            n_target_channels = y.shape[0]
 44        else:
 45            n_target_channels = pred.shape[0]
 46            y = y[:n_target_channels]
 47            assert y.shape[0] == n_target_channels
 48
 49        if fig is None:
 50            n_rows = n_target_channels + 1 if pred is None else 2 * n_target_channels + 1
 51            fig = plt.figure(figsize=(n_samples*img_size, n_rows*img_size))
 52
 53        ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 0, ii))
 54        ax.imshow(x, interpolation="nearest", cmap="Greys_r", aspect="auto")
 55
 56        for chan in range(n_target_channels):
 57            ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + chan, ii))
 58            if instance_labels:
 59                ax.imshow(y[chan].astype("uint32"), interpolation="nearest", aspect="auto")
 60            else:
 61                ax.imshow(y[chan], interpolation="nearest", cmap="Greys_r", aspect="auto")
 62
 63        if pred is not None:
 64            for chan in range(n_target_channels):
 65                ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + n_target_channels + chan, ii))
 66                ax.imshow(pred[chan], interpolation="nearest", cmap="Greys_r", aspect="auto")
 67
 68    if save_path is None:
 69        plt.show()
 70    else:
 71        plt.savefig(save_path)
 72        plt.close()
 73
 74
 75def _check_napari(loader, n_samples, instance_labels, model=None, device=None, rgb=False):
 76    import napari
 77
 78    for ii, sample in enumerate(loader):
 79        if ii >= n_samples:
 80            break
 81
 82        try:
 83            x, y = sample
 84        except ValueError:
 85            x = sample
 86            y = None
 87
 88        if model is None:
 89            pred = None
 90        else:
 91            pred = model(x if device is None else x.to(device))
 92            pred = ensure_array(pred)[0]
 93
 94        x = ensure_array(x)[0]
 95        if rgb:
 96            assert x.shape[0] == 3
 97            x = x.transpose((1, 2, 0))
 98
 99        v = napari.Viewer()
100        v.add_image(x)
101        if y is not None:
102            y = ensure_array(y)[0]
103            if instance_labels:
104                v.add_labels(y.astype("uint32"))
105            else:
106                v.add_image(y)
107        if pred is not None:
108            v.add_image(pred)
109        napari.run()
110
111
112def check_trainer(trainer, n_samples, instance_labels=False, split="val", loader=None, plt=False):
113    if loader is None:
114        assert split in ("val", "train")
115        loader = trainer.val_loader
116    with torch.no_grad():
117        model = trainer.model
118        model.eval()
119        if plt:
120            _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device)
121        else:
122            _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)
123
124
125def check_loader(loader, n_samples, instance_labels=False, plt=False, rgb=False, save_path=None):
126    if plt:
127        _check_plt(loader, n_samples, instance_labels, save_path=save_path)
128    else:
129        _check_napari(loader, n_samples, instance_labels, rgb=rgb)
def check_trainer( trainer, n_samples, instance_labels=False, split='val', loader=None, plt=False):
113def check_trainer(trainer, n_samples, instance_labels=False, split="val", loader=None, plt=False):
114    if loader is None:
115        assert split in ("val", "train")
116        loader = trainer.val_loader
117    with torch.no_grad():
118        model = trainer.model
119        model.eval()
120        if plt:
121            _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device)
122        else:
123            _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)
def check_loader( loader, n_samples, instance_labels=False, plt=False, rgb=False, save_path=None):
126def check_loader(loader, n_samples, instance_labels=False, plt=False, rgb=False, save_path=None):
127    if plt:
128        _check_plt(loader, n_samples, instance_labels, save_path=save_path)
129    else:
130        _check_napari(loader, n_samples, instance_labels, rgb=rgb)