torch_em.data.segmentation_dataset
1import os 2import warnings 3import numpy as np 4from typing import List, Union, Tuple, Optional, Any 5 6import torch 7 8from elf.wrapper import RoiWrapper 9 10from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data 11 12 13class SegmentationDataset(torch.utils.data.Dataset): 14 """ 15 """ 16 max_sampling_attempts = 500 17 18 @staticmethod 19 def compute_len(shape, patch_shape): 20 if patch_shape is None: 21 return 1 22 else: 23 n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 24 return n_samples 25 26 def __init__( 27 self, 28 raw_path: Union[List[Any], str, os.PathLike], 29 raw_key: str, 30 label_path: Union[List[Any], str, os.PathLike], 31 label_key: str, 32 patch_shape: Tuple[int, ...], 33 raw_transform=None, 34 label_transform=None, 35 label_transform2=None, 36 transform=None, 37 roi: Optional[dict] = None, 38 dtype: torch.dtype = torch.float32, 39 label_dtype: torch.dtype = torch.float32, 40 n_samples: Optional[int] = None, 41 sampler=None, 42 ndim: Optional[int] = None, 43 with_channels: bool = False, 44 with_label_channels: bool = False, 45 ): 46 self.raw_path = raw_path 47 self.raw_key = raw_key 48 self.raw = load_data(raw_path, raw_key) 49 50 self.label_path = label_path 51 self.label_key = label_key 52 self.labels = load_data(label_path, label_key) 53 54 self._with_channels = with_channels 55 self._with_label_channels = with_label_channels 56 57 if roi is not None: 58 if isinstance(roi, slice): 59 roi = (roi,) 60 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 61 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 62 RoiWrapper(self.labels, roi) 63 64 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 65 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 66 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 67 68 self.shape = shape_raw 69 self.roi = roi 70 71 self._ndim = len(shape_raw) if ndim is None else ndim 72 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 73 74 if patch_shape is not None: 75 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 76 77 self.patch_shape = patch_shape 78 79 self.raw_transform = raw_transform 80 self.label_transform = label_transform 81 self.label_transform2 = label_transform2 82 self.transform = transform 83 self.sampler = sampler 84 85 self.dtype = dtype 86 self.label_dtype = label_dtype 87 88 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 89 90 self.sample_shape = patch_shape 91 self.trafo_halo = None 92 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 93 # which is then cut. See code below; but this ne needs to be properly tested 94 95 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 96 # if self.trafo_halo is not None: 97 # if len(self.trafo_halo) == 2 and self._ndim == 3: 98 # self.trafo_halo = (0,) + self.trafo_halo 99 # assert len(self.trafo_halo) == self._ndim 100 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 101 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 102 103 def __len__(self): 104 return self._len 105 106 @property 107 def ndim(self): 108 return self._ndim 109 110 def _sample_bounding_box(self): 111 if self.sample_shape is None: 112 bb_start = [0] * len(self.shape) 113 patch_shape_for_bb = self.shape 114 else: 115 bb_start = [ 116 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 117 for sh, psh in zip(self.shape, self.sample_shape) 118 ] 119 patch_shape_for_bb = self.sample_shape 120 121 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 122 123 def _get_sample(self, index): 124 if self.raw is None or self.labels is None: 125 raise RuntimeError("SegmentationDataset has not been properly deserialized.") 126 bb = self._sample_bounding_box() 127 bb_raw = (slice(None),) + bb if self._with_channels else bb 128 bb_labels = (slice(None),) + bb if self._with_label_channels else bb 129 raw, labels = self.raw[bb_raw], self.labels[bb_labels] 130 131 if self.sampler is not None: 132 sample_id = 0 133 while not self.sampler(raw, labels): 134 bb = self._sample_bounding_box() 135 bb_raw = (slice(None),) + bb if self._with_channels else bb 136 bb_labels = (slice(None),) + bb if self._with_label_channels else bb 137 raw, labels = self.raw[bb_raw], self.labels[bb_labels] 138 sample_id += 1 139 if sample_id > self.max_sampling_attempts: 140 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 141 142 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 143 if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1: 144 raw = raw.squeeze(1 if self._with_channels else 0) 145 labels = labels.squeeze(1 if self._with_label_channels else 0) 146 147 return raw, labels 148 149 def crop(self, tensor): 150 bb = self.inner_bb 151 if tensor.ndim > len(bb): 152 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 153 return tensor[bb] 154 155 def __getitem__(self, index): 156 raw, labels = self._get_sample(index) 157 initial_label_dtype = labels.dtype 158 159 if self.raw_transform is not None: 160 raw = self.raw_transform(raw) 161 162 if self.label_transform is not None: 163 labels = self.label_transform(labels) 164 165 if self.transform is not None: 166 raw, labels = self.transform(raw, labels) 167 if self.trafo_halo is not None: 168 raw = self.crop(raw) 169 labels = self.crop(labels) 170 171 # support enlarging bounding box here as well (for affinity transform) ? 172 if self.label_transform2 is not None: 173 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 174 labels = self.label_transform2(labels) 175 176 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 177 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 178 return raw, labels 179 180 # need to overwrite pickle to support h5py 181 def __getstate__(self): 182 state = self.__dict__.copy() 183 del state["raw"] 184 del state["labels"] 185 return state 186 187 def __setstate__(self, state): 188 raw_path, raw_key = state["raw_path"], state["raw_key"] 189 label_path, label_key = state["label_path"], state["label_key"] 190 roi = state["roi"] 191 try: 192 raw = load_data(raw_path, raw_key) 193 if roi is not None: 194 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 195 state["raw"] = raw 196 except Exception: 197 msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 198 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 199 msg += "But it cannot be used for further training and wil throw an error." 200 warnings.warn(msg) 201 state["raw"] = None 202 203 try: 204 labels = load_data(label_path, label_key) 205 if roi is not None: 206 labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\ 207 RoiWrapper(labels, roi) 208 state["labels"] = labels 209 except Exception: 210 msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n" 211 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 212 msg += "But it cannot be used for further training and wil throw an error." 213 warnings.warn(msg) 214 state["labels"] = None 215 216 self.__dict__.update(state)
class
SegmentationDataset(typing.Generic[+T_co]):
14class SegmentationDataset(torch.utils.data.Dataset): 15 """ 16 """ 17 max_sampling_attempts = 500 18 19 @staticmethod 20 def compute_len(shape, patch_shape): 21 if patch_shape is None: 22 return 1 23 else: 24 n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) 25 return n_samples 26 27 def __init__( 28 self, 29 raw_path: Union[List[Any], str, os.PathLike], 30 raw_key: str, 31 label_path: Union[List[Any], str, os.PathLike], 32 label_key: str, 33 patch_shape: Tuple[int, ...], 34 raw_transform=None, 35 label_transform=None, 36 label_transform2=None, 37 transform=None, 38 roi: Optional[dict] = None, 39 dtype: torch.dtype = torch.float32, 40 label_dtype: torch.dtype = torch.float32, 41 n_samples: Optional[int] = None, 42 sampler=None, 43 ndim: Optional[int] = None, 44 with_channels: bool = False, 45 with_label_channels: bool = False, 46 ): 47 self.raw_path = raw_path 48 self.raw_key = raw_key 49 self.raw = load_data(raw_path, raw_key) 50 51 self.label_path = label_path 52 self.label_key = label_key 53 self.labels = load_data(label_path, label_key) 54 55 self._with_channels = with_channels 56 self._with_label_channels = with_label_channels 57 58 if roi is not None: 59 if isinstance(roi, slice): 60 roi = (roi,) 61 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 62 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 63 RoiWrapper(self.labels, roi) 64 65 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 66 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 67 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 68 69 self.shape = shape_raw 70 self.roi = roi 71 72 self._ndim = len(shape_raw) if ndim is None else ndim 73 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 74 75 if patch_shape is not None: 76 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 77 78 self.patch_shape = patch_shape 79 80 self.raw_transform = raw_transform 81 self.label_transform = label_transform 82 self.label_transform2 = label_transform2 83 self.transform = transform 84 self.sampler = sampler 85 86 self.dtype = dtype 87 self.label_dtype = label_dtype 88 89 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 90 91 self.sample_shape = patch_shape 92 self.trafo_halo = None 93 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 94 # which is then cut. See code below; but this ne needs to be properly tested 95 96 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 97 # if self.trafo_halo is not None: 98 # if len(self.trafo_halo) == 2 and self._ndim == 3: 99 # self.trafo_halo = (0,) + self.trafo_halo 100 # assert len(self.trafo_halo) == self._ndim 101 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 102 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo)) 103 104 def __len__(self): 105 return self._len 106 107 @property 108 def ndim(self): 109 return self._ndim 110 111 def _sample_bounding_box(self): 112 if self.sample_shape is None: 113 bb_start = [0] * len(self.shape) 114 patch_shape_for_bb = self.shape 115 else: 116 bb_start = [ 117 np.random.randint(0, sh - psh) if sh - psh > 0 else 0 118 for sh, psh in zip(self.shape, self.sample_shape) 119 ] 120 patch_shape_for_bb = self.sample_shape 121 122 return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) 123 124 def _get_sample(self, index): 125 if self.raw is None or self.labels is None: 126 raise RuntimeError("SegmentationDataset has not been properly deserialized.") 127 bb = self._sample_bounding_box() 128 bb_raw = (slice(None),) + bb if self._with_channels else bb 129 bb_labels = (slice(None),) + bb if self._with_label_channels else bb 130 raw, labels = self.raw[bb_raw], self.labels[bb_labels] 131 132 if self.sampler is not None: 133 sample_id = 0 134 while not self.sampler(raw, labels): 135 bb = self._sample_bounding_box() 136 bb_raw = (slice(None),) + bb if self._with_channels else bb 137 bb_labels = (slice(None),) + bb if self._with_label_channels else bb 138 raw, labels = self.raw[bb_raw], self.labels[bb_labels] 139 sample_id += 1 140 if sample_id > self.max_sampling_attempts: 141 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") 142 143 # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim 144 if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1: 145 raw = raw.squeeze(1 if self._with_channels else 0) 146 labels = labels.squeeze(1 if self._with_label_channels else 0) 147 148 return raw, labels 149 150 def crop(self, tensor): 151 bb = self.inner_bb 152 if tensor.ndim > len(bb): 153 bb = (tensor.ndim - len(bb)) * (slice(None),) + bb 154 return tensor[bb] 155 156 def __getitem__(self, index): 157 raw, labels = self._get_sample(index) 158 initial_label_dtype = labels.dtype 159 160 if self.raw_transform is not None: 161 raw = self.raw_transform(raw) 162 163 if self.label_transform is not None: 164 labels = self.label_transform(labels) 165 166 if self.transform is not None: 167 raw, labels = self.transform(raw, labels) 168 if self.trafo_halo is not None: 169 raw = self.crop(raw) 170 labels = self.crop(labels) 171 172 # support enlarging bounding box here as well (for affinity transform) ? 173 if self.label_transform2 is not None: 174 labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) 175 labels = self.label_transform2(labels) 176 177 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 178 labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) 179 return raw, labels 180 181 # need to overwrite pickle to support h5py 182 def __getstate__(self): 183 state = self.__dict__.copy() 184 del state["raw"] 185 del state["labels"] 186 return state 187 188 def __setstate__(self, state): 189 raw_path, raw_key = state["raw_path"], state["raw_key"] 190 label_path, label_key = state["label_path"], state["label_key"] 191 roi = state["roi"] 192 try: 193 raw = load_data(raw_path, raw_key) 194 if roi is not None: 195 raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) 196 state["raw"] = raw 197 except Exception: 198 msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" 199 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 200 msg += "But it cannot be used for further training and wil throw an error." 201 warnings.warn(msg) 202 state["raw"] = None 203 204 try: 205 labels = load_data(label_path, label_key) 206 if roi is not None: 207 labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\ 208 RoiWrapper(labels, roi) 209 state["labels"] = labels 210 except Exception: 211 msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n" 212 msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" 213 msg += "But it cannot be used for further training and wil throw an error." 214 warnings.warn(msg) 215 state["labels"] = None 216 217 self.__dict__.update(state)
SegmentationDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: str, label_path: Union[List[Any], str, os.PathLike], label_key: str, patch_shape: Tuple[int, ...], raw_transform=None, label_transform=None, label_transform2=None, transform=None, roi: Optional[dict] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler=None, ndim: Optional[int] = None, with_channels: bool = False, with_label_channels: bool = False)
27 def __init__( 28 self, 29 raw_path: Union[List[Any], str, os.PathLike], 30 raw_key: str, 31 label_path: Union[List[Any], str, os.PathLike], 32 label_key: str, 33 patch_shape: Tuple[int, ...], 34 raw_transform=None, 35 label_transform=None, 36 label_transform2=None, 37 transform=None, 38 roi: Optional[dict] = None, 39 dtype: torch.dtype = torch.float32, 40 label_dtype: torch.dtype = torch.float32, 41 n_samples: Optional[int] = None, 42 sampler=None, 43 ndim: Optional[int] = None, 44 with_channels: bool = False, 45 with_label_channels: bool = False, 46 ): 47 self.raw_path = raw_path 48 self.raw_key = raw_key 49 self.raw = load_data(raw_path, raw_key) 50 51 self.label_path = label_path 52 self.label_key = label_key 53 self.labels = load_data(label_path, label_key) 54 55 self._with_channels = with_channels 56 self._with_label_channels = with_label_channels 57 58 if roi is not None: 59 if isinstance(roi, slice): 60 roi = (roi,) 61 self.raw = RoiWrapper(self.raw, (slice(None),) + roi) if self._with_channels else RoiWrapper(self.raw, roi) 62 self.labels = RoiWrapper(self.labels, (slice(None),) + roi) if self._with_label_channels else\ 63 RoiWrapper(self.labels, roi) 64 65 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape 66 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape 67 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}" 68 69 self.shape = shape_raw 70 self.roi = roi 71 72 self._ndim = len(shape_raw) if ndim is None else ndim 73 assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" 74 75 if patch_shape is not None: 76 assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" 77 78 self.patch_shape = patch_shape 79 80 self.raw_transform = raw_transform 81 self.label_transform = label_transform 82 self.label_transform2 = label_transform2 83 self.transform = transform 84 self.sampler = sampler 85 86 self.dtype = dtype 87 self.label_dtype = label_dtype 88 89 self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples 90 91 self.sample_shape = patch_shape 92 self.trafo_halo = None 93 # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, 94 # which is then cut. See code below; but this ne needs to be properly tested 95 96 # self.trafo_halo = None if self.transform is None else self.transform.halo(self.patch_shape) 97 # if self.trafo_halo is not None: 98 # if len(self.trafo_halo) == 2 and self._ndim == 3: 99 # self.trafo_halo = (0,) + self.trafo_halo 100 # assert len(self.trafo_halo) == self._ndim 101 # self.sample_shape = tuple(sh + ha for sh, ha in zip(self.patch_shape, self.trafo_halo)) 102 # self.inner_bb = tuple(slice(ha, sh - ha) for sh, ha in zip(self.patch_shape, self.trafo_halo))