torch_em.util.test

@private

 1"""@private
 2"""
 3import os
 4import imageio
 5import h5py
 6import numpy as np
 7import torch
 8
 9from scipy.ndimage import distance_transform_edt
10from skimage.measure import label
11from skimage.segmentation import watershed
12
13
14def make_gt(spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None):
15    def _make_gt():
16        seeds = np.random.rand(*spatial_shape)
17        seeds = label(seeds > 0.99)
18        hmap = distance_transform_edt(seeds == 0)
19        if with_background:
20            mask = np.random.rand(*spatial_shape) > 0.5
21            assert mask.shape == hmap.shape
22        else:
23            mask = None
24        return watershed(hmap, markers=seeds, mask=mask)
25
26    if n_batches is None and not with_channels:
27        seg = _make_gt()
28    elif n_batches is None and with_channels:
29        seg = _make_gt[None]
30    else:
31        seg = []
32        for _ in range(n_batches):
33            batch_seg = _make_gt()
34            if with_channels:
35                batch_seg = batch_seg[None]
36            seg.append(batch_seg[None])
37        seg = np.concatenate(seg, axis=0)
38    if dtype is not None:
39        seg = seg.astype(dtype)
40    return torch.from_numpy(seg)
41
42
43def create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks):
44    with h5py.File(data_path, "a") as f:
45        f.create_dataset(raw_key, data=np.random.rand(*shape), chunks=chunks)
46        f.create_dataset(label_key, data=np.random.randint(0, 4, size=shape), chunks=chunks)
47
48
49def create_image_collection_test_data(folder, n_images, min_shape, max_shape):
50    im_folder = os.path.join(folder, "images")
51    label_folder = os.path.join(folder, "labels")
52    os.makedirs(im_folder, exist_ok=True)
53    os.makedirs(label_folder, exist_ok=True)
54
55    for i in range(n_images):
56        shape = tuple(np.random.randint(mins, maxs) for mins, maxs in zip(min_shape, max_shape))
57        raw = np.random.rand(*shape).astype("int16")
58        label = np.random.randint(0, 4, size=shape)
59        imageio.imwrite(os.path.join(im_folder, f"im_{i}.tif"), raw)
60        imageio.imwrite(os.path.join(label_folder, f"im_{i}.tif"), label)
def make_gt( spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None):
15def make_gt(spatial_shape, n_batches=None, with_channels=False, with_background=False, dtype=None):
16    def _make_gt():
17        seeds = np.random.rand(*spatial_shape)
18        seeds = label(seeds > 0.99)
19        hmap = distance_transform_edt(seeds == 0)
20        if with_background:
21            mask = np.random.rand(*spatial_shape) > 0.5
22            assert mask.shape == hmap.shape
23        else:
24            mask = None
25        return watershed(hmap, markers=seeds, mask=mask)
26
27    if n_batches is None and not with_channels:
28        seg = _make_gt()
29    elif n_batches is None and with_channels:
30        seg = _make_gt[None]
31    else:
32        seg = []
33        for _ in range(n_batches):
34            batch_seg = _make_gt()
35            if with_channels:
36                batch_seg = batch_seg[None]
37            seg.append(batch_seg[None])
38        seg = np.concatenate(seg, axis=0)
39    if dtype is not None:
40        seg = seg.astype(dtype)
41    return torch.from_numpy(seg)
def create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks):
44def create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks):
45    with h5py.File(data_path, "a") as f:
46        f.create_dataset(raw_key, data=np.random.rand(*shape), chunks=chunks)
47        f.create_dataset(label_key, data=np.random.randint(0, 4, size=shape), chunks=chunks)
def create_image_collection_test_data(folder, n_images, min_shape, max_shape):
50def create_image_collection_test_data(folder, n_images, min_shape, max_shape):
51    im_folder = os.path.join(folder, "images")
52    label_folder = os.path.join(folder, "labels")
53    os.makedirs(im_folder, exist_ok=True)
54    os.makedirs(label_folder, exist_ok=True)
55
56    for i in range(n_images):
57        shape = tuple(np.random.randint(mins, maxs) for mins, maxs in zip(min_shape, max_shape))
58        raw = np.random.rand(*shape).astype("int16")
59        label = np.random.randint(0, 4, size=shape)
60        imageio.imwrite(os.path.join(im_folder, f"im_{i}.tif"), raw)
61        imageio.imwrite(os.path.join(label_folder, f"im_{i}.tif"), label)