1import os
  2import numpy as np
  3from typing import List, Optional, Tuple, Union, Callable
  5import torch
  7from ..util import (
  8    ensure_spatial_array, ensure_tensor_with_channels, load_image, supports_memmap, ensure_patch_shape
 12class ImageCollectionDataset(
 13    """Dataset that provides raw data and labels stored in a regular image data format for segmentation training.
 15    The dataset returns patches loaded from the images and labels as sample for a batch.
 16    The raw data and labels are expected to be images of the same shape, except for possible channels.
 17    It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
 19    Args:
 20        raw_image_paths: The file paths to the raw data.
 21        label_image_paths: The file path to the label data.
 22        patch_shape: The patch shape for a training sample.
 23        raw_transform: Transformation applied to the raw data of a sample.
 24        label_transform: Transformation applied to the label data of a sample,
 25            before applying augmentations via `transform`.
 26        label_transform2: Transformation applied to the label data of a sample,
 27            after applying augmentations via `transform`.
 28        transform: Transformation applied to both the raw data and label data of a sample.
 29            This can be used to implement data augmentations.
 30        dtype: The return data type of the raw data.
 31        label_dtype: The return data type of the label data.
 32        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 33        sampler: Sampler for rejecting samples according to a defined criterion.
 34            The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
 35        full_check: Whether to check that the input data is valid for all image paths.
 36            This will ensure that the data is valid, but will take longer for creating the dataset.
 37        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
 38    """
 39    max_sampling_attempts = 500
 40    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 41    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 42    """
 43    max_sampling_attempts_image = 50
 44    """The maximal number of sampling attempts for a single image, for loading a sample via `__getitem__`.
 45    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 46    """
 48    def _check_inputs(self, raw_images, label_images, full_check):
 49        if len(raw_images) != len(label_images):
 50            raise ValueError(f"Expect same number of  and label images, got {len(raw_images)} and {len(label_images)}")
 52        if not full_check:
 53            return
 55        is_multichan = None
 56        for raw_im, label_im in zip(raw_images, label_images):
 58            # we only check for compatible shapes if both images support memmap, because
 59            # we don't want to load everything into ram
 60            if supports_memmap(raw_im) and supports_memmap(label_im):
 61                shape = load_image(raw_im).shape
 62                assert len(shape) in (2, 3)
 64                multichan = len(shape) == 3
 65                if is_multichan is None:
 66                    is_multichan = multichan
 67                else:
 68                    assert is_multichan == multichan
 70                if is_multichan:
 71                    # use heuristic to decide whether the data is stored in channel last or channel first order:
 72                    # if the last axis has a length smaller than 16 we assume that it's the channel axis,
 73                    # otherwise we assume it's a spatial axis and that the first axis is the channel axis.
 74                    if shape[-1] < 16:
 75                        shape = shape[:-1]
 76                    else:
 77                        shape = shape[1:]
 79                label_shape = load_image(label_im).shape
 80                if shape != label_shape:
 81                    msg = f"Expect raw and labels of same shape, got {shape}, {label_shape} for {raw_im}, {label_im}"
 82                    raise ValueError(msg)
 84    def __init__(
 85        self,
 86        raw_image_paths: List[Union[str, os.PathLike]],
 87        label_image_paths: List[Union[str, os.PathLike]],
 88        patch_shape: Tuple[int, ...],
 89        raw_transform: Optional[Callable] = None,
 90        label_transform: Optional[Callable] = None,
 91        label_transform2: Optional[Callable] = None,
 92        transform: Optional[Callable] = None,
 93        dtype: torch.dtype = torch.float32,
 94        label_dtype: torch.dtype = torch.float32,
 95        n_samples: Optional[int] = None,
 96        sampler: Optional[Callable] = None,
 97        full_check: bool = False,
 98        with_padding: bool = True,
 99    ):
100        self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check)
101        self.raw_images = raw_image_paths
102        self.label_images = label_image_paths
103        self._ndim = 2
105        if patch_shape is not None:
106            assert len(patch_shape) == self._ndim
107        self.patch_shape = patch_shape
109        self.raw_transform = raw_transform
110        self.label_transform = label_transform
111        self.label_transform2 = label_transform2
112        self.transform = transform
113        self.sampler = sampler
114        self.with_padding = with_padding
116        self.dtype = dtype
117        self.label_dtype = label_dtype
119        if n_samples is None:
120            self._len = len(self.raw_images)
121            self.sample_random_index = False
122        else:
123            self._len = n_samples
124            self.sample_random_index = True
126    def __len__(self):
127        return self._len
129    @property
130    def ndim(self):
131        return self._ndim
133    def _sample_bounding_box(self, shape):
134        if self.patch_shape is None:
135            patch_shape_for_bb = shape
136            bb_start = [0] * len(shape)
137        else:
138            patch_shape_for_bb = self.patch_shape
139            bb_start = [
140                np.random.randint(0, sh - psh) if sh - psh > 0 else 0
141                for sh, psh in zip(shape, patch_shape_for_bb)
142            ]
144        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
146    def _load_data(self, raw_path, label_path):
147        raw = load_image(raw_path, memmap=False)
148        label = load_image(label_path, memmap=False)
150        have_raw_channels = raw.ndim == 3
151        have_label_channels = label.ndim == 3
152        if have_label_channels:
153            raise NotImplementedError("Multi-channel labels are not supported.")
155        # We determine if the image has channels as the first or last axis based on the array shape.
156        # This will work only for images with less than 16 channels!
157        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
158        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
159        channel_first = None
160        if have_raw_channels:
161            channel_first = raw.shape[-1] > 16
163        if self.patch_shape is not None and self.with_padding:
164            raw, label = ensure_patch_shape(
165                raw=raw,
166                labels=label,
167                patch_shape=self.patch_shape,
168                have_raw_channels=have_raw_channels,
169                have_label_channels=have_label_channels,
170                channel_first=channel_first
171            )
173        shape = raw.shape
175        prefix_box = tuple()
176        if have_raw_channels:
177            if channel_first:
178                shape = shape[1:]
179                prefix_box = (slice(None), )
180            else:
181                shape = shape[:-1]
183        return raw, label, shape, prefix_box, have_raw_channels
185    def _get_sample(self, index):
186        if self.sample_random_index:
187            index = np.random.randint(0, len(self.raw_images))
189        # The filepath corresponding to this image.
190        raw_path, label_path = self.raw_images[index], self.label_images[index]
192        # Load the corresponding data.
193        raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
195        # Sample random bounding box for this image.
196        bb = self._sample_bounding_box(shape)
197        raw_patch = np.array(raw[prefix_box + bb])
198        label_patch = np.array(label[bb])
200        if self.sampler is not None:
201            sample_id = 0
202            while not self.sampler(raw_patch, label_patch):
203                bb = self._sample_bounding_box(shape)
204                raw_patch = np.array(raw[prefix_box + bb])
205                label_patch = np.array(label[bb])
206                sample_id += 1
208                # We need to avoid sampling from the same image over and over again,
209                # otherwise this will fail just because of one or a few empty images.
210                # Hence we update the image from which we sample sometimes.
211                if sample_id % self.max_sampling_attempts_image == 0:
212                    index = np.random.randint(0, len(self.raw_images))
213                    raw_path, label_path = self.raw_images[index], self.label_images[index]
214                    raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
216                if sample_id > self.max_sampling_attempts:
217                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
219        # to channel first
220        if have_raw_channels and len(prefix_box) == 0:
221            raw_patch = raw_patch.transpose((2, 0, 1))
223        return raw_patch, label_patch
225    def __getitem__(self, index):
226        raw, labels = self._get_sample(index)
227        initial_label_dtype = labels.dtype
229        if self.raw_transform is not None:
230            raw = self.raw_transform(raw)
232        if self.label_transform is not None:
233            labels = self.label_transform(labels)
235        if self.transform is not None:
236            raw, labels = self.transform(raw, labels)
237            # if self.trafo_halo is not None:
238            #     raw = self.crop(raw)
239            #     labels = self.crop(labels)
241        # support enlarging bounding box here as well (for affinity transform) ?
242        if self.label_transform2 is not None:
243            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
244            labels = self.label_transform2(labels)
246        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
247        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
248        return raw, labels
