torch_em.util.test
1import os 2import imageio 3import h5py 4import numpy as np 5import torch 6 7from scipy.ndimage import distance_transform_edt 8from skimage.measure import label 9from skimage.segmentation import watershed 10 11 12def make_gt(spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None): 13 def _make_gt(): 14 seeds = np.random.rand(*spatial_shape) 15 seeds = label(seeds > 0.99) 16 hmap = distance_transform_edt(seeds == 0) 17 if with_background: 18 mask = np.random.rand(*spatial_shape) > 0.5 19 assert mask.shape == hmap.shape 20 else: 21 mask = None 22 return watershed(hmap, markers=seeds, mask=mask) 23 24 if n_batches is None and not with_channels: 25 seg = _make_gt() 26 elif n_batches is None and with_channels: 27 seg = _make_gt[None] 28 else: 29 seg = [] 30 for _ in range(n_batches): 31 batch_seg = _make_gt() 32 if with_channels: 33 batch_seg = batch_seg[None] 34 seg.append(batch_seg[None]) 35 seg = np.concatenate(seg, axis=0) 36 if dtype is not None: 37 seg = seg.astype(dtype) 38 return torch.from_numpy(seg) 39 40 41def create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks): 42 with h5py.File(data_path, "a") as f: 43 f.create_dataset(raw_key, data=np.random.rand(*shape), chunks=chunks) 44 f.create_dataset(label_key, data=np.random.randint(0, 4, size=shape), chunks=chunks) 45 46 47def create_image_collection_test_data(folder, n_images, min_shape, max_shape): 48 im_folder = os.path.join(folder, "images") 49 label_folder = os.path.join(folder, "labels") 50 os.makedirs(im_folder, exist_ok=True) 51 os.makedirs(label_folder, exist_ok=True) 52 53 for i in range(n_images): 54 shape = tuple(np.random.randint(mins, maxs) for mins, maxs in zip(min_shape, max_shape)) 55 raw = np.random.rand(*shape).astype("int16") 56 label = np.random.randint(0, 4, size=shape) 57 imageio.imwrite(os.path.join(im_folder, f"im_{i}.tif"), raw) 58 imageio.imwrite(os.path.join(label_folder, f"im_{i}.tif"), label)
def
make_gt( spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None):
13def make_gt(spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None): 14 def _make_gt(): 15 seeds = np.random.rand(*spatial_shape) 16 seeds = label(seeds > 0.99) 17 hmap = distance_transform_edt(seeds == 0) 18 if with_background: 19 mask = np.random.rand(*spatial_shape) > 0.5 20 assert mask.shape == hmap.shape 21 else: 22 mask = None 23 return watershed(hmap, markers=seeds, mask=mask) 24 25 if n_batches is None and not with_channels: 26 seg = _make_gt() 27 elif n_batches is None and with_channels: 28 seg = _make_gt[None] 29 else: 30 seg = [] 31 for _ in range(n_batches): 32 batch_seg = _make_gt() 33 if with_channels: 34 batch_seg = batch_seg[None] 35 seg.append(batch_seg[None]) 36 seg = np.concatenate(seg, axis=0) 37 if dtype is not None: 38 seg = seg.astype(dtype) 39 return torch.from_numpy(seg)
def
create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks):
def
create_image_collection_test_data(folder, n_images, min_shape, max_shape):
48def create_image_collection_test_data(folder, n_images, min_shape, max_shape): 49 im_folder = os.path.join(folder, "images") 50 label_folder = os.path.join(folder, "labels") 51 os.makedirs(im_folder, exist_ok=True) 52 os.makedirs(label_folder, exist_ok=True) 53 54 for i in range(n_images): 55 shape = tuple(np.random.randint(mins, maxs) for mins, maxs in zip(min_shape, max_shape)) 56 raw = np.random.rand(*shape).astype("int16") 57 label = np.random.randint(0, 4, size=shape) 58 imageio.imwrite(os.path.join(im_folder, f"im_{i}.tif"), raw) 59 imageio.imwrite(os.path.join(label_folder, f"im_{i}.tif"), label)