torch_em.data.image_collection_dataset

  1import os
  2import numpy as np
  3from typing import List, Optional, Tuple, Union, Callable
  4
  5import torch
  6
  7from ..util import (
  8    ensure_spatial_array, ensure_tensor_with_channels, load_image, supports_memmap, ensure_patch_shape
  9)
 10
 11
 12class ImageCollectionDataset(torch.utils.data.Dataset):
 13    """Dataset that provides raw data and labels stored in a regular image data format for segmentation training.
 14
 15    The dataset returns patches loaded from the images and labels as sample for a batch.
 16    The raw data and labels are expected to be images of the same shape, except for possible channels.
 17    It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
 18
 19    Args:
 20        raw_image_paths: The file paths to the raw data.
 21        label_image_paths: The file path to the label data.
 22        patch_shape: The patch shape for a training sample.
 23        raw_transform: Transformation applied to the raw data of a sample.
 24        label_transform: Transformation applied to the label data of a sample,
 25            before applying augmentations via `transform`.
 26        label_transform2: Transformation applied to the label data of a sample,
 27            after applying augmentations via `transform`.
 28        transform: Transformation applied to both the raw data and label data of a sample.
 29            This can be used to implement data augmentations.
 30        dtype: The return data type of the raw data.
 31        label_dtype: The return data type of the label data.
 32        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 33        sampler: Sampler for rejecting samples according to a defined criterion.
 34            The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
 35        full_check: Whether to check that the input data is valid for all image paths.
 36            This will ensure that the data is valid, but will take longer for creating the dataset.
 37        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 38    """
 39    max_sampling_attempts = 500
 40    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 41    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 42    """
 43    max_sampling_attempts_image = 50
 44    """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`.
 45    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 46    """
 47
 48    def _check_inputs(self, raw_images, label_images, full_check):
 49        if len(raw_images) != len(label_images):
 50            raise ValueError(f"Expect same number of  and label images, got {len(raw_images)} and {len(label_images)}")
 51
 52        if not full_check:
 53            return
 54
 55        is_multichan = None
 56        for raw_im, label_im in zip(raw_images, label_images):
 57
 58            # we only check for compatible shapes if both images support memmap, because
 59            # we don't want to load everything into ram
 60            if supports_memmap(raw_im) and supports_memmap(label_im):
 61                shape = load_image(raw_im).shape
 62                assert len(shape) in (2, 3)
 63
 64                multichan = len(shape) == 3
 65                if is_multichan is None:
 66                    is_multichan = multichan
 67                else:
 68                    assert is_multichan == multichan
 69
 70                if is_multichan:
 71                    # use heuristic to decide whether the data is stored in channel last or channel first order:
 72                    # if the last axis has a length smaller than 16 we assume that it's the channel axis,
 73                    # otherwise we assume it's a spatial axis and that the first axis is the channel axis.
 74                    if shape[-1] < 16:
 75                        shape = shape[:-1]
 76                    else:
 77                        shape = shape[1:]
 78
 79                label_shape = load_image(label_im).shape
 80                if shape != label_shape:
 81                    msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}"
 82                    raise ValueError(msg)
 83
 84    def __init__(
 85        self,
 86        raw_image_paths: List[Union[str, os.PathLike]],
 87        label_image_paths: List[Union[str, os.PathLike]],
 88        patch_shape: Tuple[int, ...],
 89        raw_transform: Optional[Callable] = None,
 90        label_transform: Optional[Callable] = None,
 91        label_transform2: Optional[Callable] = None,
 92        transform: Optional[Callable] = None,
 93        dtype: torch.dtype = torch.float32,
 94        label_dtype: torch.dtype = torch.float32,
 95        n_samples: Optional[int] = None,
 96        sampler: Optional[Callable] = None,
 97        full_check: bool = False,
 98        with_padding: bool = True,
 99    ) -> None:
100        self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check)
101        self.raw_images = raw_image_paths
102        self.label_images = label_image_paths
103        self._ndim = 2
104
105        if patch_shape is not None:
106            assert len(patch_shape) == self._ndim
107        self.patch_shape = patch_shape
108
109        self.raw_transform = raw_transform
110        self.label_transform = label_transform
111        self.label_transform2 = label_transform2
112        self.transform = transform
113        self.sampler = sampler
114        self.with_padding = with_padding
115
116        self.dtype = dtype
117        self.label_dtype = label_dtype
118
119        if n_samples is None:
120            self._len = len(self.raw_images)
121            self.sample_random_index = False
122        else:
123            self._len = n_samples
124            self.sample_random_index = True
125
126    def __len__(self):
127        return self._len
128
129    @property
130    def ndim(self):
131        return self._ndim
132
133    def _sample_bounding_box(self, shape):
134        if self.patch_shape is None:
135            patch_shape_for_bb = shape
136            bb_start = [0] * len(shape)
137        else:
138            patch_shape_for_bb = self.patch_shape
139            bb_start = [
140                np.random.randint(0, sh - psh) if sh - psh > 0 else 0
141                for sh, psh in zip(shape, patch_shape_for_bb)
142            ]
143
144        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
145
146    def _load_data(self, raw_path, label_path):
147        if getattr(self, "have_tensor_data", False):
148            raw, label = raw_path, label_path
149        else:
150            raw = load_image(raw_path, memmap=False)
151            label = load_image(label_path, memmap=False)
152
153        have_raw_channels = getattr(self, "with_channels", raw.ndim == 3)
154        have_label_channels = getattr(self, "with_label_channels", label.ndim == 3)
155        if have_label_channels:
156            raise NotImplementedError("Multi-channel labels are not supported.")
157
158        # We determine if the image has channels as the first or last axis based on the array shape.
159        # This will work only for images with less than 16 channels!
160        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
161        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
162        channel_first = None
163        if have_raw_channels:
164            channel_first = raw.shape[-1] > 16
165
166        if self.patch_shape is not None and self.with_padding:
167            raw, label = ensure_patch_shape(
168                raw=raw,
169                labels=label,
170                patch_shape=self.patch_shape,
171                have_raw_channels=have_raw_channels,
172                have_label_channels=have_label_channels,
173                channel_first=channel_first
174            )
175
176        shape = raw.shape
177
178        prefix_box = tuple()
179        if have_raw_channels:
180            if channel_first:
181                shape = shape[1:]
182                prefix_box = (slice(None), )
183            else:
184                shape = shape[:-1]
185
186        return raw, label, shape, prefix_box, have_raw_channels
187
188    def _get_sample(self, index):
189        if self.sample_random_index:
190            index = np.random.randint(0, len(self.raw_images))
191
192        # The filepath corresponding to this image.
193        raw_path, label_path = self.raw_images[index], self.label_images[index]
194
195        # Load the corresponding data.
196        raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
197
198        # Sample random bounding box for this image.
199        bb = self._sample_bounding_box(shape)
200        raw_patch = np.array(raw[prefix_box + bb])
201        label_patch = np.array(label[bb])
202
203        if self.sampler is not None:
204            sample_id = 0
205            while not self.sampler(raw_patch, label_patch):
206                bb = self._sample_bounding_box(shape)
207                raw_patch = np.array(raw[prefix_box + bb])
208                label_patch = np.array(label[bb])
209                sample_id += 1
210
211                # We need to avoid sampling from the same image over and over again,
212                # otherwise this will fail just because of one or a few empty images.
213                # Hence we update the image from which we sample sometimes.
214                if sample_id % self.max_sampling_attempts_image == 0:
215                    index = np.random.randint(0, len(self.raw_images))
216                    raw_path, label_path = self.raw_images[index], self.label_images[index]
217                    raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
218
219                if sample_id > self.max_sampling_attempts:
220                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
221
222        # to channel first
223        if have_raw_channels and len(prefix_box) == 0:
224            raw_patch = raw_patch.transpose((2, 0, 1))
225
226        return raw_patch, label_patch
227
228    def __getitem__(self, index):
229        raw, labels = self._get_sample(index)
230        initial_label_dtype = labels.dtype
231
232        if self.raw_transform is not None:
233            raw = self.raw_transform(raw)
234
235        if self.label_transform is not None:
236            labels = self.label_transform(labels)
237
238        if self.transform is not None:
239            raw, labels = self.transform(raw, labels)
240            # if self.trafo_halo is not None:
241            #     raw = self.crop(raw)
242            #     labels = self.crop(labels)
243
244        # support enlarging bounding box here as well (for affinity transform) ?
245        if self.label_transform2 is not None:
246            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
247            labels = self.label_transform2(labels)
248
249        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
250        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
251        return raw, labels
class ImageCollectionDataset(typing.Generic[+_T_co]):
 13class ImageCollectionDataset(torch.utils.data.Dataset):
 14    """Dataset that provides raw data and labels stored in a regular image data format for segmentation training.
 15
 16    The dataset returns patches loaded from the images and labels as sample for a batch.
 17    The raw data and labels are expected to be images of the same shape, except for possible channels.
 18    It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
 19
 20    Args:
 21        raw_image_paths: The file paths to the raw data.
 22        label_image_paths: The file path to the label data.
 23        patch_shape: The patch shape for a training sample.
 24        raw_transform: Transformation applied to the raw data of a sample.
 25        label_transform: Transformation applied to the label data of a sample,
 26            before applying augmentations via `transform`.
 27        label_transform2: Transformation applied to the label data of a sample,
 28            after applying augmentations via `transform`.
 29        transform: Transformation applied to both the raw data and label data of a sample.
 30            This can be used to implement data augmentations.
 31        dtype: The return data type of the raw data.
 32        label_dtype: The return data type of the label data.
 33        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 34        sampler: Sampler for rejecting samples according to a defined criterion.
 35            The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
 36        full_check: Whether to check that the input data is valid for all image paths.
 37            This will ensure that the data is valid, but will take longer for creating the dataset.
 38        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 39    """
 40    max_sampling_attempts = 500
 41    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 42    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 43    """
 44    max_sampling_attempts_image = 50
 45    """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`.
 46    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 47    """
 48
 49    def _check_inputs(self, raw_images, label_images, full_check):
 50        if len(raw_images) != len(label_images):
 51            raise ValueError(f"Expect same number of  and label images, got {len(raw_images)} and {len(label_images)}")
 52
 53        if not full_check:
 54            return
 55
 56        is_multichan = None
 57        for raw_im, label_im in zip(raw_images, label_images):
 58
 59            # we only check for compatible shapes if both images support memmap, because
 60            # we don't want to load everything into ram
 61            if supports_memmap(raw_im) and supports_memmap(label_im):
 62                shape = load_image(raw_im).shape
 63                assert len(shape) in (2, 3)
 64
 65                multichan = len(shape) == 3
 66                if is_multichan is None:
 67                    is_multichan = multichan
 68                else:
 69                    assert is_multichan == multichan
 70
 71                if is_multichan:
 72                    # use heuristic to decide whether the data is stored in channel last or channel first order:
 73                    # if the last axis has a length smaller than 16 we assume that it's the channel axis,
 74                    # otherwise we assume it's a spatial axis and that the first axis is the channel axis.
 75                    if shape[-1] < 16:
 76                        shape = shape[:-1]
 77                    else:
 78                        shape = shape[1:]
 79
 80                label_shape = load_image(label_im).shape
 81                if shape != label_shape:
 82                    msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}"
 83                    raise ValueError(msg)
 84
 85    def __init__(
 86        self,
 87        raw_image_paths: List[Union[str, os.PathLike]],
 88        label_image_paths: List[Union[str, os.PathLike]],
 89        patch_shape: Tuple[int, ...],
 90        raw_transform: Optional[Callable] = None,
 91        label_transform: Optional[Callable] = None,
 92        label_transform2: Optional[Callable] = None,
 93        transform: Optional[Callable] = None,
 94        dtype: torch.dtype = torch.float32,
 95        label_dtype: torch.dtype = torch.float32,
 96        n_samples: Optional[int] = None,
 97        sampler: Optional[Callable] = None,
 98        full_check: bool = False,
 99        with_padding: bool = True,
100    ) -> None:
101        self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check)
102        self.raw_images = raw_image_paths
103        self.label_images = label_image_paths
104        self._ndim = 2
105
106        if patch_shape is not None:
107            assert len(patch_shape) == self._ndim
108        self.patch_shape = patch_shape
109
110        self.raw_transform = raw_transform
111        self.label_transform = label_transform
112        self.label_transform2 = label_transform2
113        self.transform = transform
114        self.sampler = sampler
115        self.with_padding = with_padding
116
117        self.dtype = dtype
118        self.label_dtype = label_dtype
119
120        if n_samples is None:
121            self._len = len(self.raw_images)
122            self.sample_random_index = False
123        else:
124            self._len = n_samples
125            self.sample_random_index = True
126
127    def __len__(self):
128        return self._len
129
130    @property
131    def ndim(self):
132        return self._ndim
133
134    def _sample_bounding_box(self, shape):
135        if self.patch_shape is None:
136            patch_shape_for_bb = shape
137            bb_start = [0] * len(shape)
138        else:
139            patch_shape_for_bb = self.patch_shape
140            bb_start = [
141                np.random.randint(0, sh - psh) if sh - psh > 0 else 0
142                for sh, psh in zip(shape, patch_shape_for_bb)
143            ]
144
145        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
146
147    def _load_data(self, raw_path, label_path):
148        if getattr(self, "have_tensor_data", False):
149            raw, label = raw_path, label_path
150        else:
151            raw = load_image(raw_path, memmap=False)
152            label = load_image(label_path, memmap=False)
153
154        have_raw_channels = getattr(self, "with_channels", raw.ndim == 3)
155        have_label_channels = getattr(self, "with_label_channels", label.ndim == 3)
156        if have_label_channels:
157            raise NotImplementedError("Multi-channel labels are not supported.")
158
159        # We determine if the image has channels as the first or last axis based on the array shape.
160        # This will work only for images with less than 16 channels!
161        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
162        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
163        channel_first = None
164        if have_raw_channels:
165            channel_first = raw.shape[-1] > 16
166
167        if self.patch_shape is not None and self.with_padding:
168            raw, label = ensure_patch_shape(
169                raw=raw,
170                labels=label,
171                patch_shape=self.patch_shape,
172                have_raw_channels=have_raw_channels,
173                have_label_channels=have_label_channels,
174                channel_first=channel_first
175            )
176
177        shape = raw.shape
178
179        prefix_box = tuple()
180        if have_raw_channels:
181            if channel_first:
182                shape = shape[1:]
183                prefix_box = (slice(None), )
184            else:
185                shape = shape[:-1]
186
187        return raw, label, shape, prefix_box, have_raw_channels
188
189    def _get_sample(self, index):
190        if self.sample_random_index:
191            index = np.random.randint(0, len(self.raw_images))
192
193        # The filepath corresponding to this image.
194        raw_path, label_path = self.raw_images[index], self.label_images[index]
195
196        # Load the corresponding data.
197        raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
198
199        # Sample random bounding box for this image.
200        bb = self._sample_bounding_box(shape)
201        raw_patch = np.array(raw[prefix_box + bb])
202        label_patch = np.array(label[bb])
203
204        if self.sampler is not None:
205            sample_id = 0
206            while not self.sampler(raw_patch, label_patch):
207                bb = self._sample_bounding_box(shape)
208                raw_patch = np.array(raw[prefix_box + bb])
209                label_patch = np.array(label[bb])
210                sample_id += 1
211
212                # We need to avoid sampling from the same image over and over again,
213                # otherwise this will fail just because of one or a few empty images.
214                # Hence we update the image from which we sample sometimes.
215                if sample_id % self.max_sampling_attempts_image == 0:
216                    index = np.random.randint(0, len(self.raw_images))
217                    raw_path, label_path = self.raw_images[index], self.label_images[index]
218                    raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
219
220                if sample_id > self.max_sampling_attempts:
221                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
222
223        # to channel first
224        if have_raw_channels and len(prefix_box) == 0:
225            raw_patch = raw_patch.transpose((2, 0, 1))
226
227        return raw_patch, label_patch
228
229    def __getitem__(self, index):
230        raw, labels = self._get_sample(index)
231        initial_label_dtype = labels.dtype
232
233        if self.raw_transform is not None:
234            raw = self.raw_transform(raw)
235
236        if self.label_transform is not None:
237            labels = self.label_transform(labels)
238
239        if self.transform is not None:
240            raw, labels = self.transform(raw, labels)
241            # if self.trafo_halo is not None:
242            #     raw = self.crop(raw)
243            #     labels = self.crop(labels)
244
245        # support enlarging bounding box here as well (for affinity transform) ?
246        if self.label_transform2 is not None:
247            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
248            labels = self.label_transform2(labels)
249
250        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
251        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
252        return raw, labels

Dataset that provides raw data and labels stored in a regular image data format for segmentation training.

The dataset returns patches loaded from the images and labels as sample for a batch. The raw data and labels are expected to be images of the same shape, except for possible channels. It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.

Arguments:
  • raw_image_paths: The file paths to the raw data.
  • label_image_paths: The file path to the label data.
  • patch_shape: The patch shape for a training sample.
  • raw_transform: Transformation applied to the raw data of a sample.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
  • full_check: Whether to check that the input data is valid for all image paths. This will ensure that the data is valid, but will take longer for creating the dataset.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
ImageCollectionDataset( raw_image_paths: List[Union[str, os.PathLike]], label_image_paths: List[Union[str, os.PathLike]], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, full_check: bool = False, with_padding: bool = True)
 85    def __init__(
 86        self,
 87        raw_image_paths: List[Union[str, os.PathLike]],
 88        label_image_paths: List[Union[str, os.PathLike]],
 89        patch_shape: Tuple[int, ...],
 90        raw_transform: Optional[Callable] = None,
 91        label_transform: Optional[Callable] = None,
 92        label_transform2: Optional[Callable] = None,
 93        transform: Optional[Callable] = None,
 94        dtype: torch.dtype = torch.float32,
 95        label_dtype: torch.dtype = torch.float32,
 96        n_samples: Optional[int] = None,
 97        sampler: Optional[Callable] = None,
 98        full_check: bool = False,
 99        with_padding: bool = True,
100    ) -> None:
101        self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check)
102        self.raw_images = raw_image_paths
103        self.label_images = label_image_paths
104        self._ndim = 2
105
106        if patch_shape is not None:
107            assert len(patch_shape) == self._ndim
108        self.patch_shape = patch_shape
109
110        self.raw_transform = raw_transform
111        self.label_transform = label_transform
112        self.label_transform2 = label_transform2
113        self.transform = transform
114        self.sampler = sampler
115        self.with_padding = with_padding
116
117        self.dtype = dtype
118        self.label_dtype = label_dtype
119
120        if n_samples is None:
121            self._len = len(self.raw_images)
122            self.sample_random_index = False
123        else:
124            self._len = n_samples
125            self.sample_random_index = True
max_sampling_attempts = 500

The maximal number of sampling attempts, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

max_sampling_attempts_image = 50

The maximal number of sampling attempts for a single image, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

raw_images
label_images
patch_shape
raw_transform
label_transform
label_transform2
transform
sampler
with_padding
dtype
label_dtype
ndim
130    @property
131    def ndim(self):
132        return self._ndim