torch_em.util.image
1# TODO this should be partially refactored into elf.io before the next elf release 2# and then be used in image_stack_wrapper as welll 3import os 4import numpy as np 5 6from elf.io import open_file 7try: 8 import imageio.v3 as imageio 9except ImportError: 10 import imageio 11 12try: 13 import tifffile 14except ImportError: 15 tifffile = None 16 17TIF_EXTS = (".tif", ".tiff") 18 19 20def supports_memmap(image_path): 21 if tifffile is None: 22 return False 23 ext = os.path.splitext(image_path)[1] 24 if ext.lower() not in TIF_EXTS: 25 return False 26 try: 27 tifffile.memmap(image_path, mode="r") 28 except ValueError: 29 return False 30 return True 31 32 33def load_image(image_path, memmap=True): 34 if supports_memmap(image_path) and memmap: 35 return tifffile.memmap(image_path, mode="r") 36 elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"): 37 return tifffile.imread(image_path) 38 else: 39 return imageio.imread(image_path) 40 41 42class MultiDatasetWrapper: 43 def __init__(self, *file_datasets): 44 # Make sure we have the same shapes. 45 reference_shape = file_datasets[0].shape 46 assert all(reference_shape == ds.shape for ds in file_datasets) 47 self.file_datasets = file_datasets 48 49 self.shape = (len(self.file_datasets),) + reference_shape 50 51 def __getitem__(self, index): 52 channel_index, spatial_index = index[:1], index[1:] 53 data = [] 54 for ds in self.file_datasets: 55 ds_data = ds[spatial_index] 56 data.append(ds_data) 57 data = np.stack(data) 58 data = data[channel_index] 59 return data 60 61 62def load_data(path, key, mode="r"): 63 have_single_file = isinstance(path, str) 64 if key is None and have_single_file: 65 return load_image(path) 66 elif key is None and not have_single_file: 67 return np.stack([load_image(p) for p in path]) 68 elif key is not None and have_single_file: 69 return open_file(path, mode=mode)[key] 70 elif key is not None and not have_single_file: 71 return MultiDatasetWrapper(*[open_file(p, mode=mode)[key] for p in path])
TIF_EXTS =
('.tif', '.tiff')
def
supports_memmap(image_path):
def
load_image(image_path, memmap=True):
34def load_image(image_path, memmap=True): 35 if supports_memmap(image_path) and memmap: 36 return tifffile.memmap(image_path, mode="r") 37 elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"): 38 return tifffile.imread(image_path) 39 else: 40 return imageio.imread(image_path)
class
MultiDatasetWrapper:
43class MultiDatasetWrapper: 44 def __init__(self, *file_datasets): 45 # Make sure we have the same shapes. 46 reference_shape = file_datasets[0].shape 47 assert all(reference_shape == ds.shape for ds in file_datasets) 48 self.file_datasets = file_datasets 49 50 self.shape = (len(self.file_datasets),) + reference_shape 51 52 def __getitem__(self, index): 53 channel_index, spatial_index = index[:1], index[1:] 54 data = [] 55 for ds in self.file_datasets: 56 ds_data = ds[spatial_index] 57 data.append(ds_data) 58 data = np.stack(data) 59 data = data[channel_index] 60 return data
def
load_data(path, key, mode='r'):
63def load_data(path, key, mode="r"): 64 have_single_file = isinstance(path, str) 65 if key is None and have_single_file: 66 return load_image(path) 67 elif key is None and not have_single_file: 68 return np.stack([load_image(p) for p in path]) 69 elif key is not None and have_single_file: 70 return open_file(path, mode=mode)[key] 71 elif key is not None and not have_single_file: 72 return MultiDatasetWrapper(*[open_file(p, mode=mode)[key] for p in path])