torch_em.data.raw_dataset

  1import os
  2import warnings
  3import numpy as np
  4from typing import List, Union, Tuple, Optional, Any
  5
  6import torch
  7
  8from elf.wrapper import RoiWrapper
  9
 10from ..util import ensure_tensor_with_channels, load_data
 11
 12
 13class RawDataset(torch.utils.data.Dataset):
 14    """
 15    """
 16    max_sampling_attempts = 500
 17
 18    @staticmethod
 19    def compute_len(shape, patch_shape):
 20        n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 21        return n_samples
 22
 23    def __init__(
 24        self,
 25        raw_path: Union[List[Any], str, os.PathLike],
 26        raw_key: str,
 27        patch_shape: Tuple[int, ...],
 28        raw_transform=None,
 29        transform=None,
 30        roi: Optional[dict] = None,
 31        dtype: torch.dtype = torch.float32,
 32        n_samples: Optional[int] = None,
 33        sampler=None,
 34        ndim: Optional[int] = None,
 35        with_channels: bool = False,
 36        augmentations=None,
 37    ):
 38        self.raw_path = raw_path
 39        self.raw_key = raw_key
 40        self.raw = load_data(raw_path, raw_key)
 41
 42        self._with_channels = with_channels
 43
 44        if roi is not None:
 45            if isinstance(roi, slice):
 46                roi = (roi,)
 47            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 48
 49        self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 50        self.roi = roi
 51
 52        self._ndim = len(self.shape) if ndim is None else ndim
 53        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 54
 55        assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 56        self.patch_shape = patch_shape
 57
 58        self.raw_transform = raw_transform
 59        self.transform = transform
 60        self.sampler = sampler
 61        self.dtype = dtype
 62
 63        if augmentations is not None:
 64            assert len(augmentations) == 2
 65        self.augmentations = augmentations
 66
 67        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 68
 69        self.sample_shape = patch_shape
 70        self.trafo_halo = None
 71        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
 72        # which is then cut. See code below; but this ne needs to be properly tested
 73
 74        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
 75        # if self.trafo_halo is not None:
 76        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
 77        #         self.trafo_halo = (0,) + self.trafo_halo
 78        #     assert len(self.trafo_halo) == self._ndim
 79        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
 80        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
 81
 82    def __len__(self):
 83        return self._len
 84
 85    @property
 86    def ndim(self):
 87        return self._ndim
 88
 89    def _sample_bounding_box(self):
 90        bb_start = [
 91            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
 92            for sh, psh in zip(self.shape, self.sample_shape)
 93        ]
 94        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape))
 95
 96    def _get_sample(self, index):
 97        if self.raw is None:
 98            raise RuntimeError("RawDataset has not been properly deserialized.")
 99        bb = self._sample_bounding_box()
100        raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
101
102        if self.sampler is not None:
103            sample_id = 0
104            while not self.sampler(raw):
105                bb = self._sample_bounding_box()
106                raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
107                sample_id += 1
108                if sample_id > self.max_sampling_attempts:
109                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
110
111        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
112        if len(self.patch_shape) == self._ndim + 1:
113            raw = raw.squeeze(1 if self._with_channels else 0)
114
115        return raw
116
117    def crop(self, tensor):
118        bb = self.inner_bb
119        if tensor.ndim > len(bb):
120            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
121        return tensor[bb]
122
123    def __getitem__(self, index):
124        raw = self._get_sample(index)
125
126        if self.raw_transform is not None:
127            raw = self.raw_transform(raw)
128
129        if self.transform is not None:
130            raw = self.transform(raw)
131            if isinstance(raw, list):
132                assert len(raw) == 1
133                raw = raw[0]
134            if self.trafo_halo is not None:
135                raw = self.crop(raw)
136
137        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
138        if self.augmentations is not None:
139            aug1, aug2 = self.augmentations
140            raw1, raw2 = aug1(raw), aug2(raw)
141            return raw1, raw2
142
143        return raw
144
145    # need to overwrite pickle to support h5py
146    def __getstate__(self):
147        state = self.__dict__.copy()
148        del state["raw"]
149        return state
150
151    def __setstate__(self, state):
152        raw_path, raw_key = state["raw_path"], state["raw_key"]
153        roi = state["roi"]
154        try:
155            raw = load_data(raw_path, raw_key)
156            if roi is not None:
157                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
158            state["raw"] = raw
159        except Exception:
160            msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
161            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
162            msg += "But it cannot be used for further training and wil throw an error."
163            warnings.warn(msg)
164            state["raw"] = None
165        self.__dict__.update(state)
class RawDataset(typing.Generic[+T_co]):
 14class RawDataset(torch.utils.data.Dataset):
 15    """
 16    """
 17    max_sampling_attempts = 500
 18
 19    @staticmethod
 20    def compute_len(shape, patch_shape):
 21        n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
 22        return n_samples
 23
 24    def __init__(
 25        self,
 26        raw_path: Union[List[Any], str, os.PathLike],
 27        raw_key: str,
 28        patch_shape: Tuple[int, ...],
 29        raw_transform=None,
 30        transform=None,
 31        roi: Optional[dict] = None,
 32        dtype: torch.dtype = torch.float32,
 33        n_samples: Optional[int] = None,
 34        sampler=None,
 35        ndim: Optional[int] = None,
 36        with_channels: bool = False,
 37        augmentations=None,
 38    ):
 39        self.raw_path = raw_path
 40        self.raw_key = raw_key
 41        self.raw = load_data(raw_path, raw_key)
 42
 43        self._with_channels = with_channels
 44
 45        if roi is not None:
 46            if isinstance(roi, slice):
 47                roi = (roi,)
 48            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
 49
 50        self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
 51        self.roi = roi
 52
 53        self._ndim = len(self.shape) if ndim is None else ndim
 54        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
 55
 56        assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
 57        self.patch_shape = patch_shape
 58
 59        self.raw_transform = raw_transform
 60        self.transform = transform
 61        self.sampler = sampler
 62        self.dtype = dtype
 63
 64        if augmentations is not None:
 65            assert len(augmentations) == 2
 66        self.augmentations = augmentations
 67
 68        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
 69
 70        self.sample_shape = patch_shape
 71        self.trafo_halo = None
 72        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
 73        # which is then cut. See code below; but this ne needs to be properly tested
 74
 75        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
 76        # if self.trafo_halo is not None:
 77        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
 78        #         self.trafo_halo = (0,) + self.trafo_halo
 79        #     assert len(self.trafo_halo) == self._ndim
 80        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
 81        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
 82
 83    def __len__(self):
 84        return self._len
 85
 86    @property
 87    def ndim(self):
 88        return self._ndim
 89
 90    def _sample_bounding_box(self):
 91        bb_start = [
 92            np.random.randint(0, sh - psh) if sh - psh > 0 else 0
 93            for sh, psh in zip(self.shape, self.sample_shape)
 94        ]
 95        return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape))
 96
 97    def _get_sample(self, index):
 98        if self.raw is None:
 99            raise RuntimeError("RawDataset has not been properly deserialized.")
100        bb = self._sample_bounding_box()
101        raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
102
103        if self.sampler is not None:
104            sample_id = 0
105            while not self.sampler(raw):
106                bb = self._sample_bounding_box()
107                raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb]
108                sample_id += 1
109                if sample_id > self.max_sampling_attempts:
110                    raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
111
112        # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
113        if len(self.patch_shape) == self._ndim + 1:
114            raw = raw.squeeze(1 if self._with_channels else 0)
115
116        return raw
117
118    def crop(self, tensor):
119        bb = self.inner_bb
120        if tensor.ndim > len(bb):
121            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
122        return tensor[bb]
123
124    def __getitem__(self, index):
125        raw = self._get_sample(index)
126
127        if self.raw_transform is not None:
128            raw = self.raw_transform(raw)
129
130        if self.transform is not None:
131            raw = self.transform(raw)
132            if isinstance(raw, list):
133                assert len(raw) == 1
134                raw = raw[0]
135            if self.trafo_halo is not None:
136                raw = self.crop(raw)
137
138        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
139        if self.augmentations is not None:
140            aug1, aug2 = self.augmentations
141            raw1, raw2 = aug1(raw), aug2(raw)
142            return raw1, raw2
143
144        return raw
145
146    # need to overwrite pickle to support h5py
147    def __getstate__(self):
148        state = self.__dict__.copy()
149        del state["raw"]
150        return state
151
152    def __setstate__(self, state):
153        raw_path, raw_key = state["raw_path"], state["raw_key"]
154        roi = state["roi"]
155        try:
156            raw = load_data(raw_path, raw_key)
157            if roi is not None:
158                raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
159            state["raw"] = raw
160        except Exception:
161            msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n"
162            msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n"
163            msg += "But it cannot be used for further training and wil throw an error."
164            warnings.warn(msg)
165            state["raw"] = None
166        self.__dict__.update(state)
RawDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: str, patch_shape: Tuple[int, ...], raw_transform=None, transform=None, roi: Optional[dict] = None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler=None, ndim: Optional[int] = None, with_channels: bool = False, augmentations=None)
24    def __init__(
25        self,
26        raw_path: Union[List[Any], str, os.PathLike],
27        raw_key: str,
28        patch_shape: Tuple[int, ...],
29        raw_transform=None,
30        transform=None,
31        roi: Optional[dict] = None,
32        dtype: torch.dtype = torch.float32,
33        n_samples: Optional[int] = None,
34        sampler=None,
35        ndim: Optional[int] = None,
36        with_channels: bool = False,
37        augmentations=None,
38    ):
39        self.raw_path = raw_path
40        self.raw_key = raw_key
41        self.raw = load_data(raw_path, raw_key)
42
43        self._with_channels = with_channels
44
45        if roi is not None:
46            if isinstance(roi, slice):
47                roi = (roi,)
48            self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi)
49
50        self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape
51        self.roi = roi
52
53        self._ndim = len(self.shape) if ndim is None else ndim
54        assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported"
55
56        assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}"
57        self.patch_shape = patch_shape
58
59        self.raw_transform = raw_transform
60        self.transform = transform
61        self.sampler = sampler
62        self.dtype = dtype
63
64        if augmentations is not None:
65            assert len(augmentations) == 2
66        self.augmentations = augmentations
67
68        self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
69
70        self.sample_shape = patch_shape
71        self.trafo_halo = None
72        # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo,
73        # which is then cut. See code below; but this ne needs to be properly tested
74
75        # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape)
76        # if self.trafo_halo is not None:
77        #     if len(self.trafo_halo) == 2 and self._ndim == 3:
78        #         self.trafo_halo = (0,) + self.trafo_halo
79        #     assert len(self.trafo_halo) == self._ndim
80        #     self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo))
81        #     self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))
max_sampling_attempts = 500
@staticmethod
def compute_len(shape, patch_shape):
19    @staticmethod
20    def compute_len(shape, patch_shape):
21        n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
22        return n_samples
raw_path
raw_key
raw
shape
roi
patch_shape
raw_transform
transform
sampler
dtype
augmentations
sample_shape
trafo_halo
ndim
def crop(self, tensor):
118    def crop(self, tensor):
119        bb = self.inner_bb
120        if tensor.ndim > len(bb):
121            bb = (tensor.ndim - len(bb)) * (slice(None),) + bb
122        return tensor[bb]