torch_em.data.pseudo_label_dataset
1import os 2from typing import Union, Tuple, Optional, List, Any, Callable 3 4import torch 5 6from .raw_dataset import RawDataset 7from ..util import ensure_tensor_with_channels 8 9 10class PseudoLabelDataset(RawDataset): 11 """Dataset that uses a prediction function to provide raw data and pseudo labels for segmentation training. 12 13 The dataset loads a patch from the raw data and then applies the pseudo labeler to it to predict pseudo labels. 14 The raw data and pseudo labels are returned together as a sample for a batch. 15 The datataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 16 Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset. 17 It also supports regular image formats, such as .tif. For these cases set `raw_key=None`. 18 19 Args: 20 raw_path: The file path to the raw image data. May also be a list of file paths. 21 raw_key: The key to the internal dataset containing the raw data. 22 patch_shape: The patch shape for a training sample. 23 pseudo_labeler: The pseudo labeler. Must be a function that accepts the raw data as torch tensor 24 and that returns the predicted labels as torch tensor. 25 raw_transform: Transformation applied to the raw data of a sample. 26 label_transform: Transformation applied to the label data of a sample. 27 roi: Region of interest in the raw data. 28 If given, the raw data will only be loaded from the corresponding area. 29 dtype: The return data type of the raw data. 30 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 31 sampler: Sampler for rejecting samples according to a defined criterion. 32 The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input. 33 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 34 with_channels: Whether the raw data has channels. 35 labeler_device: The expected device for the pseudo labeler. 36 """ 37 def __init__( 38 self, 39 raw_path: Union[List[Any], str, os.PathLike], 40 raw_key: Optional[str], 41 patch_shape: Tuple[int, ...], 42 pseudo_labeler: Callable, 43 raw_transform: Optional[Callable] = None, 44 label_transform: Optional[Callable] = None, 45 transform: Optional[Callable] = None, 46 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 47 dtype: torch.dtype = torch.float32, 48 n_samples: Optional[int] = None, 49 sampler: Optional[Callable] = None, 50 ndim: Optional[Union[int]] = None, 51 with_channels: bool = False, 52 labeler_device: Optional[Union[str, torch.device]] = None, 53 ): 54 super().__init__( 55 raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform, roi=roi, 56 dtype=dtype, n_samples=n_samples, sampler=sampler, ndim=ndim, with_channels=with_channels 57 ) 58 self.pseudo_labeler = pseudo_labeler 59 self.label_transform = label_transform 60 self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device 61 62 def __getitem__(self, index): 63 raw = self._get_sample(index) 64 65 # Transform for augmentations. 66 # Applied to the raw data since, labels are generated on the fly by the pseudo_labeler. 67 if self.transform is not None: 68 raw = self.transform(raw)[0] 69 if self.trafo_halo is not None: 70 raw = self.crop(raw) 71 72 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 73 with torch.no_grad(): 74 # Ilastik needs uint as input, so normalize afterwards. 75 labels = self.pseudo_labeler(raw[None].to(self.labeler_device))[0] 76 77 # Normalize the raw data. 78 if self.raw_transform is not None: 79 raw = self.raw_transform(raw.cpu().detach().numpy()) 80 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 81 82 if self.label_transform is not None: 83 labels = self.label_transform(labels) 84 labels = ensure_tensor_with_channels(labels, ndim=self._ndim) 85 86 return raw, labels
class
PseudoLabelDataset(typing.Generic[+_T_co]):
11class PseudoLabelDataset(RawDataset): 12 """Dataset that uses a prediction function to provide raw data and pseudo labels for segmentation training. 13 14 The dataset loads a patch from the raw data and then applies the pseudo labeler to it to predict pseudo labels. 15 The raw data and pseudo labels are returned together as a sample for a batch. 16 The datataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5. 17 Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset. 18 It also supports regular image formats, such as .tif. For these cases set `raw_key=None`. 19 20 Args: 21 raw_path: The file path to the raw image data. May also be a list of file paths. 22 raw_key: The key to the internal dataset containing the raw data. 23 patch_shape: The patch shape for a training sample. 24 pseudo_labeler: The pseudo labeler. Must be a function that accepts the raw data as torch tensor 25 and that returns the predicted labels as torch tensor. 26 raw_transform: Transformation applied to the raw data of a sample. 27 label_transform: Transformation applied to the label data of a sample. 28 roi: Region of interest in the raw data. 29 If given, the raw data will only be loaded from the corresponding area. 30 dtype: The return data type of the raw data. 31 n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`. 32 sampler: Sampler for rejecting samples according to a defined criterion. 33 The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input. 34 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 35 with_channels: Whether the raw data has channels. 36 labeler_device: The expected device for the pseudo labeler. 37 """ 38 def __init__( 39 self, 40 raw_path: Union[List[Any], str, os.PathLike], 41 raw_key: Optional[str], 42 patch_shape: Tuple[int, ...], 43 pseudo_labeler: Callable, 44 raw_transform: Optional[Callable] = None, 45 label_transform: Optional[Callable] = None, 46 transform: Optional[Callable] = None, 47 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 48 dtype: torch.dtype = torch.float32, 49 n_samples: Optional[int] = None, 50 sampler: Optional[Callable] = None, 51 ndim: Optional[Union[int]] = None, 52 with_channels: bool = False, 53 labeler_device: Optional[Union[str, torch.device]] = None, 54 ): 55 super().__init__( 56 raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform, roi=roi, 57 dtype=dtype, n_samples=n_samples, sampler=sampler, ndim=ndim, with_channels=with_channels 58 ) 59 self.pseudo_labeler = pseudo_labeler 60 self.label_transform = label_transform 61 self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device 62 63 def __getitem__(self, index): 64 raw = self._get_sample(index) 65 66 # Transform for augmentations. 67 # Applied to the raw data since, labels are generated on the fly by the pseudo_labeler. 68 if self.transform is not None: 69 raw = self.transform(raw)[0] 70 if self.trafo_halo is not None: 71 raw = self.crop(raw) 72 73 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 74 with torch.no_grad(): 75 # Ilastik needs uint as input, so normalize afterwards. 76 labels = self.pseudo_labeler(raw[None].to(self.labeler_device))[0] 77 78 # Normalize the raw data. 79 if self.raw_transform is not None: 80 raw = self.raw_transform(raw.cpu().detach().numpy()) 81 raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) 82 83 if self.label_transform is not None: 84 labels = self.label_transform(labels) 85 labels = ensure_tensor_with_channels(labels, ndim=self._ndim) 86 87 return raw, labels
Dataset that uses a prediction function to provide raw data and pseudo labels for segmentation training.
The dataset loads a patch from the raw data and then applies the pseudo labeler to it to predict pseudo labels.
The raw data and pseudo labels are returned together as a sample for a batch.
The datataset supports all file formats that can be opened with elf.io.open_file
, such as hdf5, zarr or n5.
Use raw_path
to specify the path to the file and raw_key
to specify the internal dataset.
It also supports regular image formats, such as .tif. For these cases set raw_key=None
.
Arguments:
- raw_path: The file path to the raw image data. May also be a list of file paths.
- raw_key: The key to the internal dataset containing the raw data.
- patch_shape: The patch shape for a training sample.
- pseudo_labeler: The pseudo labeler. Must be a function that accepts the raw data as torch tensor and that returns the predicted labels as torch tensor.
- raw_transform: Transformation applied to the raw data of a sample.
- label_transform: Transformation applied to the label data of a sample.
- roi: Region of interest in the raw data. If given, the raw data will only be loaded from the corresponding area.
- dtype: The return data type of the raw data.
- n_samples: The length of this dataset. If None, the length will be set to
len(raw_image_paths)
. - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- with_channels: Whether the raw data has channels.
- labeler_device: The expected device for the pseudo labeler.
PseudoLabelDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: Optional[str], patch_shape: Tuple[int, ...], pseudo_labeler: Callable, raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, transform: Optional[Callable] = None, roi: Union[slice, Tuple[slice, ...], NoneType] = None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, with_channels: bool = False, labeler_device: Union[str, torch.device, NoneType] = None)
38 def __init__( 39 self, 40 raw_path: Union[List[Any], str, os.PathLike], 41 raw_key: Optional[str], 42 patch_shape: Tuple[int, ...], 43 pseudo_labeler: Callable, 44 raw_transform: Optional[Callable] = None, 45 label_transform: Optional[Callable] = None, 46 transform: Optional[Callable] = None, 47 roi: Optional[Union[slice, Tuple[slice, ...]]] = None, 48 dtype: torch.dtype = torch.float32, 49 n_samples: Optional[int] = None, 50 sampler: Optional[Callable] = None, 51 ndim: Optional[Union[int]] = None, 52 with_channels: bool = False, 53 labeler_device: Optional[Union[str, torch.device]] = None, 54 ): 55 super().__init__( 56 raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform, roi=roi, 57 dtype=dtype, n_samples=n_samples, sampler=sampler, ndim=ndim, with_channels=with_channels 58 ) 59 self.pseudo_labeler = pseudo_labeler 60 self.label_transform = label_transform 61 self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device