torch_em.data.raw_image_collection_dataset

  1import os
  2import numpy as np
  3from typing import List, Union, Tuple, Optional, Any, Callable
  4
  5import torch
  6
  7from ..util import ensure_tensor_with_channels, load_image, supports_memmap
  8
  9
 10class RawImageCollectionDataset(torch.utils.data.Dataset):
 11    """Dataset that provides raw data stored in a regular image data format for unsupervised training.
 12
 13    The dataset loads a patch the raw data and returns a sample for a batch.
 14    It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
 15
 16    The dataset can also be used for contrastive learning that relies on two different views of the same data.
 17    You can use the `augmentations` argument for this.
 18
 19    Args:
 20        raw_image_paths: The file paths to the raw data.
 21        patch_shape: The patch shape for a training sample.
 22        raw_transform: Transformation applied to the raw data of a sample.
 23        transform: Transformation to the raw data. This can be used to implement data augmentations.
 24        dtype: The return data type of the raw data.
 25        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 26        sampler: Sampler for rejecting samples according to a defined criterion.
 27            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 28        augmentations: Augmentations for contrastive learning. If given, these need to be two different callables.
 29            They will be applied to the sampled raw data to return two independent views of the raw data.
 30        full_check: Whether to check that the input data is valid for all image paths.
 31            This will ensure that the data is valid, but will take longer for creating the dataset.
 32    """
 33    max_sampling_attempts = 500
 34    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 35    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 36    """
 37
 38    def _check_inputs(self, raw_images, full_check):
 39        if not full_check:
 40            return
 41
 42        is_multichan = None
 43        for raw_im in raw_images:
 44
 45            # we only check for compatible shapes if images support memmap, because
 46            # we don't want to load everything into ram
 47            if supports_memmap(raw_im):
 48                shape = load_image(raw_im).shape
 49                assert len(shape) in (2, 3)
 50
 51                multichan = len(shape) == 3
 52                if is_multichan is None:
 53                    is_multichan = multichan
 54                else:
 55                    assert is_multichan == multichan
 56
 57                # we assume axis last
 58                if is_multichan:
 59                    shape = shape[:-1]
 60
 61    def __init__(
 62        self,
 63        raw_image_paths: Union[List[Any], str, os.PathLike],
 64        patch_shape: Tuple[int, ...],
 65        raw_transform: Optional[Callable] = None,
 66        transform: Optional[Callable] = None,
 67        dtype: torch.dtype = torch.float32,
 68        n_samples: Optional[int] = None,
 69        sampler: Optional[Callable] = None,
 70        augmentations: Optional[Callable] = None,
 71        full_check: bool = False,
 72    ):
 73        self._check_inputs(raw_image_paths, full_check)
 74        self.raw_images = raw_image_paths
 75        self._ndim = 2
 76
 77        assert len(patch_shape) == self._ndim
 78        self.patch_shape = patch_shape
 79
 80        self.raw_transform = raw_transform
 81        self.transform = transform
 82        self.dtype = dtype
 83        self.sampler = sampler
 84
 85        if n_samples is None:
 86            self._len = len(self.raw_images)
 87            self.sample_random_index = False
 88        else:
 89            self._len = n_samples
 90            self.sample_random_index = True
 91
 92        if augmentations is not None:
 93            assert len(augmentations) == 2
 94        self.augmentations = augmentations
 95
 96    def __len__(self):
 97        return self._len
 98
 99    @property
100    def ndim(self):
101        return self._ndim
102
103    def _sample_bounding_box(self, shape):
104        bb_start = [
105            np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, self.patch_shape)
106        ]
107        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))
108
109    def _ensure_patch_shape(self, raw, have_raw_channels, channel_first):
110        shape = raw.shape
111        if have_raw_channels and channel_first:
112            shape = shape[1:]
113
114        if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
115            pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)]
116
117            if have_raw_channels and channel_first:
118                pw_raw = [(0, 0), *pw]
119            elif have_raw_channels and not channel_first:
120                pw_raw = [*pw, (0, 0)]
121            else:
122                pw_raw = pw
123
124            raw = np.pad(raw, pw_raw)
125        return raw
126
127    def _get_sample(self, index):
128        if self.sample_random_index:
129            index = np.random.randint(0, len(self.raw_images))
130
131        raw = load_image(self.raw_images[index])
132        have_raw_channels = raw.ndim == 3
133
134        # We determine if the image has channels as the first or last axis based on the array shape.
135        # This will work only for images with less than 16 channels!
136        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
137        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
138        channel_first = None
139        if have_raw_channels:
140            channel_first = raw.shape[-1] > 16
141
142        raw = self._ensure_patch_shape(raw, have_raw_channels, channel_first)
143
144        shape = raw.shape
145        # we assume images are loaded with channel last!
146        if have_raw_channels:
147            shape = shape[:-1]
148
149        # sample random bounding box for this image
150        bb = self._sample_bounding_box(shape)
151        raw = np.array(raw[bb])
152
153        if self.sampler is not None:
154            sample_id = 0
155            while not self.sampler(raw):
156                bb = self._sample_bounding_box(shape)
157                raw = np.array(raw[bb])
158                sample_id += 1
159                if sample_id > self.max_sampling_attempts:
160                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
161
162        # to channel first
163        if have_raw_channels:
164            raw = raw.transpose((2, 0, 1))
165
166        return raw
167
168    def __getitem__(self, index):
169        raw = self._get_sample(index)
170
171        if self.raw_transform is not None:
172            raw = self.raw_transform(raw)
173
174        if self.transform is not None:
175            raw = self.transform(raw)
176            assert len(raw) == 1
177            raw = raw[0]
178            # if self.trafo_halo is not None:
179            #     raw = self.crop(raw)
180
181        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
182        if self.augmentations is not None:
183            aug1, aug2 = self.augmentations
184            raw1, raw2 = aug1(raw), aug2(raw)
185            return raw1, raw2
186
187        return raw
class RawImageCollectionDataset(typing.Generic[+_T_co]):
 11class RawImageCollectionDataset(torch.utils.data.Dataset):
 12    """Dataset that provides raw data stored in a regular image data format for unsupervised training.
 13
 14    The dataset loads a patch the raw data and returns a sample for a batch.
 15    It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.
 16
 17    The dataset can also be used for contrastive learning that relies on two different views of the same data.
 18    You can use the `augmentations` argument for this.
 19
 20    Args:
 21        raw_image_paths: The file paths to the raw data.
 22        patch_shape: The patch shape for a training sample.
 23        raw_transform: Transformation applied to the raw data of a sample.
 24        transform: Transformation to the raw data. This can be used to implement data augmentations.
 25        dtype: The return data type of the raw data.
 26        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
 27        sampler: Sampler for rejecting samples according to a defined criterion.
 28            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
 29        augmentations: Augmentations for contrastive learning. If given, these need to be two different callables.
 30            They will be applied to the sampled raw data to return two independent views of the raw data.
 31        full_check: Whether to check that the input data is valid for all image paths.
 32            This will ensure that the data is valid, but will take longer for creating the dataset.
 33    """
 34    max_sampling_attempts = 500
 35    """The maximal number of sampling attempts, for loading a sample via `__getitem__`.
 36    This is used when `sampler` rejects a sample, to avoid an infinite loop if no valid sample can be found.
 37    """
 38
 39    def _check_inputs(self, raw_images, full_check):
 40        if not full_check:
 41            return
 42
 43        is_multichan = None
 44        for raw_im in raw_images:
 45
 46            # we only check for compatible shapes if images support memmap, because
 47            # we don't want to load everything into ram
 48            if supports_memmap(raw_im):
 49                shape = load_image(raw_im).shape
 50                assert len(shape) in (2, 3)
 51
 52                multichan = len(shape) == 3
 53                if is_multichan is None:
 54                    is_multichan = multichan
 55                else:
 56                    assert is_multichan == multichan
 57
 58                # we assume axis last
 59                if is_multichan:
 60                    shape = shape[:-1]
 61
 62    def __init__(
 63        self,
 64        raw_image_paths: Union[List[Any], str, os.PathLike],
 65        patch_shape: Tuple[int, ...],
 66        raw_transform: Optional[Callable] = None,
 67        transform: Optional[Callable] = None,
 68        dtype: torch.dtype = torch.float32,
 69        n_samples: Optional[int] = None,
 70        sampler: Optional[Callable] = None,
 71        augmentations: Optional[Callable] = None,
 72        full_check: bool = False,
 73    ):
 74        self._check_inputs(raw_image_paths, full_check)
 75        self.raw_images = raw_image_paths
 76        self._ndim = 2
 77
 78        assert len(patch_shape) == self._ndim
 79        self.patch_shape = patch_shape
 80
 81        self.raw_transform = raw_transform
 82        self.transform = transform
 83        self.dtype = dtype
 84        self.sampler = sampler
 85
 86        if n_samples is None:
 87            self._len = len(self.raw_images)
 88            self.sample_random_index = False
 89        else:
 90            self._len = n_samples
 91            self.sample_random_index = True
 92
 93        if augmentations is not None:
 94            assert len(augmentations) == 2
 95        self.augmentations = augmentations
 96
 97    def __len__(self):
 98        return self._len
 99
100    @property
101    def ndim(self):
102        return self._ndim
103
104    def _sample_bounding_box(self, shape):
105        bb_start = [
106            np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, self.patch_shape)
107        ]
108        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))
109
110    def _ensure_patch_shape(self, raw, have_raw_channels, channel_first):
111        shape = raw.shape
112        if have_raw_channels and channel_first:
113            shape = shape[1:]
114
115        if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
116            pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)]
117
118            if have_raw_channels and channel_first:
119                pw_raw = [(0, 0), *pw]
120            elif have_raw_channels and not channel_first:
121                pw_raw = [*pw, (0, 0)]
122            else:
123                pw_raw = pw
124
125            raw = np.pad(raw, pw_raw)
126        return raw
127
128    def _get_sample(self, index):
129        if self.sample_random_index:
130            index = np.random.randint(0, len(self.raw_images))
131
132        raw = load_image(self.raw_images[index])
133        have_raw_channels = raw.ndim == 3
134
135        # We determine if the image has channels as the first or last axis based on the array shape.
136        # This will work only for images with less than 16 channels!
137        # If the last axis has a length smaller than 16 we assume that it is the channel axis,
138        # otherwise we assume it is a spatial axis and that the first axis is the channel axis.
139        channel_first = None
140        if have_raw_channels:
141            channel_first = raw.shape[-1] > 16
142
143        raw = self._ensure_patch_shape(raw, have_raw_channels, channel_first)
144
145        shape = raw.shape
146        # we assume images are loaded with channel last!
147        if have_raw_channels:
148            shape = shape[:-1]
149
150        # sample random bounding box for this image
151        bb = self._sample_bounding_box(shape)
152        raw = np.array(raw[bb])
153
154        if self.sampler is not None:
155            sample_id = 0
156            while not self.sampler(raw):
157                bb = self._sample_bounding_box(shape)
158                raw = np.array(raw[bb])
159                sample_id += 1
160                if sample_id > self.max_sampling_attempts:
161                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
162
163        # to channel first
164        if have_raw_channels:
165            raw = raw.transpose((2, 0, 1))
166
167        return raw
168
169    def __getitem__(self, index):
170        raw = self._get_sample(index)
171
172        if self.raw_transform is not None:
173            raw = self.raw_transform(raw)
174
175        if self.transform is not None:
176            raw = self.transform(raw)
177            assert len(raw) == 1
178            raw = raw[0]
179            # if self.trafo_halo is not None:
180            #     raw = self.crop(raw)
181
182        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
183        if self.augmentations is not None:
184            aug1, aug2 = self.augmentations
185            raw1, raw2 = aug1(raw), aug2(raw)
186            return raw1, raw2
187
188        return raw

Dataset that provides raw data stored in a regular image data format for unsupervised training.

The dataset loads a patch the raw data and returns a sample for a batch. It supports all file formats that can be loaded with the imageio or tiffile library, such as tif, png or jpeg files.

The dataset can also be used for contrastive learning that relies on two different views of the same data. You can use the augmentations argument for this.

Arguments:
  • raw_image_paths: The file paths to the raw data.
  • patch_shape: The patch shape for a training sample.
  • raw_transform: Transformation applied to the raw data of a sample.
  • transform: Transformation to the raw data. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • augmentations: Augmentations for contrastive learning. If given, these need to be two different callables. They will be applied to the sampled raw data to return two independent views of the raw data.
  • full_check: Whether to check that the input data is valid for all image paths. This will ensure that the data is valid, but will take longer for creating the dataset.
RawImageCollectionDataset( raw_image_paths: Union[List[Any], str, os.PathLike], patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, augmentations: Optional[Callable] = None, full_check: bool = False)
62    def __init__(
63        self,
64        raw_image_paths: Union[List[Any], str, os.PathLike],
65        patch_shape: Tuple[int, ...],
66        raw_transform: Optional[Callable] = None,
67        transform: Optional[Callable] = None,
68        dtype: torch.dtype = torch.float32,
69        n_samples: Optional[int] = None,
70        sampler: Optional[Callable] = None,
71        augmentations: Optional[Callable] = None,
72        full_check: bool = False,
73    ):
74        self._check_inputs(raw_image_paths, full_check)
75        self.raw_images = raw_image_paths
76        self._ndim = 2
77
78        assert len(patch_shape) == self._ndim
79        self.patch_shape = patch_shape
80
81        self.raw_transform = raw_transform
82        self.transform = transform
83        self.dtype = dtype
84        self.sampler = sampler
85
86        if n_samples is None:
87            self._len = len(self.raw_images)
88            self.sample_random_index = False
89        else:
90            self._len = n_samples
91            self.sample_random_index = True
92
93        if augmentations is not None:
94            assert len(augmentations) == 2
95        self.augmentations = augmentations
max_sampling_attempts = 500

The maximal number of sampling attempts, for loading a sample via __getitem__. This is used when sampler rejects a sample, to avoid an infinite loop if no valid sample can be found.

raw_images
patch_shape
raw_transform
transform
dtype
sampler
augmentations
ndim
100    @property
101    def ndim(self):
102        return self._ndim