torch_em.util.debug

  1import os
  2import warnings
  3from typing import Union, Optional
  4
  5import torch
  6import torch.utils.data
  7
  8from .util import ensure_array, get_random_colors
  9
 10
 11def _check_plt(loader, n_samples, instance_labels, model=None, device=None, save_path=None):
 12    import matplotlib.pyplot as plt
 13    img_size = 5
 14
 15    fig = None
 16    n_rows = None
 17
 18    def to_index(ns, rid, sid):
 19        index = 1 + rid * ns + sid
 20        return index
 21
 22    for ii, (x, y) in enumerate(loader):
 23        if ii >= n_samples:
 24            break
 25
 26        if model is None:
 27            pred = None
 28        else:
 29            pred = model(x if device is None else x.to(device))
 30            pred = ensure_array(pred)[0]
 31
 32        # cast the data to array and remove the batch axis / choose first sample in batch
 33        x = ensure_array(x)[0]
 34        y = ensure_array(y)[0]
 35        assert x.ndim == y.ndim
 36        if x.ndim == 4:  # 3d data (with channel axis)
 37            z_slice = x.shape[1] // 2
 38            warnings.warn(f"3d input data is not yet supported, will only show slice {z_slice} / {x.shape[1]}")
 39            x, y = x[:, z_slice], y[:, z_slice]
 40            if pred is not None:
 41                pred = pred[:, z_slice]
 42
 43        if x.shape[0] > 1:
 44            warnings.warn(f"Multi-channel input data is not yet supported, will only show channel 0 / {x.shape[0]}")
 45        x = x[0]
 46
 47        if pred is None:
 48            n_target_channels = y.shape[0]
 49        else:
 50            n_target_channels = pred.shape[0]
 51            y = y[:n_target_channels]
 52            assert y.shape[0] == n_target_channels
 53
 54        if fig is None:
 55            n_rows = n_target_channels + 1 if pred is None else 2 * n_target_channels + 1
 56            fig = plt.figure(figsize=(n_samples*img_size, n_rows*img_size))
 57
 58        ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 0, ii))
 59        ax.imshow(x, interpolation="nearest", cmap="Greys_r", aspect="auto")
 60
 61        for chan in range(n_target_channels):
 62            ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + chan, ii))
 63            if instance_labels:
 64                label_data = y[chan].astype("uint32")
 65                ax.imshow(label_data, interpolation="nearest", aspect="auto", cmap=get_random_colors(label_data))
 66            else:
 67                ax.imshow(y[chan], interpolation="nearest", cmap="Greys_r", aspect="auto")
 68
 69        if pred is not None:
 70            for chan in range(n_target_channels):
 71                ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + n_target_channels + chan, ii))
 72                ax.imshow(pred[chan], interpolation="nearest", cmap="Greys_r", aspect="auto")
 73
 74    if save_path is None:
 75        plt.show()
 76    else:
 77        plt.savefig(save_path)
 78        plt.close()
 79
 80
 81def _check_napari(loader, n_samples, instance_labels, model=None, device=None, rgb=False):
 82    import napari
 83
 84    for ii, sample in enumerate(loader):
 85        if ii >= n_samples:
 86            break
 87
 88        try:
 89            x, y = sample
 90        except ValueError:
 91            x = sample
 92            y = None
 93
 94        if model is None:
 95            pred = None
 96        else:
 97            pred = model(x if device is None else x.to(device))
 98            pred = ensure_array(pred)[0]
 99
100        x = ensure_array(x)[0]
101        if rgb:
102            assert x.shape[0] == 3
103            x = x.transpose((1, 2, 0))
104
105        v = napari.Viewer()
106        v.add_image(x)
107        if y is not None:
108            y = ensure_array(y)[0]
109            if instance_labels:
110                v.add_labels(y.astype("uint32"))
111            else:
112                v.add_image(y)
113        if pred is not None:
114            v.add_image(pred)
115
116        napari.run()
117
118
119def check_trainer(
120    trainer,
121    n_samples: int,
122    instance_labels: bool = False,
123    split: str = "val",
124    loader: Optional[torch.utils.data.DataLoader] = None,
125    plt: bool = False,
126):
127    """Check a trainer visually.
128
129    This function shows images and labels from the training or validation loader
130    and predictions from the trainer's model for the images.
131    The data will be plotted either with napari or matplotlib.
132
133    Args:
134        trainer: The trainer.
135        n_samples: The number of samples to plot.
136        instance_labels: Whether to visualize the label data as instances.
137        split: Which split to use. This will determine which data loader is used for plotting.
138            Can be one of "val" or "train".
139        loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer.
140        plt: Whether to plot the data with matplotlib instead of napari.
141    """
142    if loader is None:
143        assert split in ("val", "train")
144        loader = trainer.val_loader if split == "val" else trainer.train_loader
145    with torch.no_grad():
146        model = trainer.model
147        model.eval()
148        if plt:
149            _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device)
150        else:
151            _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)
152
153
154def check_loader(
155    loader: torch.utils.data.DataLoader,
156    n_samples: int,
157    instance_labels: bool = False,
158    plt: bool = False,
159    rgb: bool = False,
160    save_path: Optional[Union[str, os.PathLike]] = None,
161):
162    """Check a loader visually.
163
164    This function shows images and labels from the loader with napari or matplotlib.
165
166    Args:
167        loader: The data loader.
168        n_samples: The number of samples to plot.
169        instance_labels: Whether to visualize the label data as instances.
170        plt: Whether to plot the data with matplotlib instead of napari.
171        rgb: Whether the image data is rgb.
172        save_path: Path for saving the images instead of showing them.
173            This argument only has an effect if `plt=True`.
174    """
175    if plt:
176        _check_plt(loader, n_samples, instance_labels, save_path=save_path)
177    else:
178        _check_napari(loader, n_samples, instance_labels, rgb=rgb)
def check_trainer( trainer, n_samples: int, instance_labels: bool = False, split: str = 'val', loader: Optional[torch.utils.data.dataloader.DataLoader] = None, plt: bool = False):
120def check_trainer(
121    trainer,
122    n_samples: int,
123    instance_labels: bool = False,
124    split: str = "val",
125    loader: Optional[torch.utils.data.DataLoader] = None,
126    plt: bool = False,
127):
128    """Check a trainer visually.
129
130    This function shows images and labels from the training or validation loader
131    and predictions from the trainer's model for the images.
132    The data will be plotted either with napari or matplotlib.
133
134    Args:
135        trainer: The trainer.
136        n_samples: The number of samples to plot.
137        instance_labels: Whether to visualize the label data as instances.
138        split: Which split to use. This will determine which data loader is used for plotting.
139            Can be one of "val" or "train".
140        loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer.
141        plt: Whether to plot the data with matplotlib instead of napari.
142    """
143    if loader is None:
144        assert split in ("val", "train")
145        loader = trainer.val_loader if split == "val" else trainer.train_loader
146    with torch.no_grad():
147        model = trainer.model
148        model.eval()
149        if plt:
150            _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device)
151        else:
152            _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)

Check a trainer visually.

This function shows images and labels from the training or validation loader and predictions from the trainer's model for the images. The data will be plotted either with napari or matplotlib.

Arguments:
  • trainer: The trainer.
  • n_samples: The number of samples to plot.
  • instance_labels: Whether to visualize the label data as instances.
  • split: Which split to use. This will determine which data loader is used for plotting. Can be one of "val" or "train".
  • loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer.
  • plt: Whether to plot the data with matplotlib instead of napari.
def check_loader( loader: torch.utils.data.dataloader.DataLoader, n_samples: int, instance_labels: bool = False, plt: bool = False, rgb: bool = False, save_path: Union[str, os.PathLike, NoneType] = None):
155def check_loader(
156    loader: torch.utils.data.DataLoader,
157    n_samples: int,
158    instance_labels: bool = False,
159    plt: bool = False,
160    rgb: bool = False,
161    save_path: Optional[Union[str, os.PathLike]] = None,
162):
163    """Check a loader visually.
164
165    This function shows images and labels from the loader with napari or matplotlib.
166
167    Args:
168        loader: The data loader.
169        n_samples: The number of samples to plot.
170        instance_labels: Whether to visualize the label data as instances.
171        plt: Whether to plot the data with matplotlib instead of napari.
172        rgb: Whether the image data is rgb.
173        save_path: Path for saving the images instead of showing them.
174            This argument only has an effect if `plt=True`.
175    """
176    if plt:
177        _check_plt(loader, n_samples, instance_labels, save_path=save_path)
178    else:
179        _check_napari(loader, n_samples, instance_labels, rgb=rgb)

Check a loader visually.

This function shows images and labels from the loader with napari or matplotlib.

Arguments:
  • loader: The data loader.
  • n_samples: The number of samples to plot.
  • instance_labels: Whether to visualize the label data as instances.
  • plt: Whether to plot the data with matplotlib instead of napari.
  • rgb: Whether the image data is rgb.
  • save_path: Path for saving the images instead of showing them. This argument only has an effect if plt=True.