torch_em.util.debug
1import warnings 2import torch 3from .util import ensure_array 4 5 6def _check_plt(loader, n_samples, instance_labels, model=None, device=None, save_path=None): 7 import matplotlib.pyplot as plt 8 img_size = 5 9 10 fig = None 11 n_rows = None 12 13 def to_index(ns, rid, sid): 14 index = 1 + rid * ns + sid 15 return index 16 17 for ii, (x, y) in enumerate(loader): 18 if ii >= n_samples: 19 break 20 21 if model is None: 22 pred = None 23 else: 24 pred = model(x if device is None else x.to(device)) 25 pred = ensure_array(pred)[0] 26 27 # cast the data to array and remove the batch axis / choose first sample in batch 28 x = ensure_array(x)[0] 29 y = ensure_array(y)[0] 30 assert x.ndim == y.ndim 31 if x.ndim == 4: # 3d data (with channel axis) 32 z_slice = x.shape[1] // 2 33 warnings.warn(f"3d input data is not yet supported, will only show slice {z_slice} / {x.shape[1]}") 34 x, y = x[:, z_slice], y[:, z_slice] 35 if pred is not None: 36 pred = pred[:, z_slice] 37 38 if x.shape[0] > 1: 39 warnings.warn(f"Multi-channel input data is not yet supported, will only show channel 0 / {x.shape[0]}") 40 x = x[0] 41 42 if pred is None: 43 n_target_channels = y.shape[0] 44 else: 45 n_target_channels = pred.shape[0] 46 y = y[:n_target_channels] 47 assert y.shape[0] == n_target_channels 48 49 if fig is None: 50 n_rows = n_target_channels + 1 if pred is None else 2 * n_target_channels + 1 51 fig = plt.figure(figsize=(n_samples*img_size, n_rows*img_size)) 52 53 ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 0, ii)) 54 ax.imshow(x, interpolation="nearest", cmap="Greys_r", aspect="auto") 55 56 for chan in range(n_target_channels): 57 ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + chan, ii)) 58 if instance_labels: 59 ax.imshow(y[chan].astype("uint32"), interpolation="nearest", aspect="auto") 60 else: 61 ax.imshow(y[chan], interpolation="nearest", cmap="Greys_r", aspect="auto") 62 63 if pred is not None: 64 for chan in range(n_target_channels): 65 ax = fig.add_subplot(n_rows, n_samples, to_index(n_samples, 1 + n_target_channels + chan, ii)) 66 ax.imshow(pred[chan], interpolation="nearest", cmap="Greys_r", aspect="auto") 67 68 if save_path is None: 69 plt.show() 70 else: 71 plt.savefig(save_path) 72 plt.close() 73 74 75def _check_napari(loader, n_samples, instance_labels, model=None, device=None, rgb=False): 76 import napari 77 78 for ii, sample in enumerate(loader): 79 if ii >= n_samples: 80 break 81 82 try: 83 x, y = sample 84 except ValueError: 85 x = sample 86 y = None 87 88 if model is None: 89 pred = None 90 else: 91 pred = model(x if device is None else x.to(device)) 92 pred = ensure_array(pred)[0] 93 94 x = ensure_array(x)[0] 95 if rgb: 96 assert x.shape[0] == 3 97 x = x.transpose((1, 2, 0)) 98 99 v = napari.Viewer() 100 v.add_image(x) 101 if y is not None: 102 y = ensure_array(y)[0] 103 if instance_labels: 104 v.add_labels(y.astype("uint32")) 105 else: 106 v.add_image(y) 107 if pred is not None: 108 v.add_image(pred) 109 napari.run() 110 111 112def check_trainer(trainer, n_samples, instance_labels=False, split="val", loader=None, plt=False): 113 if loader is None: 114 assert split in ("val", "train") 115 loader = trainer.val_loader 116 with torch.no_grad(): 117 model = trainer.model 118 model.eval() 119 if plt: 120 _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device) 121 else: 122 _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device) 123 124 125def check_loader(loader, n_samples, instance_labels=False, plt=False, rgb=False, save_path=None): 126 if plt: 127 _check_plt(loader, n_samples, instance_labels, save_path=save_path) 128 else: 129 _check_napari(loader, n_samples, instance_labels, rgb=rgb)
def
check_trainer( trainer, n_samples, instance_labels=False, split='val', loader=None, plt=False):
113def check_trainer(trainer, n_samples, instance_labels=False, split="val", loader=None, plt=False): 114 if loader is None: 115 assert split in ("val", "train") 116 loader = trainer.val_loader 117 with torch.no_grad(): 118 model = trainer.model 119 model.eval() 120 if plt: 121 _check_plt(loader, n_samples, instance_labels, model=model, device=trainer.device) 122 else: 123 _check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)
def
check_loader( loader, n_samples, instance_labels=False, plt=False, rgb=False, save_path=None):