torch_em.util.debug

  1import os
  2import warnings
  3from typing import Union, Optional
  4
  5import torch
  6import torch.utils.data
  7
  8from .util import ensure_array
  9
 10
 11def _check_plt(loader, n_samples, instance_labels, model=None, device=None, save_path=None):
 12    import matplotlib.pyplot as plt
 13    img_size = 5
 14
 15    fig = None
 16    n_rows = None
 17
 18    def to_index(ns, rid, sid):
 19        index = 1 + rid * ns + sid
 20        return index
 21
 22    for ii, (x, y) in enumerate(loader):
 23        if ii >= n_samples:
 24            break
 25
 26        if model is None:
 27            pred = None
 28        else:
 29            pred = model(x if device is None else x.to(device))
 30            pred = ensure_array(pred)[0]
 31
 32        # cast the data to array and remove the batch axis / choose first sample in batch
 33        x = ensure_array(x)[0]
 34        y = ensure_array(y)[0]
 35        assert x.ndim == y.ndim
 36        if x.ndim == 4:  # 3d data (with channel axis)
 37            z_slice = x.shape[1] // 2
 38            warnings.warn(f"3d input data is not yet supported, will only show slice {z_slice} / {x.shape[1]}")
 39            x, y = x[:, z_slice], y[:, z_slice]
 40            if pred is not None:
 41                pred = pred[:, z_slice]
 42
 43        if x.shape[0] > 1:
 44            warnings.warn(f"Multi-channel input data is not yet supported, will only show channel 0 / {x.shape[0]}")
 45        x = x[0]
 46
 47        if pred is None:
 48            n_target_channels = y.shape[0]
 49        else:
 50            n_target_channels = pred.shape[0]
 51            y = y[:n_target_channels]
 52            assert y.shape[0] == n_target_channels
 53
 54        if fig is None:
 55            n_rows = n_target_channels + 1 if pred is None else 2 * n_target_channels + 1
 56            fig = plt.figure(figsize=(n_samples*img_size, n_rows*img_size))
 57
 58        ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 0, ii))
 59        ax.imshow(x, interpolation="nearest", cmap="Greys_r", aspect="auto")
 60
 61        for chan in range(n_target_channels):
 62            ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + chan, ii))
 63            if instance_labels:
 64                ax.imshow(y[chan].astype("uint32"), interpolation="nearest", aspect="auto")
 65            else:
 66                ax.imshow(y[chan], interpolation="nearest", cmap="Greys_r", aspect="auto")
 67
 68        if pred is not None:
 69            for chan in range(n_target_channels):
 70                ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + n_target_channels + chan, ii))
 71                ax.imshow(pred[chan], interpolation="nearest", cmap="Greys_r", aspect="auto")
 72
 73    if save_path is None:
 74        plt.show()
 75    else:
 76        plt.savefig(save_path)
 77        plt.close()
 78
 79
 80def _check_napari(loader, n_samples, instance_labels, model=None, device=None, rgb=False):
 81    import napari
 82
 83    for ii, sample in enumerate(loader):
 84        if ii >= n_samples:
 85            break
 86
 87        try:
 88            x, y = sample
 89        except ValueError:
 90            x = sample
 91            y = None
 92
 93        if model is None:
 94            pred = None
 95        else:
 96            pred = model(x if device is None else x.to(device))
 97            pred = ensure_array(pred)[0]
 98
 99        x = ensure_array(x)[0]
100        if rgb:
101            assert x.shape[0] == 3
102            x = x.transpose((1, 2, 0))
103
104        v = napari.Viewer()
105        v.add_image(x)
106        if y is not None:
107            y = ensure_array(y)[0]
108            if instance_labels:
109                v.add_labels(y.astype("uint32"))
110            else:
111                v.add_image(y)
112        if pred is not None:
113            v.add_image(pred)
114
115        napari.run()
116
117
118def check_trainer(
119    trainer,
120    n_samples: int,
121    instance_labels: bool = False,
122    split: str = "val",
123    loader: Optional[torch.utils.data.DataLoader] = None,
124    plt: bool = False,
125):
126    """Check a trainer visually.
127
128    This function shows images and labels from the training or validation loader
129    and predictions from the trainer's model for the images.
130    The data will be plotted either with napari or matplotlib.
131
132    Args:
133        trainer: The trainer.
134        n_samples: The number of samples to plot.
135        instance_labels: Whether to visualize the label data as instances.
136        split: Which split to use. This will determine which data loader is used for plotting.
137            Can be one of "val" or "train".
138        loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer.
139        plt: Whether to plot the data with matplotlib instead of napari.
140    """
141    if loader is None:
142        assert split in ("val", "train")
143        loader = trainer.val_loader if split == "val" else trainer.train_loader
144    with torch.no_grad():
145        model = trainer.model
146        model.eval()
147        if plt:
148            _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device)
149        else:
150            _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)
151
152
153def check_loader(
154    loader: torch.utils.data.DataLoader,
155    n_samples: int,
156    instance_labels: bool = False,
157    plt: bool = False,
158    rgb: bool = False,
159    save_path: Optional[Union[str, os.PathLike]] = None,
160):
161    """Check a loader visually.
162
163    This function shows images and labels from the loader with napari or matplotlib.
164
165    Args:
166        loader: The data loader.
167        n_samples: The number of samples to plot.
168        instance_labels: Whether to visualize the label data as instances.
169        plt: Whether to plot the data with matplotlib instead of napari.
170        rgb: Whether the image data is rgb.
171        save_path: Path for saving the images instead of showing them.
172            This argument only has an effect if `plt=True`.
173    """
174    if plt:
175        _check_plt(loader, n_samples, instance_labels, save_path=save_path)
176    else:
177        _check_napari(loader, n_samples, instance_labels, rgb=rgb)
def check_trainer( trainer, n_samples: int, instance_labels: bool = False, split: str = 'val', loader: Optional[torch.utils.data.dataloader.DataLoader] = None, plt: bool = False):
119def check_trainer(
120    trainer,
121    n_samples: int,
122    instance_labels: bool = False,
123    split: str = "val",
124    loader: Optional[torch.utils.data.DataLoader] = None,
125    plt: bool = False,
126):
127    """Check a trainer visually.
128
129    This function shows images and labels from the training or validation loader
130    and predictions from the trainer's model for the images.
131    The data will be plotted either with napari or matplotlib.
132
133    Args:
134        trainer: The trainer.
135        n_samples: The number of samples to plot.
136        instance_labels: Whether to visualize the label data as instances.
137        split: Which split to use. This will determine which data loader is used for plotting.
138            Can be one of "val" or "train".
139        loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer.
140        plt: Whether to plot the data with matplotlib instead of napari.
141    """
142    if loader is None:
143        assert split in ("val", "train")
144        loader = trainer.val_loader if split == "val" else trainer.train_loader
145    with torch.no_grad():
146        model = trainer.model
147        model.eval()
148        if plt:
149            _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device)
150        else:
151            _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)

Check a trainer visually.

This function shows images and labels from the training or validation loader and predictions from the trainer's model for the images. The data will be plotted either with napari or matplotlib.

Arguments:
  • trainer: The trainer.
  • n_samples: The number of samples to plot.
  • instance_labels: Whether to visualize the label data as instances.
  • split: Which split to use. This will determine which data loader is used for plotting. Can be one of "val" or "train".
  • loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer.
  • plt: Whether to plot the data with matplotlib instead of napari.
def check_loader( loader: torch.utils.data.dataloader.DataLoader, n_samples: int, instance_labels: bool = False, plt: bool = False, rgb: bool = False, save_path: Union[str, os.PathLike, NoneType] = None):
154def check_loader(
155    loader: torch.utils.data.DataLoader,
156    n_samples: int,
157    instance_labels: bool = False,
158    plt: bool = False,
159    rgb: bool = False,
160    save_path: Optional[Union[str, os.PathLike]] = None,
161):
162    """Check a loader visually.
163
164    This function shows images and labels from the loader with napari or matplotlib.
165
166    Args:
167        loader: The data loader.
168        n_samples: The number of samples to plot.
169        instance_labels: Whether to visualize the label data as instances.
170        plt: Whether to plot the data with matplotlib instead of napari.
171        rgb: Whether the image data is rgb.
172        save_path: Path for saving the images instead of showing them.
173            This argument only has an effect if `plt=True`.
174    """
175    if plt:
176        _check_plt(loader, n_samples, instance_labels, save_path=save_path)
177    else:
178        _check_napari(loader, n_samples, instance_labels, rgb=rgb)

Check a loader visually.

This function shows images and labels from the loader with napari or matplotlib.

Arguments:
  • loader: The data loader.
  • n_samples: The number of samples to plot.
  • instance_labels: Whether to visualize the label data as instances.
  • plt: Whether to plot the data with matplotlib instead of napari.
  • rgb: Whether the image data is rgb.
  • save_path: Path for saving the images instead of showing them. This argument only has an effect if plt=True.