torch_em.util.test
@private
1"""@private 2""" 3import os 4import imageio 5import h5py 6import numpy as np 7import torch 8 9from scipy.ndimage import distance_transform_edt 10from skimage.measure import label 11from skimage.segmentation import watershed 12 13 14def make_gt(spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None): 15 def _make_gt(): 16 seeds = np.random.rand(*spatial_shape) 17 seeds = label(seeds > 0.99) 18 hmap = distance_transform_edt(seeds == 0) 19 if with_background: 20 mask = np.random.rand(*spatial_shape) > 0.5 21 assert mask.shape == hmap.shape 22 else: 23 mask = None 24 return watershed(hmap, markers=seeds, mask=mask) 25 26 if n_batches is None and not with_channels: 27 seg = _make_gt() 28 elif n_batches is None and with_channels: 29 seg = _make_gt[None] 30 else: 31 seg = [] 32 for _ in range(n_batches): 33 batch_seg = _make_gt() 34 if with_channels: 35 batch_seg = batch_seg[None] 36 seg.append(batch_seg[None]) 37 seg = np.concatenate(seg, axis=0) 38 if dtype is not None: 39 seg = seg.astype(dtype) 40 return torch.from_numpy(seg) 41 42 43def create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks): 44 with h5py.File(data_path, "a") as f: 45 f.create_dataset(raw_key, data=np.random.rand(*shape), chunks=chunks) 46 f.create_dataset(label_key, data=np.random.randint(0, 4, size=shape), chunks=chunks) 47 48 49def create_image_collection_test_data(folder, n_images, min_shape, max_shape): 50 im_folder = os.path.join(folder, "images") 51 label_folder = os.path.join(folder, "labels") 52 os.makedirs(im_folder, exist_ok=True) 53 os.makedirs(label_folder, exist_ok=True) 54 55 for i in range(n_images): 56 shape = tuple(np.random.randint(mins, maxs) for mins, maxs in zip(min_shape, max_shape)) 57 raw = np.random.rand(*shape).astype("int16") 58 label = np.random.randint(0, 4, size=shape) 59 imageio.imwrite(os.path.join(im_folder, f"im_{i}.tif"), raw) 60 imageio.imwrite(os.path.join(label_folder, f"im_{i}.tif"), label)
def
make_gt( spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None):
15def make_gt(spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None): 16 def _make_gt(): 17 seeds = np.random.rand(*spatial_shape) 18 seeds = label(seeds > 0.99) 19 hmap = distance_transform_edt(seeds == 0) 20 if with_background: 21 mask = np.random.rand(*spatial_shape) > 0.5 22 assert mask.shape == hmap.shape 23 else: 24 mask = None 25 return watershed(hmap, markers=seeds, mask=mask) 26 27 if n_batches is None and not with_channels: 28 seg = _make_gt() 29 elif n_batches is None and with_channels: 30 seg = _make_gt[None] 31 else: 32 seg = [] 33 for _ in range(n_batches): 34 batch_seg = _make_gt() 35 if with_channels: 36 batch_seg = batch_seg[None] 37 seg.append(batch_seg[None]) 38 seg = np.concatenate(seg, axis=0) 39 if dtype is not None: 40 seg = seg.astype(dtype) 41 return torch.from_numpy(seg)
def
create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks):
def
create_image_collection_test_data(folder, n_images, min_shape, max_shape):
50def create_image_collection_test_data(folder, n_images, min_shape, max_shape): 51 im_folder = os.path.join(folder, "images") 52 label_folder = os.path.join(folder, "labels") 53 os.makedirs(im_folder, exist_ok=True) 54 os.makedirs(label_folder, exist_ok=True) 55 56 for i in range(n_images): 57 shape = tuple(np.random.randint(mins, maxs) for mins, maxs in zip(min_shape, max_shape)) 58 raw = np.random.rand(*shape).astype("int16") 59 label = np.random.randint(0, 4, size=shape) 60 imageio.imwrite(os.path.join(im_folder, f"im_{i}.tif"), raw) 61 imageio.imwrite(os.path.join(label_folder, f"im_{i}.tif"), label)