torch_em.util.debug
1import os 2import warnings 3from typing import Union, Optional 4 5import torch 6import torch.utils.data 7 8from .util import ensure_array 9 10 11def _check_plt(loader, n_samples, instance_labels, model=None, device=None, save_path=None): 12 import matplotlib.pyplot as plt 13 img_size = 5 14 15 fig = None 16 n_rows = None 17 18 def to_index(ns, rid, sid): 19 index = 1 + rid * ns + sid 20 return index 21 22 for ii, (x, y) in enumerate(loader): 23 if ii >= n_samples: 24 break 25 26 if model is None: 27 pred = None 28 else: 29 pred = model(x if device is None else x.to(device)) 30 pred = ensure_array(pred)[0] 31 32 # cast the data to array and remove the batch axis / choose first sample in batch 33 x = ensure_array(x)[0] 34 y = ensure_array(y)[0] 35 assert x.ndim == y.ndim 36 if x.ndim == 4: # 3d data (with channel axis) 37 z_slice = x.shape[1] // 2 38 warnings.warn(f"3d input data is not yet supported, will only show slice {z_slice} / {x.shape[1]}") 39 x, y = x[:, z_slice], y[:, z_slice] 40 if pred is not None: 41 pred = pred[:, z_slice] 42 43 if x.shape[0] > 1: 44 warnings.warn(f"Multi-channel input data is not yet supported, will only show channel 0 / {x.shape[0]}") 45 x = x[0] 46 47 if pred is None: 48 n_target_channels = y.shape[0] 49 else: 50 n_target_channels = pred.shape[0] 51 y = y[:n_target_channels] 52 assert y.shape[0] == n_target_channels 53 54 if fig is None: 55 n_rows = n_target_channels + 1 if pred is None else 2 * n_target_channels + 1 56 fig = plt.figure(figsize=(n_samples*img_size, n_rows*img_size)) 57 58 ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 0, ii)) 59 ax.imshow(x, interpolation="nearest", cmap="Greys_r", aspect="auto") 60 61 for chan in range(n_target_channels): 62 ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + chan, ii)) 63 if instance_labels: 64 ax.imshow(y[chan].astype("uint32"), interpolation="nearest", aspect="auto") 65 else: 66 ax.imshow(y[chan], interpolation="nearest", cmap="Greys_r", aspect="auto") 67 68 if pred is not None: 69 for chan in range(n_target_channels): 70 ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + n_target_channels + chan, ii)) 71 ax.imshow(pred[chan], interpolation="nearest", cmap="Greys_r", aspect="auto") 72 73 if save_path is None: 74 plt.show() 75 else: 76 plt.savefig(save_path) 77 plt.close() 78 79 80def _check_napari(loader, n_samples, instance_labels, model=None, device=None, rgb=False): 81 import napari 82 83 for ii, sample in enumerate(loader): 84 if ii >= n_samples: 85 break 86 87 try: 88 x, y = sample 89 except ValueError: 90 x = sample 91 y = None 92 93 if model is None: 94 pred = None 95 else: 96 pred = model(x if device is None else x.to(device)) 97 pred = ensure_array(pred)[0] 98 99 x = ensure_array(x)[0] 100 if rgb: 101 assert x.shape[0] == 3 102 x = x.transpose((1, 2, 0)) 103 104 v = napari.Viewer() 105 v.add_image(x) 106 if y is not None: 107 y = ensure_array(y)[0] 108 if instance_labels: 109 v.add_labels(y.astype("uint32")) 110 else: 111 v.add_image(y) 112 if pred is not None: 113 v.add_image(pred) 114 115 napari.run() 116 117 118def check_trainer( 119 trainer, 120 n_samples: int, 121 instance_labels: bool = False, 122 split: str = "val", 123 loader: Optional[torch.utils.data.DataLoader] = None, 124 plt: bool = False, 125): 126 """Check a trainer visually. 127 128 This function shows images and labels from the training or validation loader 129 and predictions from the trainer's model for the images. 130 The data will be plotted either with napari or matplotlib. 131 132 Args: 133 trainer: The trainer. 134 n_samples: The number of samples to plot. 135 instance_labels: Whether to visualize the label data as instances. 136 split: Which split to use. This will determine which data loader is used for plotting. 137 Can be one of "val" or "train". 138 loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer. 139 plt: Whether to plot the data with matplotlib instead of napari. 140 """ 141 if loader is None: 142 assert split in ("val", "train") 143 loader = trainer.val_loader if split == "val" else trainer.train_loader 144 with torch.no_grad(): 145 model = trainer.model 146 model.eval() 147 if plt: 148 _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device) 149 else: 150 _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device) 151 152 153def check_loader( 154 loader: torch.utils.data.DataLoader, 155 n_samples: int, 156 instance_labels: bool = False, 157 plt: bool = False, 158 rgb: bool = False, 159 save_path: Optional[Union[str, os.PathLike]] = None, 160): 161 """Check a loader visually. 162 163 This function shows images and labels from the loader with napari or matplotlib. 164 165 Args: 166 loader: The data loader. 167 n_samples: The number of samples to plot. 168 instance_labels: Whether to visualize the label data as instances. 169 plt: Whether to plot the data with matplotlib instead of napari. 170 rgb: Whether the image data is rgb. 171 save_path: Path for saving the images instead of showing them. 172 This argument only has an effect if `plt=True`. 173 """ 174 if plt: 175 _check_plt(loader, n_samples, instance_labels, save_path=save_path) 176 else: 177 _check_napari(loader, n_samples, instance_labels, rgb=rgb)
def
check_trainer( trainer, n_samples: int, instance_labels: bool = False, split: str = 'val', loader: Optional[torch.utils.data.dataloader.DataLoader] = None, plt: bool = False):
119def check_trainer( 120 trainer, 121 n_samples: int, 122 instance_labels: bool = False, 123 split: str = "val", 124 loader: Optional[torch.utils.data.DataLoader] = None, 125 plt: bool = False, 126): 127 """Check a trainer visually. 128 129 This function shows images and labels from the training or validation loader 130 and predictions from the trainer's model for the images. 131 The data will be plotted either with napari or matplotlib. 132 133 Args: 134 trainer: The trainer. 135 n_samples: The number of samples to plot. 136 instance_labels: Whether to visualize the label data as instances. 137 split: Which split to use. This will determine which data loader is used for plotting. 138 Can be one of "val" or "train". 139 loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer. 140 plt: Whether to plot the data with matplotlib instead of napari. 141 """ 142 if loader is None: 143 assert split in ("val", "train") 144 loader = trainer.val_loader if split == "val" else trainer.train_loader 145 with torch.no_grad(): 146 model = trainer.model 147 model.eval() 148 if plt: 149 _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device) 150 else: 151 _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)
Check a trainer visually.
This function shows images and labels from the training or validation loader and predictions from the trainer's model for the images. The data will be plotted either with napari or matplotlib.
Arguments:
- trainer: The trainer.
- n_samples: The number of samples to plot.
- instance_labels: Whether to visualize the label data as instances.
- split: Which split to use. This will determine which data loader is used for plotting. Can be one of "val" or "train".
- loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer.
- plt: Whether to plot the data with matplotlib instead of napari.
def
check_loader( loader: torch.utils.data.dataloader.DataLoader, n_samples: int, instance_labels: bool = False, plt: bool = False, rgb: bool = False, save_path: Union[str, os.PathLike, NoneType] = None):
154def check_loader( 155 loader: torch.utils.data.DataLoader, 156 n_samples: int, 157 instance_labels: bool = False, 158 plt: bool = False, 159 rgb: bool = False, 160 save_path: Optional[Union[str, os.PathLike]] = None, 161): 162 """Check a loader visually. 163 164 This function shows images and labels from the loader with napari or matplotlib. 165 166 Args: 167 loader: The data loader. 168 n_samples: The number of samples to plot. 169 instance_labels: Whether to visualize the label data as instances. 170 plt: Whether to plot the data with matplotlib instead of napari. 171 rgb: Whether the image data is rgb. 172 save_path: Path for saving the images instead of showing them. 173 This argument only has an effect if `plt=True`. 174 """ 175 if plt: 176 _check_plt(loader, n_samples, instance_labels, save_path=save_path) 177 else: 178 _check_napari(loader, n_samples, instance_labels, rgb=rgb)
Check a loader visually.
This function shows images and labels from the loader with napari or matplotlib.
Arguments:
- loader: The data loader.
- n_samples: The number of samples to plot.
- instance_labels: Whether to visualize the label data as instances.
- plt: Whether to plot the data with matplotlib instead of napari.
- rgb: Whether the image data is rgb.
- save_path: Path for saving the images instead of showing them.
This argument only has an effect if
plt=True
.