torch_em.util.test

 1import os
 2import imageio
 3import h5py
 4import numpy as np
 5import torch
 6
 7from scipy.ndimage import distance_transform_edt
 8from skimage.measure import label
 9from skimage.segmentation import watershed
10
11
12def make_gt(spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None):
13    def _make_gt():
14        seeds = np.random.rand(*spatial_shape)
15        seeds = label(seeds > 0.99)
16        hmap = distance_transform_edt(seeds == 0)
17        if with_background:
18            mask = np.random.rand(*spatial_shape) > 0.5
19            assert mask.shape == hmap.shape
20        else:
21            mask = None
22        return watershed(hmap, markers=seeds, mask=mask)
23
24    if n_batches is None and not with_channels:
25        seg = _make_gt()
26    elif n_batches is None and with_channels:
27        seg = _make_gt[None]
28    else:
29        seg = []
30        for _ in range(n_batches):
31            batch_seg = _make_gt()
32            if with_channels:
33                batch_seg = batch_seg[None]
34            seg.append(batch_seg[None])
35        seg = np.concatenate(seg, axis=0)
36    if dtype is not None:
37        seg = seg.astype(dtype)
38    return torch.from_numpy(seg)
39
40
41def create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks):
42    with h5py.File(data_path, "a") as f:
43        f.create_dataset(raw_key, data=np.random.rand(*shape), chunks=chunks)
44        f.create_dataset(label_key, data=np.random.randint(0, 4, size=shape), chunks=chunks)
45
46
47def create_image_collection_test_data(folder, n_images, min_shape, max_shape):
48    im_folder = os.path.join(folder, "images")
49    label_folder = os.path.join(folder, "labels")
50    os.makedirs(im_folder, exist_ok=True)
51    os.makedirs(label_folder, exist_ok=True)
52
53    for i in range(n_images):
54        shape = tuple(np.random.randint(mins, maxs) for mins, maxs in zip(min_shape, max_shape))
55        raw = np.random.rand(*shape).astype("int16")
56        label = np.random.randint(0, 4, size=shape)
57        imageio.imwrite(os.path.join(im_folder, f"im_{i}.tif"), raw)
58        imageio.imwrite(os.path.join(label_folder, f"im_{i}.tif"), label)
def make_gt( spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None):
13def make_gt(spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None):
14    def _make_gt():
15        seeds = np.random.rand(*spatial_shape)
16        seeds = label(seeds > 0.99)
17        hmap = distance_transform_edt(seeds == 0)
18        if with_background:
19            mask = np.random.rand(*spatial_shape) > 0.5
20            assert mask.shape == hmap.shape
21        else:
22            mask = None
23        return watershed(hmap, markers=seeds, mask=mask)
24
25    if n_batches is None and not with_channels:
26        seg = _make_gt()
27    elif n_batches is None and with_channels:
28        seg = _make_gt[None]
29    else:
30        seg = []
31        for _ in range(n_batches):
32            batch_seg = _make_gt()
33            if with_channels:
34                batch_seg = batch_seg[None]
35            seg.append(batch_seg[None])
36        seg = np.concatenate(seg, axis=0)
37    if dtype is not None:
38        seg = seg.astype(dtype)
39    return torch.from_numpy(seg)
def create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks):
42def create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks):
43    with h5py.File(data_path, "a") as f:
44        f.create_dataset(raw_key, data=np.random.rand(*shape), chunks=chunks)
45        f.create_dataset(label_key, data=np.random.randint(0, 4, size=shape), chunks=chunks)
def create_image_collection_test_data(folder, n_images, min_shape, max_shape):
48def create_image_collection_test_data(folder, n_images, min_shape, max_shape):
49    im_folder = os.path.join(folder, "images")
50    label_folder = os.path.join(folder, "labels")
51    os.makedirs(im_folder, exist_ok=True)
52    os.makedirs(label_folder, exist_ok=True)
53
54    for i in range(n_images):
55        shape = tuple(np.random.randint(mins, maxs) for mins, maxs in zip(min_shape, max_shape))
56        raw = np.random.rand(*shape).astype("int16")
57        label = np.random.randint(0, 4, size=shape)
58        imageio.imwrite(os.path.join(im_folder, f"im_{i}.tif"), raw)
59        imageio.imwrite(os.path.join(label_folder, f"im_{i}.tif"), label)