torch_em.util.image

 1# TODO this should be partially refactored into elf.io before the next elf release
 2# and then be used in image_stack_wrapper as welll
 3import os
 4import numpy as np
 5
 6from elf.io import open_file
 7try:
 8    import imageio.v3 as imageio
 9except ImportError:
10    import imageio
11
12try:
13    import tifffile
14except ImportError:
15    tifffile = None
16
17TIF_EXTS = (".tif", ".tiff")
18
19
20def supports_memmap(image_path):
21    if tifffile is None:
22        return False
23    ext = os.path.splitext(image_path)[1]
24    if ext.lower() not in TIF_EXTS:
25        return False
26    try:
27        tifffile.memmap(image_path, mode="r")
28    except ValueError:
29        return False
30    return True
31
32
33def load_image(image_path, memmap=True):
34    if supports_memmap(image_path) and memmap:
35        return tifffile.memmap(image_path, mode="r")
36    elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"):
37        return tifffile.imread(image_path)
38    else:
39        return imageio.imread(image_path)
40
41
42class MultiDatasetWrapper:
43    def __init__(self, *file_datasets):
44        # Make sure we have the same shapes.
45        reference_shape = file_datasets[0].shape
46        assert all(reference_shape == ds.shape for ds in file_datasets)
47        self.file_datasets = file_datasets
48
49        self.shape = (len(self.file_datasets),) + reference_shape
50
51    def __getitem__(self, index):
52        channel_index, spatial_index = index[:1], index[1:]
53        data = []
54        for ds in self.file_datasets:
55            ds_data = ds[spatial_index]
56            data.append(ds_data)
57        data = np.stack(data)
58        data = data[channel_index]
59        return data
60
61
62def load_data(path, key, mode="r"):
63    have_single_file = isinstance(path, str)
64    if key is None and have_single_file:
65        return load_image(path)
66    elif key is None and not have_single_file:
67        return np.stack([load_image(p) for p in path])
68    elif key is not None and have_single_file:
69        return open_file(path, mode=mode)[key]
70    elif key is not None and not have_single_file:
71        return MultiDatasetWrapper(*[open_file(p, mode=mode)[key] for p in path])
TIF_EXTS = ('.tif', '.tiff')
def supports_memmap(image_path):
21def supports_memmap(image_path):
22    if tifffile is None:
23        return False
24    ext = os.path.splitext(image_path)[1]
25    if ext.lower() not in TIF_EXTS:
26        return False
27    try:
28        tifffile.memmap(image_path, mode="r")
29    except ValueError:
30        return False
31    return True
def load_image(image_path, memmap=True):
34def load_image(image_path, memmap=True):
35    if supports_memmap(image_path) and memmap:
36        return tifffile.memmap(image_path, mode="r")
37    elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"):
38        return tifffile.imread(image_path)
39    else:
40        return imageio.imread(image_path)
class MultiDatasetWrapper:
43class MultiDatasetWrapper:
44    def __init__(self, *file_datasets):
45        # Make sure we have the same shapes.
46        reference_shape = file_datasets[0].shape
47        assert all(reference_shape == ds.shape for ds in file_datasets)
48        self.file_datasets = file_datasets
49
50        self.shape = (len(self.file_datasets),) + reference_shape
51
52    def __getitem__(self, index):
53        channel_index, spatial_index = index[:1], index[1:]
54        data = []
55        for ds in self.file_datasets:
56            ds_data = ds[spatial_index]
57            data.append(ds_data)
58        data = np.stack(data)
59        data = data[channel_index]
60        return data
MultiDatasetWrapper(*file_datasets)
44    def __init__(self, *file_datasets):
45        # Make sure we have the same shapes.
46        reference_shape = file_datasets[0].shape
47        assert all(reference_shape == ds.shape for ds in file_datasets)
48        self.file_datasets = file_datasets
49
50        self.shape = (len(self.file_datasets),) + reference_shape
file_datasets
shape
def load_data(path, key, mode='r'):
63def load_data(path, key, mode="r"):
64    have_single_file = isinstance(path, str)
65    if key is None and have_single_file:
66        return load_image(path)
67    elif key is None and not have_single_file:
68        return np.stack([load_image(p) for p in path])
69    elif key is not None and have_single_file:
70        return open_file(path, mode=mode)[key]
71    elif key is not None and not have_single_file:
72        return MultiDatasetWrapper(*[open_file(p, mode=mode)[key] for p in path])