torch_em.data.image_collection_dataset

  1import os
  2import numpy as np
  3from typing import List, Optional, Tuple, Union
  4
  5import torch
  6
  7from ..util import (ensure_spatial_array, ensure_tensor_with_channels,
  8                    load_image, supports_memmap)
  9
 10
 11class ImageCollectionDataset(torch.utils.data.Dataset):
 12    max_sampling_attempts = 500
 13    max_sampling_attempts_image = 50
 14
 15    def _check_inputs(self, raw_images, label_images, full_check):
 16        if len(raw_images) != len(label_images):
 17            raise ValueError(f"Expect same number of  and label images, got {len(raw_images)} and {len(label_images)}")
 18
 19        if not full_check:
 20            return
 21
 22        is_multichan = None
 23        for raw_im, label_im in zip(raw_images, label_images):
 24
 25            # we only check for compatible shapes if both images support memmap, because
 26            # we don't want to load everything into ram
 27            if supports_memmap(raw_im) and supports_memmap(label_im):
 28                shape = load_image(raw_im).shape
 29                assert len(shape) in (2, 3)
 30
 31                multichan = len(shape) == 3
 32                if is_multichan is None:
 33                    is_multichan = multichan
 34                else:
 35                    assert is_multichan == multichan
 36
 37                if is_multichan:
 38                    # use heuristic to decide whether the data is stored in channel last or channel first order:
 39                    # if the last axis has a length smaller than 16 we assume that it's the channel axis,
 40                    # otherwise we assume it's a spatial axis and that the first axis is the channel axis.
 41                    if shape[-1] < 16:
 42                        shape = shape[:-1]
 43                    else:
 44                        shape = shape[1:]
 45
 46                label_shape = load_image(label_im).shape
 47                if shape != label_shape:
 48                    msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}"
 49                    raise ValueError(msg)
 50
 51    def __init__(
 52        self,
 53        raw_image_paths: List[Union[str, os.PathLike]],
 54        label_image_paths: List[Union[str, os.PathLike]],
 55        patch_shape: Tuple[int, ...],
 56        raw_transform=None,
 57        label_transform=None,
 58        label_transform2=None,
 59        transform=None,
 60        dtype: torch.dtype = torch.float32,
 61        label_dtype: torch.dtype = torch.float32,
 62        n_samples: Optional[int] = None,
 63        sampler=None,
 64        full_check: bool = False,
 65    ):
 66        self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check)
 67        self.raw_images = raw_image_paths
 68        self.label_images = label_image_paths
 69        self._ndim = 2
 70
 71        if patch_shape is not None:
 72            assert len(patch_shape) == self._ndim
 73        self.patch_shape = patch_shape
 74
 75        self.raw_transform = raw_transform
 76        self.label_transform = label_transform
 77        self.label_transform2 = label_transform2
 78        self.transform = transform
 79        self.sampler = sampler
 80
 81        self.dtype = dtype
 82        self.label_dtype = label_dtype
 83
 84        if n_samples is None:
 85            self._len = len(self.raw_images)
 86            self.sample_random_index = False
 87        else:
 88            self._len = n_samples
 89            self.sample_random_index = True
 90
 91    def __len__(self):
 92        return self._len
 93
 94    @property
 95    def ndim(self):
 96        return self._ndim
 97
 98    def _sample_bounding_box(self, shape):
 99        if self.patch_shape is None:
100            patch_shape_for_bb = shape
101            bb_start = [0] * len(shape)
102        else:
103            patch_shape_for_bb = self.patch_shape
104            bb_start = [
105                np.random.randint(0, sh - psh) if sh - psh > 0 else 0
106                for sh, psh in zip(shape, patch_shape_for_bb)
107            ]
108
109        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
110
111    def _ensure_patch_shape(self, raw, labels, have_raw_channels, have_label_channels, channel_first):
112        shape = raw.shape
113        if have_raw_channels and channel_first:
114            shape = shape[1:]
115        if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
116            pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)]
117
118            if have_raw_channels and channel_first:
119                pw_raw = [(0, 0), *pw]
120            elif have_raw_channels and not channel_first:
121                pw_raw = [*pw, (0, 0)]
122            else:
123                pw_raw = pw
124
125            # TODO: ensure padding for labels with channels, when supported (see `_get_sample` below)
126
127            raw, labels = np.pad(raw, pw_raw), np.pad(labels, pw)
128        return raw, labels
129
130    def _load_data(self, raw_path, label_path):
131        raw = load_image(raw_path, memmap=False)
132        label = load_image(label_path, memmap=False)
133
134        have_raw_channels = raw.ndim == 3
135        have_label_channels = label.ndim == 3
136        if have_label_channels:
137            raise NotImplementedError("Multi-channel labels are not supported.")
138
139        # We determine if the image has channels as the first or last axis based on the array shape.
140        # This will work only for images with less than 16 channels!
141        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
142        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
143        channel_first = None
144        if have_raw_channels:
145            channel_first = raw.shape[-1] > 16
146
147        if self.patch_shape is not None:
148            raw, label = self._ensure_patch_shape(raw, label, have_raw_channels, have_label_channels, channel_first)
149        shape = raw.shape
150
151        prefix_box = tuple()
152        if have_raw_channels:
153            if channel_first:
154                shape = shape[1:]
155                prefix_box = (slice(None), )
156            else:
157                shape = shape[:-1]
158
159        return raw, label, shape, prefix_box, have_raw_channels
160
161    def _get_sample(self, index):
162        if self.sample_random_index:
163            index = np.random.randint(0, len(self.raw_images))
164
165        # The filepath corresponding to this image.
166        raw_path, label_path = self.raw_images[index], self.label_images[index]
167
168        # Load the corresponding data.
169        raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
170
171        # Sample random bounding box for this image.
172        bb = self._sample_bounding_box(shape)
173        raw_patch = np.array(raw[prefix_box + bb])
174        label_patch = np.array(label[bb])
175
176        if self.sampler is not None:
177            sample_id = 0
178            while not self.sampler(raw_patch, label_patch):
179                bb = self._sample_bounding_box(shape)
180                raw_patch = np.array(raw[prefix_box + bb])
181                label_patch = np.array(label[bb])
182                sample_id += 1
183
184                # We need to avoid sampling from the same image over and over agagin,
185                # otherwise this will fail just because of one or a few empty images.
186                # Hence we update the image from which we sample sometimes.
187                if sample_id % self.max_sampling_attempts_image == 0:
188                    index = np.random.randint(0, len(self.raw_images))
189                    raw_path, label_path = self.raw_images[index], self.label_images[index]
190                    raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
191
192                if sample_id > self.max_sampling_attempts:
193                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
194
195        # to channel first
196        if have_raw_channels and len(prefix_box) == 0:
197            raw_patch = raw_patch.transpose((2, 0, 1))
198
199        return raw_patch, label_patch
200
201    def __getitem__(self, index):
202        raw, labels = self._get_sample(index)
203        initial_label_dtype = labels.dtype
204
205        if self.raw_transform is not None:
206            raw = self.raw_transform(raw)
207
208        if self.label_transform is not None:
209            labels = self.label_transform(labels)
210
211        if self.transform is not None:
212            raw, labels = self.transform(raw, labels)
213            # if self.trafo_halo is not None:
214            #     raw = self.crop(raw)
215            #     labels = self.crop(labels)
216
217        # support enlarging bounding box here as well (for affinity transform) ?
218        if self.label_transform2 is not None:
219            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
220            labels = self.label_transform2(labels)
221
222        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
223        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
224        return raw, labels
class ImageCollectionDataset(typing.Generic[+T_co]):
 12class ImageCollectionDataset(torch.utils.data.Dataset):
 13    max_sampling_attempts = 500
 14    max_sampling_attempts_image = 50
 15
 16    def _check_inputs(self, raw_images, label_images, full_check):
 17        if len(raw_images) != len(label_images):
 18            raise ValueError(f"Expect same number of  and label images, got {len(raw_images)} and {len(label_images)}")
 19
 20        if not full_check:
 21            return
 22
 23        is_multichan = None
 24        for raw_im, label_im in zip(raw_images, label_images):
 25
 26            # we only check for compatible shapes if both images support memmap, because
 27            # we don't want to load everything into ram
 28            if supports_memmap(raw_im) and supports_memmap(label_im):
 29                shape = load_image(raw_im).shape
 30                assert len(shape) in (2, 3)
 31
 32                multichan = len(shape) == 3
 33                if is_multichan is None:
 34                    is_multichan = multichan
 35                else:
 36                    assert is_multichan == multichan
 37
 38                if is_multichan:
 39                    # use heuristic to decide whether the data is stored in channel last or channel first order:
 40                    # if the last axis has a length smaller than 16 we assume that it's the channel axis,
 41                    # otherwise we assume it's a spatial axis and that the first axis is the channel axis.
 42                    if shape[-1] < 16:
 43                        shape = shape[:-1]
 44                    else:
 45                        shape = shape[1:]
 46
 47                label_shape = load_image(label_im).shape
 48                if shape != label_shape:
 49                    msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}"
 50                    raise ValueError(msg)
 51
 52    def __init__(
 53        self,
 54        raw_image_paths: List[Union[str, os.PathLike]],
 55        label_image_paths: List[Union[str, os.PathLike]],
 56        patch_shape: Tuple[int, ...],
 57        raw_transform=None,
 58        label_transform=None,
 59        label_transform2=None,
 60        transform=None,
 61        dtype: torch.dtype = torch.float32,
 62        label_dtype: torch.dtype = torch.float32,
 63        n_samples: Optional[int] = None,
 64        sampler=None,
 65        full_check: bool = False,
 66    ):
 67        self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check)
 68        self.raw_images = raw_image_paths
 69        self.label_images = label_image_paths
 70        self._ndim = 2
 71
 72        if patch_shape is not None:
 73            assert len(patch_shape) == self._ndim
 74        self.patch_shape = patch_shape
 75
 76        self.raw_transform = raw_transform
 77        self.label_transform = label_transform
 78        self.label_transform2 = label_transform2
 79        self.transform = transform
 80        self.sampler = sampler
 81
 82        self.dtype = dtype
 83        self.label_dtype = label_dtype
 84
 85        if n_samples is None:
 86            self._len = len(self.raw_images)
 87            self.sample_random_index = False
 88        else:
 89            self._len = n_samples
 90            self.sample_random_index = True
 91
 92    def __len__(self):
 93        return self._len
 94
 95    @property
 96    def ndim(self):
 97        return self._ndim
 98
 99    def _sample_bounding_box(self, shape):
100        if self.patch_shape is None:
101            patch_shape_for_bb = shape
102            bb_start = [0] * len(shape)
103        else:
104            patch_shape_for_bb = self.patch_shape
105            bb_start = [
106                np.random.randint(0, sh - psh) if sh - psh > 0 else 0
107                for sh, psh in zip(shape, patch_shape_for_bb)
108            ]
109
110        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
111
112    def _ensure_patch_shape(self, raw, labels, have_raw_channels, have_label_channels, channel_first):
113        shape = raw.shape
114        if have_raw_channels and channel_first:
115            shape = shape[1:]
116        if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
117            pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)]
118
119            if have_raw_channels and channel_first:
120                pw_raw = [(0, 0), *pw]
121            elif have_raw_channels and not channel_first:
122                pw_raw = [*pw, (0, 0)]
123            else:
124                pw_raw = pw
125
126            # TODO: ensure padding for labels with channels, when supported (see `_get_sample` below)
127
128            raw, labels = np.pad(raw, pw_raw), np.pad(labels, pw)
129        return raw, labels
130
131    def _load_data(self, raw_path, label_path):
132        raw = load_image(raw_path, memmap=False)
133        label = load_image(label_path, memmap=False)
134
135        have_raw_channels = raw.ndim == 3
136        have_label_channels = label.ndim == 3
137        if have_label_channels:
138            raise NotImplementedError("Multi-channel labels are not supported.")
139
140        # We determine if the image has channels as the first or last axis based on the array shape.
141        # This will work only for images with less than 16 channels!
142        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
143        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
144        channel_first = None
145        if have_raw_channels:
146            channel_first = raw.shape[-1] > 16
147
148        if self.patch_shape is not None:
149            raw, label = self._ensure_patch_shape(raw, label, have_raw_channels, have_label_channels, channel_first)
150        shape = raw.shape
151
152        prefix_box = tuple()
153        if have_raw_channels:
154            if channel_first:
155                shape = shape[1:]
156                prefix_box = (slice(None), )
157            else:
158                shape = shape[:-1]
159
160        return raw, label, shape, prefix_box, have_raw_channels
161
162    def _get_sample(self, index):
163        if self.sample_random_index:
164            index = np.random.randint(0, len(self.raw_images))
165
166        # The filepath corresponding to this image.
167        raw_path, label_path = self.raw_images[index], self.label_images[index]
168
169        # Load the corresponding data.
170        raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
171
172        # Sample random bounding box for this image.
173        bb = self._sample_bounding_box(shape)
174        raw_patch = np.array(raw[prefix_box + bb])
175        label_patch = np.array(label[bb])
176
177        if self.sampler is not None:
178            sample_id = 0
179            while not self.sampler(raw_patch, label_patch):
180                bb = self._sample_bounding_box(shape)
181                raw_patch = np.array(raw[prefix_box + bb])
182                label_patch = np.array(label[bb])
183                sample_id += 1
184
185                # We need to avoid sampling from the same image over and over agagin,
186                # otherwise this will fail just because of one or a few empty images.
187                # Hence we update the image from which we sample sometimes.
188                if sample_id % self.max_sampling_attempts_image == 0:
189                    index = np.random.randint(0, len(self.raw_images))
190                    raw_path, label_path = self.raw_images[index], self.label_images[index]
191                    raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
192
193                if sample_id > self.max_sampling_attempts:
194                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
195
196        # to channel first
197        if have_raw_channels and len(prefix_box) == 0:
198            raw_patch = raw_patch.transpose((2, 0, 1))
199
200        return raw_patch, label_patch
201
202    def __getitem__(self, index):
203        raw, labels = self._get_sample(index)
204        initial_label_dtype = labels.dtype
205
206        if self.raw_transform is not None:
207            raw = self.raw_transform(raw)
208
209        if self.label_transform is not None:
210            labels = self.label_transform(labels)
211
212        if self.transform is not None:
213            raw, labels = self.transform(raw, labels)
214            # if self.trafo_halo is not None:
215            #     raw = self.crop(raw)
216            #     labels = self.crop(labels)
217
218        # support enlarging bounding box here as well (for affinity transform) ?
219        if self.label_transform2 is not None:
220            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
221            labels = self.label_transform2(labels)
222
223        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
224        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
225        return raw, labels

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

ImageCollectionDataset( raw_image_paths: List[Union[str, os.PathLike]], label_image_paths: List[Union[str, os.PathLike]], patch_shape: Tuple[int, ...], raw_transform=None, label_transform=None, label_transform2=None, transform=None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler=None, full_check: bool = False)
52    def __init__(
53        self,
54        raw_image_paths: List[Union[str, os.PathLike]],
55        label_image_paths: List[Union[str, os.PathLike]],
56        patch_shape: Tuple[int, ...],
57        raw_transform=None,
58        label_transform=None,
59        label_transform2=None,
60        transform=None,
61        dtype: torch.dtype = torch.float32,
62        label_dtype: torch.dtype = torch.float32,
63        n_samples: Optional[int] = None,
64        sampler=None,
65        full_check: bool = False,
66    ):
67        self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check)
68        self.raw_images = raw_image_paths
69        self.label_images = label_image_paths
70        self._ndim = 2
71
72        if patch_shape is not None:
73            assert len(patch_shape) == self._ndim
74        self.patch_shape = patch_shape
75
76        self.raw_transform = raw_transform
77        self.label_transform = label_transform
78        self.label_transform2 = label_transform2
79        self.transform = transform
80        self.sampler = sampler
81
82        self.dtype = dtype
83        self.label_dtype = label_dtype
84
85        if n_samples is None:
86            self._len = len(self.raw_images)
87            self.sample_random_index = False
88        else:
89            self._len = n_samples
90            self.sample_random_index = True
max_sampling_attempts = 500
max_sampling_attempts_image = 50
raw_images
label_images
patch_shape
raw_transform
label_transform
label_transform2
transform
sampler
dtype
label_dtype
ndim