torch_em.data.image_collection_dataset

  1import os
  2import numpy as np
  3from typing import List, Optional, Tuple, Union, Callable
  4
  5import torch
  6
  7from ..util import (
  8    ensure_spatial_array, ensure_tensor_with_channels, load_image, supports_memmap, ensure_patch_shape
  9)
 10
 11
 12class ImageCollectionDataset(torch.utils.data.Dataset):
 13    """Dataset that provides raw data and labels stored in a regular image data format for segmentation training.
 14
 15    The dataset returns patches loaded from the images and labels as sample for a batch.
 16    The raw data and labels are expected to be images of the same shape, except for possible channels.
 17    It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
 18
 19    Args:
 20        raw_image_paths: The file paths to the raw data.
 21        label_image_paths: The file path to the label data.
 22        patch_shape: The patch shape for a training sample.
 23        raw_transform: Transformation applied to the raw data of a sample.
 24        label_transform: Transformation applied to the label data of a sample,
 25            before applying augmentations via `transform`.
 26        label_transform2: Transformation applied to the label data of a sample,
 27            after applying augmentations via `transform`.
 28        transform: Transformation applied to both the raw data and label data of a sample.
 29            This can be used to implement data augmentations.
 30        dtype: The return data type of the raw data.
 31        label_dtype: The return data type of the label data.
 32        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 33        sampler: Sampler for rejecting samples according to a defined criterion.
 34            The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
 35        full_check: Whether to check that the input data is valid for all image paths.
 36            This will ensure that the data is valid, but will take longer for creating the dataset.
 37        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 38    """
 39    max_sampling_attempts = 500
 40    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 41    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 42    """
 43    max_sampling_attempts_image = 50
 44    """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`.
 45    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 46    """
 47
 48    def _check_inputs(self, raw_images, label_images, full_check):
 49        if len(raw_images) != len(label_images):
 50            raise ValueError(f"Expect same number of  and label images, got {len(raw_images)} and {len(label_images)}")
 51
 52        if not full_check:
 53            return
 54
 55        is_multichan = None
 56        for raw_im, label_im in zip(raw_images, label_images):
 57
 58            # we only check for compatible shapes if both images support memmap, because
 59            # we don't want to load everything into ram
 60            if supports_memmap(raw_im) and supports_memmap(label_im):
 61                shape = load_image(raw_im).shape
 62                assert len(shape) in (2, 3)
 63
 64                multichan = len(shape) == 3
 65                if is_multichan is None:
 66                    is_multichan = multichan
 67                else:
 68                    assert is_multichan == multichan
 69
 70                if is_multichan:
 71                    # use heuristic to decide whether the data is stored in channel last or channel first order:
 72                    # if the last axis has a length smaller than 16 we assume that it's the channel axis,
 73                    # otherwise we assume it's a spatial axis and that the first axis is the channel axis.
 74                    if shape[-1] < 16:
 75                        shape = shape[:-1]
 76                    else:
 77                        shape = shape[1:]
 78
 79                label_shape = load_image(label_im).shape
 80                if shape != label_shape:
 81                    msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}"
 82                    raise ValueError(msg)
 83
 84    def __init__(
 85        self,
 86        raw_image_paths: List[Union[str, os.PathLike]],
 87        label_image_paths: List[Union[str, os.PathLike]],
 88        patch_shape: Tuple[int, ...],
 89        raw_transform: Optional[Callable] = None,
 90        label_transform: Optional[Callable] = None,
 91        label_transform2: Optional[Callable] = None,
 92        transform: Optional[Callable] = None,
 93        dtype: torch.dtype = torch.float32,
 94        label_dtype: torch.dtype = torch.float32,
 95        n_samples: Optional[int] = None,
 96        sampler: Optional[Callable] = None,
 97        full_check: bool = False,
 98        with_padding: bool = True,
 99    ):
100        self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check)
101        self.raw_images = raw_image_paths
102        self.label_images = label_image_paths
103        self._ndim = 2
104
105        if patch_shape is not None:
106            assert len(patch_shape) == self._ndim
107        self.patch_shape = patch_shape
108
109        self.raw_transform = raw_transform
110        self.label_transform = label_transform
111        self.label_transform2 = label_transform2
112        self.transform = transform
113        self.sampler = sampler
114        self.with_padding = with_padding
115
116        self.dtype = dtype
117        self.label_dtype = label_dtype
118
119        if n_samples is None:
120            self._len = len(self.raw_images)
121            self.sample_random_index = False
122        else:
123            self._len = n_samples
124            self.sample_random_index = True
125
126    def __len__(self):
127        return self._len
128
129    @property
130    def ndim(self):
131        return self._ndim
132
133    def _sample_bounding_box(self, shape):
134        if self.patch_shape is None:
135            patch_shape_for_bb = shape
136            bb_start = [0] * len(shape)
137        else:
138            patch_shape_for_bb = self.patch_shape
139            bb_start = [
140                np.random.randint(0, sh - psh) if sh - psh > 0 else 0
141                for sh, psh in zip(shape, patch_shape_for_bb)
142            ]
143
144        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
145
146    def _load_data(self, raw_path, label_path):
147        raw = load_image(raw_path, memmap=False)
148        label = load_image(label_path, memmap=False)
149
150        have_raw_channels = raw.ndim == 3
151        have_label_channels = label.ndim == 3
152        if have_label_channels:
153            raise NotImplementedError("Multi-channel labels are not supported.")
154
155        # We determine if the image has channels as the first or last axis based on the array shape.
156        # This will work only for images with less than 16 channels!
157        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
158        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
159        channel_first = None
160        if have_raw_channels:
161            channel_first = raw.shape[-1] > 16
162
163        if self.patch_shape is not None and self.with_padding:
164            raw, label = ensure_patch_shape(
165                raw=raw,
166                labels=label,
167                patch_shape=self.patch_shape,
168                have_raw_channels=have_raw_channels,
169                have_label_channels=have_label_channels,
170                channel_first=channel_first
171            )
172
173        shape = raw.shape
174
175        prefix_box = tuple()
176        if have_raw_channels:
177            if channel_first:
178                shape = shape[1:]
179                prefix_box = (slice(None), )
180            else:
181                shape = shape[:-1]
182
183        return raw, label, shape, prefix_box, have_raw_channels
184
185    def _get_sample(self, index):
186        if self.sample_random_index:
187            index = np.random.randint(0, len(self.raw_images))
188
189        # The filepath corresponding to this image.
190        raw_path, label_path = self.raw_images[index], self.label_images[index]
191
192        # Load the corresponding data.
193        raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
194
195        # Sample random bounding box for this image.
196        bb = self._sample_bounding_box(shape)
197        raw_patch = np.array(raw[prefix_box + bb])
198        label_patch = np.array(label[bb])
199
200        if self.sampler is not None:
201            sample_id = 0
202            while not self.sampler(raw_patch, label_patch):
203                bb = self._sample_bounding_box(shape)
204                raw_patch = np.array(raw[prefix_box + bb])
205                label_patch = np.array(label[bb])
206                sample_id += 1
207
208                # We need to avoid sampling from the same image over and over again,
209                # otherwise this will fail just because of one or a few empty images.
210                # Hence we update the image from which we sample sometimes.
211                if sample_id % self.max_sampling_attempts_image == 0:
212                    index = np.random.randint(0, len(self.raw_images))
213                    raw_path, label_path = self.raw_images[index], self.label_images[index]
214                    raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
215
216                if sample_id > self.max_sampling_attempts:
217                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
218
219        # to channel first
220        if have_raw_channels and len(prefix_box) == 0:
221            raw_patch = raw_patch.transpose((2, 0, 1))
222
223        return raw_patch, label_patch
224
225    def __getitem__(self, index):
226        raw, labels = self._get_sample(index)
227        initial_label_dtype = labels.dtype
228
229        if self.raw_transform is not None:
230            raw = self.raw_transform(raw)
231
232        if self.label_transform is not None:
233            labels = self.label_transform(labels)
234
235        if self.transform is not None:
236            raw, labels = self.transform(raw, labels)
237            # if self.trafo_halo is not None:
238            #     raw = self.crop(raw)
239            #     labels = self.crop(labels)
240
241        # support enlarging bounding box here as well (for affinity transform) ?
242        if self.label_transform2 is not None:
243            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
244            labels = self.label_transform2(labels)
245
246        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
247        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
248        return raw, labels
class ImageCollectionDataset(typing.Generic[+_T_co]):
 13class ImageCollectionDataset(torch.utils.data.Dataset):
 14    """Dataset that provides raw data and labels stored in a regular image data format for segmentation training.
 15
 16    The dataset returns patches loaded from the images and labels as sample for a batch.
 17    The raw data and labels are expected to be images of the same shape, except for possible channels.
 18    It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
 19
 20    Args:
 21        raw_image_paths: The file paths to the raw data.
 22        label_image_paths: The file path to the label data.
 23        patch_shape: The patch shape for a training sample.
 24        raw_transform: Transformation applied to the raw data of a sample.
 25        label_transform: Transformation applied to the label data of a sample,
 26            before applying augmentations via `transform`.
 27        label_transform2: Transformation applied to the label data of a sample,
 28            after applying augmentations via `transform`.
 29        transform: Transformation applied to both the raw data and label data of a sample.
 30            This can be used to implement data augmentations.
 31        dtype: The return data type of the raw data.
 32        label_dtype: The return data type of the label data.
 33        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 34        sampler: Sampler for rejecting samples according to a defined criterion.
 35            The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
 36        full_check: Whether to check that the input data is valid for all image paths.
 37            This will ensure that the data is valid, but will take longer for creating the dataset.
 38        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 39    """
 40    max_sampling_attempts = 500
 41    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 42    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 43    """
 44    max_sampling_attempts_image = 50
 45    """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`.
 46    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 47    """
 48
 49    def _check_inputs(self, raw_images, label_images, full_check):
 50        if len(raw_images) != len(label_images):
 51            raise ValueError(f"Expect same number of  and label images, got {len(raw_images)} and {len(label_images)}")
 52
 53        if not full_check:
 54            return
 55
 56        is_multichan = None
 57        for raw_im, label_im in zip(raw_images, label_images):
 58
 59            # we only check for compatible shapes if both images support memmap, because
 60            # we don't want to load everything into ram
 61            if supports_memmap(raw_im) and supports_memmap(label_im):
 62                shape = load_image(raw_im).shape
 63                assert len(shape) in (2, 3)
 64
 65                multichan = len(shape) == 3
 66                if is_multichan is None:
 67                    is_multichan = multichan
 68                else:
 69                    assert is_multichan == multichan
 70
 71                if is_multichan:
 72                    # use heuristic to decide whether the data is stored in channel last or channel first order:
 73                    # if the last axis has a length smaller than 16 we assume that it's the channel axis,
 74                    # otherwise we assume it's a spatial axis and that the first axis is the channel axis.
 75                    if shape[-1] < 16:
 76                        shape = shape[:-1]
 77                    else:
 78                        shape = shape[1:]
 79
 80                label_shape = load_image(label_im).shape
 81                if shape != label_shape:
 82                    msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}"
 83                    raise ValueError(msg)
 84
 85    def __init__(
 86        self,
 87        raw_image_paths: List[Union[str, os.PathLike]],
 88        label_image_paths: List[Union[str, os.PathLike]],
 89        patch_shape: Tuple[int, ...],
 90        raw_transform: Optional[Callable] = None,
 91        label_transform: Optional[Callable] = None,
 92        label_transform2: Optional[Callable] = None,
 93        transform: Optional[Callable] = None,
 94        dtype: torch.dtype = torch.float32,
 95        label_dtype: torch.dtype = torch.float32,
 96        n_samples: Optional[int] = None,
 97        sampler: Optional[Callable] = None,
 98        full_check: bool = False,
 99        with_padding: bool = True,
100    ):
101        self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check)
102        self.raw_images = raw_image_paths
103        self.label_images = label_image_paths
104        self._ndim = 2
105
106        if patch_shape is not None:
107            assert len(patch_shape) == self._ndim
108        self.patch_shape = patch_shape
109
110        self.raw_transform = raw_transform
111        self.label_transform = label_transform
112        self.label_transform2 = label_transform2
113        self.transform = transform
114        self.sampler = sampler
115        self.with_padding = with_padding
116
117        self.dtype = dtype
118        self.label_dtype = label_dtype
119
120        if n_samples is None:
121            self._len = len(self.raw_images)
122            self.sample_random_index = False
123        else:
124            self._len = n_samples
125            self.sample_random_index = True
126
127    def __len__(self):
128        return self._len
129
130    @property
131    def ndim(self):
132        return self._ndim
133
134    def _sample_bounding_box(self, shape):
135        if self.patch_shape is None:
136            patch_shape_for_bb = shape
137            bb_start = [0] * len(shape)
138        else:
139            patch_shape_for_bb = self.patch_shape
140            bb_start = [
141                np.random.randint(0, sh - psh) if sh - psh > 0 else 0
142                for sh, psh in zip(shape, patch_shape_for_bb)
143            ]
144
145        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
146
147    def _load_data(self, raw_path, label_path):
148        raw = load_image(raw_path, memmap=False)
149        label = load_image(label_path, memmap=False)
150
151        have_raw_channels = raw.ndim == 3
152        have_label_channels = label.ndim == 3
153        if have_label_channels:
154            raise NotImplementedError("Multi-channel labels are not supported.")
155
156        # We determine if the image has channels as the first or last axis based on the array shape.
157        # This will work only for images with less than 16 channels!
158        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
159        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
160        channel_first = None
161        if have_raw_channels:
162            channel_first = raw.shape[-1] > 16
163
164        if self.patch_shape is not None and self.with_padding:
165            raw, label = ensure_patch_shape(
166                raw=raw,
167                labels=label,
168                patch_shape=self.patch_shape,
169                have_raw_channels=have_raw_channels,
170                have_label_channels=have_label_channels,
171                channel_first=channel_first
172            )
173
174        shape = raw.shape
175
176        prefix_box = tuple()
177        if have_raw_channels:
178            if channel_first:
179                shape = shape[1:]
180                prefix_box = (slice(None), )
181            else:
182                shape = shape[:-1]
183
184        return raw, label, shape, prefix_box, have_raw_channels
185
186    def _get_sample(self, index):
187        if self.sample_random_index:
188            index = np.random.randint(0, len(self.raw_images))
189
190        # The filepath corresponding to this image.
191        raw_path, label_path = self.raw_images[index], self.label_images[index]
192
193        # Load the corresponding data.
194        raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
195
196        # Sample random bounding box for this image.
197        bb = self._sample_bounding_box(shape)
198        raw_patch = np.array(raw[prefix_box + bb])
199        label_patch = np.array(label[bb])
200
201        if self.sampler is not None:
202            sample_id = 0
203            while not self.sampler(raw_patch, label_patch):
204                bb = self._sample_bounding_box(shape)
205                raw_patch = np.array(raw[prefix_box + bb])
206                label_patch = np.array(label[bb])
207                sample_id += 1
208
209                # We need to avoid sampling from the same image over and over again,
210                # otherwise this will fail just because of one or a few empty images.
211                # Hence we update the image from which we sample sometimes.
212                if sample_id % self.max_sampling_attempts_image == 0:
213                    index = np.random.randint(0, len(self.raw_images))
214                    raw_path, label_path = self.raw_images[index], self.label_images[index]
215                    raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
216
217                if sample_id > self.max_sampling_attempts:
218                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
219
220        # to channel first
221        if have_raw_channels and len(prefix_box) == 0:
222            raw_patch = raw_patch.transpose((2, 0, 1))
223
224        return raw_patch, label_patch
225
226    def __getitem__(self, index):
227        raw, labels = self._get_sample(index)
228        initial_label_dtype = labels.dtype
229
230        if self.raw_transform is not None:
231            raw = self.raw_transform(raw)
232
233        if self.label_transform is not None:
234            labels = self.label_transform(labels)
235
236        if self.transform is not None:
237            raw, labels = self.transform(raw, labels)
238            # if self.trafo_halo is not None:
239            #     raw = self.crop(raw)
240            #     labels = self.crop(labels)
241
242        # support enlarging bounding box here as well (for affinity transform) ?
243        if self.label_transform2 is not None:
244            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
245            labels = self.label_transform2(labels)
246
247        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
248        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
249        return raw, labels

Dataset that provides raw data and labels stored in a regular image data format for segmentation training.

The dataset returns patches loaded from the images and labels as sample for a batch. The raw data and labels are expected to be images of the same shape, except for possible channels. It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.

Arguments:
  • raw_image_paths: The file paths to the raw data.
  • label_image_paths: The file path to the label data.
  • patch_shape: The patch shape for a training sample.
  • raw_transform: Transformation applied to the raw data of a sample.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
  • full_check: Whether to check that the input data is valid for all image paths. This will ensure that the data is valid, but will take longer for creating the dataset.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
ImageCollectionDataset( raw_image_paths: List[Union[str, os.PathLike]], label_image_paths: List[Union[str, os.PathLike]], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, full_check: bool = False, with_padding: bool = True)
 85    def __init__(
 86        self,
 87        raw_image_paths: List[Union[str, os.PathLike]],
 88        label_image_paths: List[Union[str, os.PathLike]],
 89        patch_shape: Tuple[int, ...],
 90        raw_transform: Optional[Callable] = None,
 91        label_transform: Optional[Callable] = None,
 92        label_transform2: Optional[Callable] = None,
 93        transform: Optional[Callable] = None,
 94        dtype: torch.dtype = torch.float32,
 95        label_dtype: torch.dtype = torch.float32,
 96        n_samples: Optional[int] = None,
 97        sampler: Optional[Callable] = None,
 98        full_check: bool = False,
 99        with_padding: bool = True,
100    ):
101        self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check)
102        self.raw_images = raw_image_paths
103        self.label_images = label_image_paths
104        self._ndim = 2
105
106        if patch_shape is not None:
107            assert len(patch_shape) == self._ndim
108        self.patch_shape = patch_shape
109
110        self.raw_transform = raw_transform
111        self.label_transform = label_transform
112        self.label_transform2 = label_transform2
113        self.transform = transform
114        self.sampler = sampler
115        self.with_padding = with_padding
116
117        self.dtype = dtype
118        self.label_dtype = label_dtype
119
120        if n_samples is None:
121            self._len = len(self.raw_images)
122            self.sample_random_index = False
123        else:
124            self._len = n_samples
125            self.sample_random_index = True
max_sampling_attempts = 500

The maximal number of sampling attempts, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

max_sampling_attempts_image = 50

The maximal number of sampling attempts for a single image, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

raw_images
label_images
patch_shape
raw_transform
label_transform
label_transform2
transform
sampler
with_padding
dtype
label_dtype
ndim
130    @property
131    def ndim(self):
132        return self._ndim