torch_em.data.raw_dataset
1import os 2import warnings 3import numpy as np 4from typing import List, Union, Tuple, Optional, Any 5 6import torch 7 8from elf.wrapper import RoiWrapper 9 10from ..util import ensure_tensor_with_channels, load_data 11 12 13class RawDataset(torch.utils.data.Dataset): 14 """ 15 """ 16 max_sampling_attempts = 500 17 18 @staticmethod 19 def compute_len(shape, patch_shape): 20 n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 21 return n_samples 22 23 def __init__( 24 self, 25 raw_path: Union[List[Any], str, os.PathLike], 26 raw_key: str, 27 patch_shape: Tuple[int, ...], 28 raw_transform=None, 29 transform=None, 30 roi: Optional[dict] = None, 31 dtype: torch.dtype = torch.float32, 32 n_samples: Optional[int] = None, 33 sampler=None, 34 ndim: Optional[int] = None, 35 with_channels: bool = False, 36 augmentations=None, 37 ): 38 self.raw_path = raw_path 39 self.raw_key = raw_key 40 self.raw = load_data(raw_path, raw_key) 41 42 self._with_channels = with_channels 43 44 if roi is not None: 45 if isinstance(roi, slice): 46 roi = (roi,) 47 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 48 49 self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 50 self.roi = roi 51 52 self._ndim = len(self.shape) if ndim is None else ndim 53 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 54 55 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 56 self.patch_shape = patch_shape 57 58 self.raw_transform = raw_transform 59 self.transform = transform 60 self.sampler = sampler 61 self.dtype = dtype 62 63 if augmentations is not None: 64 assert len(augmentations) == 2 65 self.augmentations = augmentations 66 67 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 68 69 self.sample_shape = patch_shape 70 self.trafo_halo = None 71 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 72 # which is then cut. See code below; but this ne needs to be properly tested 73 74 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 75 # if self.trafo_halo is not None: 76 # if len(self.trafo_halo) == 2 and self._ndim == 3: 77 # self.trafo_halo = (0,) + self.trafo_halo 78 # assert len(self.trafo_halo) == self._ndim 79 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 80 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 81 82 def __len__(self): 83 return self._len 84 85 @property 86 def ndim(self): 87 return self._ndim 88 89 def _sample_bounding_box(self): 90 bb_start = [ 91 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 92 for sh, psh in zip(self.shape, self.sample_shape) 93 ] 94 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape)) 95 96 def _get_sample(self, index): 97 if self.raw is None: 98 raise RuntimeError("RawDataset has not been properly deserialized.") 99 bb = self._sample_bounding_box() 100 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 101 102 if self.sampler is not None: 103 sample_id = 0 104 while not self.sampler(raw): 105 bb = self._sample_bounding_box() 106 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 107 sample_id += 1 108 if sample_id > self.max_sampling_attempts: 109 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 110 111 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 112 if len(self.patch_shape) == self._ndim + 1: 113 raw = raw.squeeze(1 if self._with_channels else 0) 114 115 return raw 116 117 def crop(self, tensor): 118 bb = self.inner_bb 119 if tensor.ndim > len(bb): 120 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 121 return tensor[bb] 122 123 def __getitem__(self, index): 124 raw = self._get_sample(index) 125 126 if self.raw_transform is not None: 127 raw = self.raw_transform(raw) 128 129 if self.transform is not None: 130 raw = self.transform(raw) 131 if isinstance(raw, list): 132 assert len(raw) == 1 133 raw = raw[0] 134 if self.trafo_halo is not None: 135 raw = self.crop(raw) 136 137 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 138 if self.augmentations is not None: 139 aug1, aug2 = self.augmentations 140 raw1, raw2 = aug1(raw), aug2(raw) 141 return raw1, raw2 142 143 return raw 144 145 # need to overwrite pickle to support h5py 146 def __getstate__(self): 147 state = self.__dict__.copy() 148 del state["raw"] 149 return state 150 151 def __setstate__(self, state): 152 raw_path, raw_key = state["raw_path"], state["raw_key"] 153 roi = state["roi"] 154 try: 155 raw = load_data(raw_path, raw_key) 156 if roi is not None: 157 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 158 state["raw"] = raw 159 except Exception: 160 msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 161 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 162 msg += "But it cannot be used for further training and wil throw an error." 163 warnings.warn(msg) 164 state["raw"] = None 165 self.__dict__.update(state)
class
RawDataset(typing.Generic[+T_co]):
14class RawDataset(torch.utils.data.Dataset): 15 """ 16 """ 17 max_sampling_attempts = 500 18 19 @staticmethod 20 def compute_len(shape, patch_shape): 21 n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 22 return n_samples 23 24 def __init__( 25 self, 26 raw_path: Union[List[Any], str, os.PathLike], 27 raw_key: str, 28 patch_shape: Tuple[int, ...], 29 raw_transform=None, 30 transform=None, 31 roi: Optional[dict] = None, 32 dtype: torch.dtype = torch.float32, 33 n_samples: Optional[int] = None, 34 sampler=None, 35 ndim: Optional[int] = None, 36 with_channels: bool = False, 37 augmentations=None, 38 ): 39 self.raw_path = raw_path 40 self.raw_key = raw_key 41 self.raw = load_data(raw_path, raw_key) 42 43 self._with_channels = with_channels 44 45 if roi is not None: 46 if isinstance(roi, slice): 47 roi = (roi,) 48 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 49 50 self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 51 self.roi = roi 52 53 self._ndim = len(self.shape) if ndim is None else ndim 54 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 55 56 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 57 self.patch_shape = patch_shape 58 59 self.raw_transform = raw_transform 60 self.transform = transform 61 self.sampler = sampler 62 self.dtype = dtype 63 64 if augmentations is not None: 65 assert len(augmentations) == 2 66 self.augmentations = augmentations 67 68 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 69 70 self.sample_shape = patch_shape 71 self.trafo_halo = None 72 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 73 # which is then cut. See code below; but this ne needs to be properly tested 74 75 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 76 # if self.trafo_halo is not None: 77 # if len(self.trafo_halo) == 2 and self._ndim == 3: 78 # self.trafo_halo = (0,) + self.trafo_halo 79 # assert len(self.trafo_halo) == self._ndim 80 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 81 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 82 83 def __len__(self): 84 return self._len 85 86 @property 87 def ndim(self): 88 return self._ndim 89 90 def _sample_bounding_box(self): 91 bb_start = [ 92 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 93 for sh, psh in zip(self.shape, self.sample_shape) 94 ] 95 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape)) 96 97 def _get_sample(self, index): 98 if self.raw is None: 99 raise RuntimeError("RawDataset has not been properly deserialized.") 100 bb = self._sample_bounding_box() 101 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 102 103 if self.sampler is not None: 104 sample_id = 0 105 while not self.sampler(raw): 106 bb = self._sample_bounding_box() 107 raw = self.raw[(slice(None),) + bb] if self._with_channels else self.raw[bb] 108 sample_id += 1 109 if sample_id > self.max_sampling_attempts: 110 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 111 112 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 113 if len(self.patch_shape) == self._ndim + 1: 114 raw = raw.squeeze(1 if self._with_channels else 0) 115 116 return raw 117 118 def crop(self, tensor): 119 bb = self.inner_bb 120 if tensor.ndim > len(bb): 121 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 122 return tensor[bb] 123 124 def __getitem__(self, index): 125 raw = self._get_sample(index) 126 127 if self.raw_transform is not None: 128 raw = self.raw_transform(raw) 129 130 if self.transform is not None: 131 raw = self.transform(raw) 132 if isinstance(raw, list): 133 assert len(raw) == 1 134 raw = raw[0] 135 if self.trafo_halo is not None: 136 raw = self.crop(raw) 137 138 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 139 if self.augmentations is not None: 140 aug1, aug2 = self.augmentations 141 raw1, raw2 = aug1(raw), aug2(raw) 142 return raw1, raw2 143 144 return raw 145 146 # need to overwrite pickle to support h5py 147 def __getstate__(self): 148 state = self.__dict__.copy() 149 del state["raw"] 150 return state 151 152 def __setstate__(self, state): 153 raw_path, raw_key = state["raw_path"], state["raw_key"] 154 roi = state["roi"] 155 try: 156 raw = load_data(raw_path, raw_key) 157 if roi is not None: 158 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 159 state["raw"] = raw 160 except Exception: 161 msg = f"RawDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 162 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 163 msg += "But it cannot be used for further training and wil throw an error." 164 warnings.warn(msg) 165 state["raw"] = None 166 self.__dict__.update(state)
RawDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: str, patch_shape: Tuple[int, ...], raw_transform=None, transform=None, roi: Optional[dict] = None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler=None, ndim: Optional[int] = None, with_channels: bool = False, augmentations=None)
24 def __init__( 25 self, 26 raw_path: Union[List[Any], str, os.PathLike], 27 raw_key: str, 28 patch_shape: Tuple[int, ...], 29 raw_transform=None, 30 transform=None, 31 roi: Optional[dict] = None, 32 dtype: torch.dtype = torch.float32, 33 n_samples: Optional[int] = None, 34 sampler=None, 35 ndim: Optional[int] = None, 36 with_channels: bool = False, 37 augmentations=None, 38 ): 39 self.raw_path = raw_path 40 self.raw_key = raw_key 41 self.raw = load_data(raw_path, raw_key) 42 43 self._with_channels = with_channels 44 45 if roi is not None: 46 if isinstance(roi, slice): 47 roi = (roi,) 48 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 49 50 self.shape = self.raw.shape[1:] if self._with_channels else self.raw.shape 51 self.roi = roi 52 53 self._ndim = len(self.shape) if ndim is None else ndim 54 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 55 56 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 57 self.patch_shape = patch_shape 58 59 self.raw_transform = raw_transform 60 self.transform = transform 61 self.sampler = sampler 62 self.dtype = dtype 63 64 if augmentations is not None: 65 assert len(augmentations) == 2 66 self.augmentations = augmentations 67 68 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 69 70 self.sample_shape = patch_shape 71 self.trafo_halo = None 72 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 73 # which is then cut. See code below; but this ne needs to be properly tested 74 75 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 76 # if self.trafo_halo is not None: 77 # if len(self.trafo_halo) == 2 and self._ndim == 3: 78 # self.trafo_halo = (0,) + self.trafo_halo 79 # assert len(self.trafo_halo) == self._ndim 80 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 81 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))