torch_em.util.debug
1import os 2import warnings 3from typing import Union, Optional 4 5import torch 6import torch.utils.data 7 8from .util import ensure_array, get_random_colors 9 10 11def _check_plt(loader, n_samples, instance_labels, model=None, device=None, save_path=None): 12 import matplotlib.pyplot as plt 13 img_size = 5 14 15 fig = None 16 n_rows = None 17 18 def to_index(ns, rid, sid): 19 index = 1 + rid * ns + sid 20 return index 21 22 for ii, (x, y) in enumerate(loader): 23 if ii >= n_samples: 24 break 25 26 if model is None: 27 pred = None 28 else: 29 pred = model(x if device is None else x.to(device)) 30 pred = ensure_array(pred)[0] 31 32 # cast the data to array and remove the batch axis / choose first sample in batch 33 x = ensure_array(x)[0] 34 y = ensure_array(y)[0] 35 assert x.ndim == y.ndim 36 if x.ndim == 4: # 3d data (with channel axis) 37 z_slice = x.shape[1] // 2 38 warnings.warn(f"3d input data is not yet supported, will only show slice {z_slice} / {x.shape[1]}") 39 x, y = x[:, z_slice], y[:, z_slice] 40 if pred is not None: 41 pred = pred[:, z_slice] 42 43 if x.shape[0] > 1: 44 warnings.warn(f"Multi-channel input data is not yet supported, will only show channel 0 / {x.shape[0]}") 45 x = x[0] 46 47 if pred is None: 48 n_target_channels = y.shape[0] 49 else: 50 n_target_channels = pred.shape[0] 51 y = y[:n_target_channels] 52 assert y.shape[0] == n_target_channels 53 54 if fig is None: 55 n_rows = n_target_channels + 1 if pred is None else 2 * n_target_channels + 1 56 fig = plt.figure(figsize=(n_samples*img_size, n_rows*img_size)) 57 58 ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 0, ii)) 59 ax.imshow(x, interpolation="nearest", cmap="Greys_r", aspect="auto") 60 61 for chan in range(n_target_channels): 62 ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + chan, ii)) 63 if instance_labels: 64 label_data = y[chan].astype("uint32") 65 ax.imshow(label_data, interpolation="nearest", aspect="auto", cmap=get_random_colors(label_data)) 66 else: 67 ax.imshow(y[chan], interpolation="nearest", cmap="Greys_r", aspect="auto") 68 69 if pred is not None: 70 for chan in range(n_target_channels): 71 ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + n_target_channels + chan, ii)) 72 ax.imshow(pred[chan], interpolation="nearest", cmap="Greys_r", aspect="auto") 73 74 if save_path is None: 75 plt.show() 76 else: 77 plt.savefig(save_path) 78 plt.close() 79 80 81def _check_napari(loader, n_samples, instance_labels, model=None, device=None, rgb=False): 82 import napari 83 84 for ii, sample in enumerate(loader): 85 if ii >= n_samples: 86 break 87 88 try: 89 x, y = sample 90 except ValueError: 91 x = sample 92 y = None 93 94 if model is None: 95 pred = None 96 else: 97 pred = model(x if device is None else x.to(device)) 98 pred = ensure_array(pred)[0] 99 100 x = ensure_array(x)[0] 101 if rgb: 102 assert x.shape[0] == 3 103 x = x.transpose((1, 2, 0)) 104 105 v = napari.Viewer() 106 v.add_image(x) 107 if y is not None: 108 y = ensure_array(y)[0] 109 if instance_labels: 110 v.add_labels(y.astype("uint32")) 111 else: 112 v.add_image(y) 113 if pred is not None: 114 v.add_image(pred) 115 116 napari.run() 117 118 119def check_trainer( 120 trainer, 121 n_samples: int, 122 instance_labels: bool = False, 123 split: str = "val", 124 loader: Optional[torch.utils.data.DataLoader] = None, 125 plt: bool = False, 126): 127 """Check a trainer visually. 128 129 This function shows images and labels from the training or validation loader 130 and predictions from the trainer's model for the images. 131 The data will be plotted either with napari or matplotlib. 132 133 Args: 134 trainer: The trainer. 135 n_samples: The number of samples to plot. 136 instance_labels: Whether to visualize the label data as instances. 137 split: Which split to use. This will determine which data loader is used for plotting. 138 Can be one of "val" or "train". 139 loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer. 140 plt: Whether to plot the data with matplotlib instead of napari. 141 """ 142 if loader is None: 143 assert split in ("val", "train") 144 loader = trainer.val_loader if split == "val" else trainer.train_loader 145 with torch.no_grad(): 146 model = trainer.model 147 model.eval() 148 if plt: 149 _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device) 150 else: 151 _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device) 152 153 154def check_loader( 155 loader: torch.utils.data.DataLoader, 156 n_samples: int, 157 instance_labels: bool = False, 158 plt: bool = False, 159 rgb: bool = False, 160 save_path: Optional[Union[str, os.PathLike]] = None, 161): 162 """Check a loader visually. 163 164 This function shows images and labels from the loader with napari or matplotlib. 165 166 Args: 167 loader: The data loader. 168 n_samples: The number of samples to plot. 169 instance_labels: Whether to visualize the label data as instances. 170 plt: Whether to plot the data with matplotlib instead of napari. 171 rgb: Whether the image data is rgb. 172 save_path: Path for saving the images instead of showing them. 173 This argument only has an effect if `plt=True`. 174 """ 175 if plt: 176 _check_plt(loader, n_samples, instance_labels, save_path=save_path) 177 else: 178 _check_napari(loader, n_samples, instance_labels, rgb=rgb)
def
check_trainer( trainer, n_samples: int, instance_labels: bool = False, split: str = 'val', loader: Optional[torch.utils.data.dataloader.DataLoader] = None, plt: bool = False):
120def check_trainer( 121 trainer, 122 n_samples: int, 123 instance_labels: bool = False, 124 split: str = "val", 125 loader: Optional[torch.utils.data.DataLoader] = None, 126 plt: bool = False, 127): 128 """Check a trainer visually. 129 130 This function shows images and labels from the training or validation loader 131 and predictions from the trainer's model for the images. 132 The data will be plotted either with napari or matplotlib. 133 134 Args: 135 trainer: The trainer. 136 n_samples: The number of samples to plot. 137 instance_labels: Whether to visualize the label data as instances. 138 split: Which split to use. This will determine which data loader is used for plotting. 139 Can be one of "val" or "train". 140 loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer. 141 plt: Whether to plot the data with matplotlib instead of napari. 142 """ 143 if loader is None: 144 assert split in ("val", "train") 145 loader = trainer.val_loader if split == "val" else trainer.train_loader 146 with torch.no_grad(): 147 model = trainer.model 148 model.eval() 149 if plt: 150 _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device) 151 else: 152 _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)
Check a trainer visually.
This function shows images and labels from the training or validation loader and predictions from the trainer's model for the images. The data will be plotted either with napari or matplotlib.
Arguments:
- trainer: The trainer.
- n_samples: The number of samples to plot.
- instance_labels: Whether to visualize the label data as instances.
- split: Which split to use. This will determine which data loader is used for plotting. Can be one of "val" or "train".
- loader: An optional loader to use for getting the data. Will be used instead of the loader from the trainer.
- plt: Whether to plot the data with matplotlib instead of napari.
def
check_loader( loader: torch.utils.data.dataloader.DataLoader, n_samples: int, instance_labels: bool = False, plt: bool = False, rgb: bool = False, save_path: Union[str, os.PathLike, NoneType] = None):
155def check_loader( 156 loader: torch.utils.data.DataLoader, 157 n_samples: int, 158 instance_labels: bool = False, 159 plt: bool = False, 160 rgb: bool = False, 161 save_path: Optional[Union[str, os.PathLike]] = None, 162): 163 """Check a loader visually. 164 165 This function shows images and labels from the loader with napari or matplotlib. 166 167 Args: 168 loader: The data loader. 169 n_samples: The number of samples to plot. 170 instance_labels: Whether to visualize the label data as instances. 171 plt: Whether to plot the data with matplotlib instead of napari. 172 rgb: Whether the image data is rgb. 173 save_path: Path for saving the images instead of showing them. 174 This argument only has an effect if `plt=True`. 175 """ 176 if plt: 177 _check_plt(loader, n_samples, instance_labels, save_path=save_path) 178 else: 179 _check_napari(loader, n_samples, instance_labels, rgb=rgb)
Check a loader visually.
This function shows images and labels from the loader with napari or matplotlib.
Arguments:
- loader: The data loader.
- n_samples: The number of samples to plot.
- instance_labels: Whether to visualize the label data as instances.
- plt: Whether to plot the data with matplotlib instead of napari.
- rgb: Whether the image data is rgb.
- save_path: Path for saving the images instead of showing them.
This argument only has an effect if
plt=True.