torch_em.data.segmentation_dataset

  1import os
  2import warnings
  3import numpy as np
  4from typing import List, Union, Tuple, Optional, Any
  5
  6import torch
  7
  8from elf.wrapper import RoiWrapper
  9
 10from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data
 11
 12
 13class SegmentationDataset(torch.utils.data.Dataset):
 14    """
 15    """
 16    max_sampling_attempts = 500
 17
 18    @staticmethod
 19    def compute_len(shape, patch_shape):
 20        if patch_shape is None:
 21            return 1
 22        else:
 23            n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 24            return n_samples
 25
 26    def __init__(
 27        self,
 28        raw_path: Union[List[Any], str, os.PathLike],
 29        raw_key: str,
 30        label_path: Union[List[Any], str, os.PathLike],
 31        label_key: str,
 32        patch_shape: Tuple[int, ...],
 33        raw_transform=None,
 34        label_transform=None,
 35        label_transform2=None,
 36        transform=None,
 37        roi: Optional[dict] = None,
 38        dtype: torch.dtype = torch.float32,
 39        label_dtype: torch.dtype = torch.float32,
 40        n_samples: Optional[int] = None,
 41        sampler=None,
 42        ndim: Optional[int] = None,
 43        with_channels: bool = False,
 44        with_label_channels: bool = False,
 45    ):
 46        self.raw_path = raw_path
 47        self.raw_key = raw_key
 48        self.raw = load_data(raw_path, raw_key)
 49
 50        self.label_path = label_path
 51        self.label_key = label_key
 52        self.labels = load_data(label_path, label_key)
 53
 54        self._with_channels = with_channels
 55        self._with_label_channels = with_label_channels
 56
 57        if roi is not None:
 58            if isinstance(roi, slice):
 59                roi = (roi,)
 60            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 61            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
 62                RoiWrapper(self.labels, roi)
 63
 64        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
 65        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
 66        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
 67
 68        self.shape = shape_raw
 69        self.roi = roi
 70
 71        self._ndim = len(shape_raw) if ndim is None else ndim
 72        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 73
 74        if patch_shape is not None:
 75            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 76
 77        self.patch_shape = patch_shape
 78
 79        self.raw_transform = raw_transform
 80        self.label_transform = label_transform
 81        self.label_transform2 = label_transform2
 82        self.transform = transform
 83        self.sampler = sampler
 84
 85        self.dtype = dtype
 86        self.label_dtype = label_dtype
 87
 88        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 89
 90        self.sample_shape = patch_shape
 91        self.trafo_halo = None
 92        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
 93        # which is then cut. See code below; but this ne needs to be properly tested
 94
 95        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
 96        # if self.trafo_halo is not None:
 97        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
 98        #         self.trafo_halo = (0,) + self.trafo_halo
 99        #     assert len(self.trafo_halo) == self._ndim
100        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
101        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
102
103    def __len__(self):
104        return self._len
105
106    @property
107    def ndim(self):
108        return self._ndim
109
110    def _sample_bounding_box(self):
111        if self.sample_shape is None:
112            bb_start = [0] * len(self.shape)
113            patch_shape_for_bb = self.shape
114        else:
115            bb_start = [
116                np.random.randint(0, sh - psh) if sh - psh > 0 else 0
117                for sh, psh in zip(self.shape, self.sample_shape)
118            ]
119            patch_shape_for_bb = self.sample_shape
120
121        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
122
123    def _get_sample(self, index):
124        if self.raw is None or self.labels is None:
125            raise RuntimeError("SegmentationDataset has not been properly deserialized.")
126        bb = self._sample_bounding_box()
127        bb_raw = (slice(None),) + bb if self._with_channels else bb
128        bb_labels = (slice(None),) + bb if self._with_label_channels else bb
129        raw, labels = self.raw[bb_raw], self.labels[bb_labels]
130
131        if self.sampler is not None:
132            sample_id = 0
133            while not self.sampler(raw, labels):
134                bb = self._sample_bounding_box()
135                bb_raw = (slice(None),) + bb if self._with_channels else bb
136                bb_labels = (slice(None),) + bb if self._with_label_channels else bb
137                raw, labels = self.raw[bb_raw], self.labels[bb_labels]
138                sample_id += 1
139                if sample_id > self.max_sampling_attempts:
140                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
141
142        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
143        if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1:
144            raw = raw.squeeze(1 if self._with_channels else 0)
145            labels = labels.squeeze(1 if self._with_label_channels else 0)
146
147        return raw, labels
148
149    def crop(self, tensor):
150        bb = self.inner_bb
151        if tensor.ndim > len(bb):
152            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
153        return tensor[bb]
154
155    def __getitem__(self, index):
156        raw, labels = self._get_sample(index)
157        initial_label_dtype = labels.dtype
158
159        if self.raw_transform is not None:
160            raw = self.raw_transform(raw)
161
162        if self.label_transform is not None:
163            labels = self.label_transform(labels)
164
165        if self.transform is not None:
166            raw, labels = self.transform(raw, labels)
167            if self.trafo_halo is not None:
168                raw = self.crop(raw)
169                labels = self.crop(labels)
170
171        # support enlarging bounding box here as well (for affinity transform) ?
172        if self.label_transform2 is not None:
173            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
174            labels = self.label_transform2(labels)
175
176        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
177        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
178        return raw, labels
179
180    # need to overwrite pickle to support h5py
181    def __getstate__(self):
182        state = self.__dict__.copy()
183        del state["raw"]
184        del state["labels"]
185        return state
186
187    def __setstate__(self, state):
188        raw_path, raw_key = state["raw_path"], state["raw_key"]
189        label_path, label_key = state["label_path"], state["label_key"]
190        roi = state["roi"]
191        try:
192            raw = load_data(raw_path, raw_key)
193            if roi is not None:
194                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
195            state["raw"] = raw
196        except Exception:
197            msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
198            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
199            msg += "But it cannot be used for further training and wil throw an error."
200            warnings.warn(msg)
201            state["raw"] = None
202
203        try:
204            labels = load_data(label_path, label_key)
205            if roi is not None:
206                labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
207                    RoiWrapper(labels, roi)
208            state["labels"] = labels
209        except Exception:
210            msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n"
211            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
212            msg += "But it cannot be used for further training and wil throw an error."
213            warnings.warn(msg)
214            state["labels"] = None
215
216        self.__dict__.update(state)
class SegmentationDataset(typing.Generic[+T_co]):
 14class SegmentationDataset(torch.utils.data.Dataset):
 15    """
 16    """
 17    max_sampling_attempts = 500
 18
 19    @staticmethod
 20    def compute_len(shape, patch_shape):
 21        if patch_shape is None:
 22            return 1
 23        else:
 24            n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 25            return n_samples
 26
 27    def __init__(
 28        self,
 29        raw_path: Union[List[Any], str, os.PathLike],
 30        raw_key: str,
 31        label_path: Union[List[Any], str, os.PathLike],
 32        label_key: str,
 33        patch_shape: Tuple[int, ...],
 34        raw_transform=None,
 35        label_transform=None,
 36        label_transform2=None,
 37        transform=None,
 38        roi: Optional[dict] = None,
 39        dtype: torch.dtype = torch.float32,
 40        label_dtype: torch.dtype = torch.float32,
 41        n_samples: Optional[int] = None,
 42        sampler=None,
 43        ndim: Optional[int] = None,
 44        with_channels: bool = False,
 45        with_label_channels: bool = False,
 46    ):
 47        self.raw_path = raw_path
 48        self.raw_key = raw_key
 49        self.raw = load_data(raw_path, raw_key)
 50
 51        self.label_path = label_path
 52        self.label_key = label_key
 53        self.labels = load_data(label_path, label_key)
 54
 55        self._with_channels = with_channels
 56        self._with_label_channels = with_label_channels
 57
 58        if roi is not None:
 59            if isinstance(roi, slice):
 60                roi = (roi,)
 61            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 62            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
 63                RoiWrapper(self.labels, roi)
 64
 65        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
 66        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
 67        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
 68
 69        self.shape = shape_raw
 70        self.roi = roi
 71
 72        self._ndim = len(shape_raw) if ndim is None else ndim
 73        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 74
 75        if patch_shape is not None:
 76            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 77
 78        self.patch_shape = patch_shape
 79
 80        self.raw_transform = raw_transform
 81        self.label_transform = label_transform
 82        self.label_transform2 = label_transform2
 83        self.transform = transform
 84        self.sampler = sampler
 85
 86        self.dtype = dtype
 87        self.label_dtype = label_dtype
 88
 89        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 90
 91        self.sample_shape = patch_shape
 92        self.trafo_halo = None
 93        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
 94        # which is then cut. See code below; but this ne needs to be properly tested
 95
 96        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
 97        # if self.trafo_halo is not None:
 98        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
 99        #         self.trafo_halo = (0,) + self.trafo_halo
100        #     assert len(self.trafo_halo) == self._ndim
101        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
102        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
103
104    def __len__(self):
105        return self._len
106
107    @property
108    def ndim(self):
109        return self._ndim
110
111    def _sample_bounding_box(self):
112        if self.sample_shape is None:
113            bb_start = [0] * len(self.shape)
114            patch_shape_for_bb = self.shape
115        else:
116            bb_start = [
117                np.random.randint(0, sh - psh) if sh - psh > 0 else 0
118                for sh, psh in zip(self.shape, self.sample_shape)
119            ]
120            patch_shape_for_bb = self.sample_shape
121
122        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))
123
124    def _get_sample(self, index):
125        if self.raw is None or self.labels is None:
126            raise RuntimeError("SegmentationDataset has not been properly deserialized.")
127        bb = self._sample_bounding_box()
128        bb_raw = (slice(None),) + bb if self._with_channels else bb
129        bb_labels = (slice(None),) + bb if self._with_label_channels else bb
130        raw, labels = self.raw[bb_raw], self.labels[bb_labels]
131
132        if self.sampler is not None:
133            sample_id = 0
134            while not self.sampler(raw, labels):
135                bb = self._sample_bounding_box()
136                bb_raw = (slice(None),) + bb if self._with_channels else bb
137                bb_labels = (slice(None),) + bb if self._with_label_channels else bb
138                raw, labels = self.raw[bb_raw], self.labels[bb_labels]
139                sample_id += 1
140                if sample_id > self.max_sampling_attempts:
141                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
142
143        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
144        if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1:
145            raw = raw.squeeze(1 if self._with_channels else 0)
146            labels = labels.squeeze(1 if self._with_label_channels else 0)
147
148        return raw, labels
149
150    def crop(self, tensor):
151        bb = self.inner_bb
152        if tensor.ndim > len(bb):
153            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
154        return tensor[bb]
155
156    def __getitem__(self, index):
157        raw, labels = self._get_sample(index)
158        initial_label_dtype = labels.dtype
159
160        if self.raw_transform is not None:
161            raw = self.raw_transform(raw)
162
163        if self.label_transform is not None:
164            labels = self.label_transform(labels)
165
166        if self.transform is not None:
167            raw, labels = self.transform(raw, labels)
168            if self.trafo_halo is not None:
169                raw = self.crop(raw)
170                labels = self.crop(labels)
171
172        # support enlarging bounding box here as well (for affinity transform) ?
173        if self.label_transform2 is not None:
174            labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype)
175            labels = self.label_transform2(labels)
176
177        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
178        labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
179        return raw, labels
180
181    # need to overwrite pickle to support h5py
182    def __getstate__(self):
183        state = self.__dict__.copy()
184        del state["raw"]
185        del state["labels"]
186        return state
187
188    def __setstate__(self, state):
189        raw_path, raw_key = state["raw_path"], state["raw_key"]
190        label_path, label_key = state["label_path"], state["label_key"]
191        roi = state["roi"]
192        try:
193            raw = load_data(raw_path, raw_key)
194            if roi is not None:
195                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
196            state["raw"] = raw
197        except Exception:
198            msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
199            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
200            msg += "But it cannot be used for further training and wil throw an error."
201            warnings.warn(msg)
202            state["raw"] = None
203
204        try:
205            labels = load_data(label_path, label_key)
206            if roi is not None:
207                labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
208                    RoiWrapper(labels, roi)
209            state["labels"] = labels
210        except Exception:
211            msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n"
212            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
213            msg += "But it cannot be used for further training and wil throw an error."
214            warnings.warn(msg)
215            state["labels"] = None
216
217        self.__dict__.update(state)
SegmentationDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: str, label_path: Union[List[Any], str, os.PathLike], label_key: str, patch_shape: Tuple[int, ...], raw_transform=None, label_transform=None, label_transform2=None, transform=None, roi: Optional[dict] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler=None, ndim: Optional[int] = None, with_channels: bool = False, with_label_channels: bool = False)
 27    def __init__(
 28        self,
 29        raw_path: Union[List[Any], str, os.PathLike],
 30        raw_key: str,
 31        label_path: Union[List[Any], str, os.PathLike],
 32        label_key: str,
 33        patch_shape: Tuple[int, ...],
 34        raw_transform=None,
 35        label_transform=None,
 36        label_transform2=None,
 37        transform=None,
 38        roi: Optional[dict] = None,
 39        dtype: torch.dtype = torch.float32,
 40        label_dtype: torch.dtype = torch.float32,
 41        n_samples: Optional[int] = None,
 42        sampler=None,
 43        ndim: Optional[int] = None,
 44        with_channels: bool = False,
 45        with_label_channels: bool = False,
 46    ):
 47        self.raw_path = raw_path
 48        self.raw_key = raw_key
 49        self.raw = load_data(raw_path, raw_key)
 50
 51        self.label_path = label_path
 52        self.label_key = label_key
 53        self.labels = load_data(label_path, label_key)
 54
 55        self._with_channels = with_channels
 56        self._with_label_channels = with_label_channels
 57
 58        if roi is not None:
 59            if isinstance(roi, slice):
 60                roi = (roi,)
 61            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 62            self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\
 63                RoiWrapper(self.labels, roi)
 64
 65        shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
 66        shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
 67        assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
 68
 69        self.shape = shape_raw
 70        self.roi = roi
 71
 72        self._ndim = len(shape_raw) if ndim is None else ndim
 73        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 74
 75        if patch_shape is not None:
 76            assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 77
 78        self.patch_shape = patch_shape
 79
 80        self.raw_transform = raw_transform
 81        self.label_transform = label_transform
 82        self.label_transform2 = label_transform2
 83        self.transform = transform
 84        self.sampler = sampler
 85
 86        self.dtype = dtype
 87        self.label_dtype = label_dtype
 88
 89        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 90
 91        self.sample_shape = patch_shape
 92        self.trafo_halo = None
 93        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
 94        # which is then cut. See code below; but this ne needs to be properly tested
 95
 96        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
 97        # if self.trafo_halo is not None:
 98        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
 99        #         self.trafo_halo = (0,) + self.trafo_halo
100        #     assert len(self.trafo_halo) == self._ndim
101        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
102        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
max_sampling_attempts = 500
@staticmethod
def compute_len(shape, patch_shape):
19    @staticmethod
20    def compute_len(shape, patch_shape):
21        if patch_shape is None:
22            return 1
23        else:
24            n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
25            return n_samples
raw_path
raw_key
raw
label_path
label_key
labels
shape
roi
patch_shape
raw_transform
label_transform
label_transform2
transform
sampler
dtype
label_dtype
sample_shape
trafo_halo
ndim
def crop(self, tensor):
150    def crop(self, tensor):
151        bb = self.inner_bb
152        if tensor.ndim > len(bb):
153            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
154        return tensor[bb]