torch_em.data.raw_image_collection_dataset

  1import os
  2import numpy as np
  3from typing import List, Union, Tuple, Optional, Any
  4
  5import torch
  6
  7from ..util import ensure_tensor_with_channels, load_image, supports_memmap
  8
  9
 10class RawImageCollectionDataset(torch.utils.data.Dataset):
 11    max_sampling_attempts = 500
 12
 13    def _check_inputs(self, raw_images, full_check):
 14        if not full_check:
 15            return
 16
 17        is_multichan = None
 18        for raw_im in raw_images:
 19
 20            # we only check for compatible shapes if images support memmap, because
 21            # we don't want to load everything into ram
 22            if supports_memmap(raw_im):
 23                shape = load_image(raw_im).shape
 24                assert len(shape) in (2, 3)
 25
 26                multichan = len(shape) == 3
 27                if is_multichan is None:
 28                    is_multichan = multichan
 29                else:
 30                    assert is_multichan == multichan
 31
 32                # we assume axis last
 33                if is_multichan:
 34                    shape = shape[:-1]
 35
 36    def __init__(
 37        self,
 38        raw_image_paths: Union[List[Any], str, os.PathLike],
 39        patch_shape: Tuple[int, ...],
 40        raw_transform=None,
 41        transform=None,
 42        dtype: torch.dtype = torch.float32,
 43        n_samples: Optional[int] = None,
 44        sampler=None,
 45        augmentations=None,
 46        full_check: bool = False,
 47    ):
 48        self._check_inputs(raw_image_paths, full_check)
 49        self.raw_images = raw_image_paths
 50        self._ndim = 2
 51
 52        assert len(patch_shape) == self._ndim
 53        self.patch_shape = patch_shape
 54
 55        self.raw_transform = raw_transform
 56        self.transform = transform
 57        self.dtype = dtype
 58        self.sampler = sampler
 59
 60        if n_samples is None:
 61            self._len = len(self.raw_images)
 62            self.sample_random_index = False
 63        else:
 64            self._len = n_samples
 65            self.sample_random_index = True
 66
 67        if augmentations is not None:
 68            assert len(augmentations) == 2
 69        self.augmentations = augmentations
 70
 71    def __len__(self):
 72        return self._len
 73
 74    @property
 75    def ndim(self):
 76        return self._ndim
 77
 78    def _sample_bounding_box(self, shape):
 79        bb_start = [
 80            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
 81            for sh, psh in zip(shape, self.patch_shape)
 82        ]
 83        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))
 84
 85    def _ensure_patch_shape(self, raw, have_raw_channels, channel_first):
 86        shape = raw.shape
 87        if have_raw_channels and channel_first:
 88            shape = shape[1:]
 89        if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
 90            pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)]
 91
 92            if have_raw_channels and channel_first:
 93                pw_raw = [(0, 0), *pw]
 94            elif have_raw_channels and not channel_first:
 95                pw_raw = [*pw, (0, 0)]
 96            else:
 97                pw_raw = pw
 98
 99            raw = np.pad(raw, pw_raw)
100        return raw
101
102    def _get_sample(self, index):
103        if self.sample_random_index:
104            index = np.random.randint(0, len(self.raw_images))
105        raw = load_image(self.raw_images[index])
106        have_raw_channels = raw.ndim == 3
107
108        # We determine if the image has channels as the first or last axis based on the array shape.
109        # This will work only for images with less than 16 channels!
110        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
111        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
112        channel_first = None
113        if have_raw_channels:
114            channel_first = raw.shape[-1] > 16
115
116        raw = self._ensure_patch_shape(raw, have_raw_channels, channel_first)
117
118        shape = raw.shape
119        # we assume images are loaded with channel last!
120        if have_raw_channels:
121            shape = shape[:-1]
122
123        # sample random bounding box for this image
124        bb = self._sample_bounding_box(shape)
125        raw = np.array(raw[bb])
126
127        if self.sampler is not None:
128            sample_id = 0
129            while not self.sampler(raw):
130                bb = self._sample_bounding_box(shape)
131                raw = np.array(raw[bb])
132                sample_id += 1
133                if sample_id > self.max_sampling_attempts:
134                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
135
136        # to channel first
137        if have_raw_channels:
138            raw = raw.transpose((2, 0, 1))
139
140        return raw
141
142    def __getitem__(self, index):
143        raw = self._get_sample(index)
144
145        if self.raw_transform is not None:
146            raw = self.raw_transform(raw)
147
148        if self.transform is not None:
149            raw = self.transform(raw)
150            assert len(raw) == 1
151            raw = raw[0]
152            # if self.trafo_halo is not None:
153            #     raw = self.crop(raw)
154
155        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
156        if self.augmentations is not None:
157            aug1, aug2 = self.augmentations
158            raw1, raw2 = aug1(raw), aug2(raw)
159            return raw1, raw2
160
161        return raw
class RawImageCollectionDataset(typing.Generic[+T_co]):
 11class RawImageCollectionDataset(torch.utils.data.Dataset):
 12    max_sampling_attempts = 500
 13
 14    def _check_inputs(self, raw_images, full_check):
 15        if not full_check:
 16            return
 17
 18        is_multichan = None
 19        for raw_im in raw_images:
 20
 21            # we only check for compatible shapes if images support memmap, because
 22            # we don't want to load everything into ram
 23            if supports_memmap(raw_im):
 24                shape = load_image(raw_im).shape
 25                assert len(shape) in (2, 3)
 26
 27                multichan = len(shape) == 3
 28                if is_multichan is None:
 29                    is_multichan = multichan
 30                else:
 31                    assert is_multichan == multichan
 32
 33                # we assume axis last
 34                if is_multichan:
 35                    shape = shape[:-1]
 36
 37    def __init__(
 38        self,
 39        raw_image_paths: Union[List[Any], str, os.PathLike],
 40        patch_shape: Tuple[int, ...],
 41        raw_transform=None,
 42        transform=None,
 43        dtype: torch.dtype = torch.float32,
 44        n_samples: Optional[int] = None,
 45        sampler=None,
 46        augmentations=None,
 47        full_check: bool = False,
 48    ):
 49        self._check_inputs(raw_image_paths, full_check)
 50        self.raw_images = raw_image_paths
 51        self._ndim = 2
 52
 53        assert len(patch_shape) == self._ndim
 54        self.patch_shape = patch_shape
 55
 56        self.raw_transform = raw_transform
 57        self.transform = transform
 58        self.dtype = dtype
 59        self.sampler = sampler
 60
 61        if n_samples is None:
 62            self._len = len(self.raw_images)
 63            self.sample_random_index = False
 64        else:
 65            self._len = n_samples
 66            self.sample_random_index = True
 67
 68        if augmentations is not None:
 69            assert len(augmentations) == 2
 70        self.augmentations = augmentations
 71
 72    def __len__(self):
 73        return self._len
 74
 75    @property
 76    def ndim(self):
 77        return self._ndim
 78
 79    def _sample_bounding_box(self, shape):
 80        bb_start = [
 81            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
 82            for sh, psh in zip(shape, self.patch_shape)
 83        ]
 84        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))
 85
 86    def _ensure_patch_shape(self, raw, have_raw_channels, channel_first):
 87        shape = raw.shape
 88        if have_raw_channels and channel_first:
 89            shape = shape[1:]
 90        if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
 91            pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)]
 92
 93            if have_raw_channels and channel_first:
 94                pw_raw = [(0, 0), *pw]
 95            elif have_raw_channels and not channel_first:
 96                pw_raw = [*pw, (0, 0)]
 97            else:
 98                pw_raw = pw
 99
100            raw = np.pad(raw, pw_raw)
101        return raw
102
103    def _get_sample(self, index):
104        if self.sample_random_index:
105            index = np.random.randint(0, len(self.raw_images))
106        raw = load_image(self.raw_images[index])
107        have_raw_channels = raw.ndim == 3
108
109        # We determine if the image has channels as the first or last axis based on the array shape.
110        # This will work only for images with less than 16 channels!
111        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
112        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
113        channel_first = None
114        if have_raw_channels:
115            channel_first = raw.shape[-1] > 16
116
117        raw = self._ensure_patch_shape(raw, have_raw_channels, channel_first)
118
119        shape = raw.shape
120        # we assume images are loaded with channel last!
121        if have_raw_channels:
122            shape = shape[:-1]
123
124        # sample random bounding box for this image
125        bb = self._sample_bounding_box(shape)
126        raw = np.array(raw[bb])
127
128        if self.sampler is not None:
129            sample_id = 0
130            while not self.sampler(raw):
131                bb = self._sample_bounding_box(shape)
132                raw = np.array(raw[bb])
133                sample_id += 1
134                if sample_id > self.max_sampling_attempts:
135                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
136
137        # to channel first
138        if have_raw_channels:
139            raw = raw.transpose((2, 0, 1))
140
141        return raw
142
143    def __getitem__(self, index):
144        raw = self._get_sample(index)
145
146        if self.raw_transform is not None:
147            raw = self.raw_transform(raw)
148
149        if self.transform is not None:
150            raw = self.transform(raw)
151            assert len(raw) == 1
152            raw = raw[0]
153            # if self.trafo_halo is not None:
154            #     raw = self.crop(raw)
155
156        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
157        if self.augmentations is not None:
158            aug1, aug2 = self.augmentations
159            raw1, raw2 = aug1(raw), aug2(raw)
160            return raw1, raw2
161
162        return raw

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

RawImageCollectionDataset( raw_image_paths: Union[List[Any], str, os.PathLike], patch_shape: Tuple[int, ...], raw_transform=None, transform=None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler=None, augmentations=None, full_check: bool = False)
37    def __init__(
38        self,
39        raw_image_paths: Union[List[Any], str, os.PathLike],
40        patch_shape: Tuple[int, ...],
41        raw_transform=None,
42        transform=None,
43        dtype: torch.dtype = torch.float32,
44        n_samples: Optional[int] = None,
45        sampler=None,
46        augmentations=None,
47        full_check: bool = False,
48    ):
49        self._check_inputs(raw_image_paths, full_check)
50        self.raw_images = raw_image_paths
51        self._ndim = 2
52
53        assert len(patch_shape) == self._ndim
54        self.patch_shape = patch_shape
55
56        self.raw_transform = raw_transform
57        self.transform = transform
58        self.dtype = dtype
59        self.sampler = sampler
60
61        if n_samples is None:
62            self._len = len(self.raw_images)
63            self.sample_random_index = False
64        else:
65            self._len = n_samples
66            self.sample_random_index = True
67
68        if augmentations is not None:
69            assert len(augmentations) == 2
70        self.augmentations = augmentations
max_sampling_attempts = 500
raw_images
patch_shape
raw_transform
transform
dtype
sampler
augmentations
ndim