torch_em.data.pseudo_label_dataset
1import os 2from typing import Union, Tuple, Optional, List, Any 3 4import torch 5 6from .raw_dataset import RawDataset 7from ..util import ensure_tensor_with_channels 8 9 10class PseudoLabelDataset(RawDataset): 11 def __init__( 12 self, 13 raw_path: Union[List[Any], str, os.PathLike], 14 raw_key: str, 15 patch_shape: Tuple[int, ...], 16 pseudo_labeler, 17 raw_transform=None, 18 label_transform=None, 19 transform=None, 20 roi=None, 21 dtype: torch.dtype = torch.float32, 22 n_samples: Optional[int] = None, 23 sampler=None, 24 ndim: Optional[Union[int]] = None, 25 with_channels: bool = False, 26 labeler_device: Optional[Union[str, torch.device]] = None, 27 ): 28 super().__init__(raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform, 29 roi=roi, dtype=dtype, n_samples=n_samples, sampler=sampler, 30 ndim=ndim, with_channels=with_channels) 31 self.pseudo_labeler = pseudo_labeler 32 self.label_transform = label_transform 33 self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device 34 35 def __getitem__(self, index): 36 raw = self._get_sample(index) 37 38 # transform for augmentations 39 # only applied to raw since labels are generated on the fly anyway by the pseudo_labeler 40 if self.transform is not None: 41 raw = self.transform(raw)[0] 42 if self.trafo_halo is not None: 43 raw = self.crop(raw) 44 45 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 46 with torch.no_grad(): 47 labels = self.pseudo_labeler( 48 raw[None].to(self.labeler_device))[0] # ilastik needs uint input, so normalize afterwards 49 50 # normalize after ilastik 51 if self.raw_transform is not None: 52 raw = self.raw_transform( 53 raw.cpu().detach().numpy() 54 ) # normalization functions need numpy array, self.transform already creates torch.tensor 55 56 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 57 58 if self.label_transform is not None: 59 labels = self.label_transform(labels) 60 labels = ensure_tensor_with_channels(labels, ndim=self._ndim) 61 62 return raw, labels
class
PseudoLabelDataset(typing.Generic[+T_co]):
11class PseudoLabelDataset(RawDataset): 12 def __init__( 13 self, 14 raw_path: Union[List[Any], str, os.PathLike], 15 raw_key: str, 16 patch_shape: Tuple[int, ...], 17 pseudo_labeler, 18 raw_transform=None, 19 label_transform=None, 20 transform=None, 21 roi=None, 22 dtype: torch.dtype = torch.float32, 23 n_samples: Optional[int] = None, 24 sampler=None, 25 ndim: Optional[Union[int]] = None, 26 with_channels: bool = False, 27 labeler_device: Optional[Union[str, torch.device]] = None, 28 ): 29 super().__init__(raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform, 30 roi=roi, dtype=dtype, n_samples=n_samples, sampler=sampler, 31 ndim=ndim, with_channels=with_channels) 32 self.pseudo_labeler = pseudo_labeler 33 self.label_transform = label_transform 34 self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device 35 36 def __getitem__(self, index): 37 raw = self._get_sample(index) 38 39 # transform for augmentations 40 # only applied to raw since labels are generated on the fly anyway by the pseudo_labeler 41 if self.transform is not None: 42 raw = self.transform(raw)[0] 43 if self.trafo_halo is not None: 44 raw = self.crop(raw) 45 46 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 47 with torch.no_grad(): 48 labels = self.pseudo_labeler( 49 raw[None].to(self.labeler_device))[0] # ilastik needs uint input, so normalize afterwards 50 51 # normalize after ilastik 52 if self.raw_transform is not None: 53 raw = self.raw_transform( 54 raw.cpu().detach().numpy() 55 ) # normalization functions need numpy array, self.transform already creates torch.tensor 56 57 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 58 59 if self.label_transform is not None: 60 labels = self.label_transform(labels) 61 labels = ensure_tensor_with_channels(labels, ndim=self._ndim) 62 63 return raw, labels
PseudoLabelDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: str, patch_shape: Tuple[int, ...], pseudo_labeler, raw_transform=None, label_transform=None, transform=None, roi=None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler=None, ndim: Optional[int] = None, with_channels: bool = False, labeler_device: Union[str, torch.device, NoneType] = None)
12 def __init__( 13 self, 14 raw_path: Union[List[Any], str, os.PathLike], 15 raw_key: str, 16 patch_shape: Tuple[int, ...], 17 pseudo_labeler, 18 raw_transform=None, 19 label_transform=None, 20 transform=None, 21 roi=None, 22 dtype: torch.dtype = torch.float32, 23 n_samples: Optional[int] = None, 24 sampler=None, 25 ndim: Optional[Union[int]] = None, 26 with_channels: bool = False, 27 labeler_device: Optional[Union[str, torch.device]] = None, 28 ): 29 super().__init__(raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform, 30 roi=roi, dtype=dtype, n_samples=n_samples, sampler=sampler, 31 ndim=ndim, with_channels=with_channels) 32 self.pseudo_labeler = pseudo_labeler 33 self.label_transform = label_transform 34 self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device