torch_em.data.pseudo_label_dataset

 1import os
 2from typing import Union, Tuple, Optional, List, Any, Callable
 3
 4import torch
 5
 6from .raw_dataset import RawDataset
 7from ..util import ensure_tensor_with_channels
 8
 9
10class PseudoLabelDataset(RawDataset):
11    """Dataset that uses a prediction function to provide raw data and pseudo labels for segmentation training.
12
13    The dataset loads a patch from the raw data and then applies the pseudo labeler to it to predict pseudo labels.
14    The raw data and pseudo labels are returned together as a sample for a batch.
15    The datataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
16    Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset.
17    It also supports regular image formats, such as .tif. For these cases set `raw_key=None`.
18
19    Args:
20        raw_path: The file path to the raw image data. May also be a list of file paths.
21        raw_key: The key to the internal dataset containing the raw data.
22        patch_shape: The patch shape for a training sample.
23        pseudo_labeler: The pseudo labeler. Must be a function that accepts the raw data as torch tensor
24            and that returns the predicted labels as torch tensor.
25        raw_transform: Transformation applied to the raw data of a sample.
26        label_transform: Transformation applied to the label data of a sample.
27        roi: Region of interest in the raw data.
28            If given, the raw data will only be loaded from the corresponding area.
29        dtype: The return data type of the raw data.
30        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
31        sampler: Sampler for rejecting samples according to a defined criterion.
32            The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
33        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
34        with_channels: Whether the raw data has channels.
35        labeler_device: The expected device for the pseudo labeler.
36    """
37    def __init__(
38        self,
39        raw_path: Union[List[Any], str, os.PathLike],
40        raw_key: Optional[str],
41        patch_shape: Tuple[int, ...],
42        pseudo_labeler: Callable,
43        raw_transform: Optional[Callable] = None,
44        label_transform: Optional[Callable] = None,
45        transform: Optional[Callable] = None,
46        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
47        dtype: torch.dtype = torch.float32,
48        n_samples: Optional[int] = None,
49        sampler: Optional[Callable] = None,
50        ndim: Optional[Union[int]] = None,
51        with_channels: bool = False,
52        labeler_device: Optional[Union[str, torch.device]] = None,
53    ):
54        super().__init__(
55            raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform, roi=roi,
56            dtype=dtype, n_samples=n_samples, sampler=sampler, ndim=ndim, with_channels=with_channels
57        )
58        self.pseudo_labeler = pseudo_labeler
59        self.label_transform = label_transform
60        self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device
61
62    def __getitem__(self, index):
63        raw = self._get_sample(index)
64
65        # Transform for augmentations.
66        # Applied to the raw data since, labels are generated on the fly by the pseudo_labeler.
67        if self.transform is not None:
68            raw = self.transform(raw)[0]
69            if self.trafo_halo is not None:
70                raw = self.crop(raw)
71
72        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
73        with torch.no_grad():
74            # Ilastik needs uint as input, so normalize afterwards.
75            labels = self.pseudo_labeler(raw[None].to(self.labeler_device))[0]
76
77        # Normalize the raw data.
78        if self.raw_transform is not None:
79            raw = self.raw_transform(raw.cpu().detach().numpy())
80        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
81
82        if self.label_transform is not None:
83            labels = self.label_transform(labels)
84        labels = ensure_tensor_with_channels(labels, ndim=self._ndim)
85
86        return raw, labels
class PseudoLabelDataset(typing.Generic[+_T_co]):
11class PseudoLabelDataset(RawDataset):
12    """Dataset that uses a prediction function to provide raw data and pseudo labels for segmentation training.
13
14    The dataset loads a patch from the raw data and then applies the pseudo labeler to it to predict pseudo labels.
15    The raw data and pseudo labels are returned together as a sample for a batch.
16    The datataset supports all file formats that can be opened with `elf.io.open_file`, such as hdf5, zarr or n5.
17    Use `raw_path` to specify the path to the file and `raw_key` to specify the internal dataset.
18    It also supports regular image formats, such as .tif. For these cases set `raw_key=None`.
19
20    Args:
21        raw_path: The file path to the raw image data. May also be a list of file paths.
22        raw_key: The key to the internal dataset containing the raw data.
23        patch_shape: The patch shape for a training sample.
24        pseudo_labeler: The pseudo labeler. Must be a function that accepts the raw data as torch tensor
25            and that returns the predicted labels as torch tensor.
26        raw_transform: Transformation applied to the raw data of a sample.
27        label_transform: Transformation applied to the label data of a sample.
28        roi: Region of interest in the raw data.
29            If given, the raw data will only be loaded from the corresponding area.
30        dtype: The return data type of the raw data.
31        n_samples: The length of this dataset. If None, the length will be set to `len(raw_image_paths)`.
32        sampler: Sampler for rejecting samples according to a defined criterion.
33            The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
34        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
35        with_channels: Whether the raw data has channels.
36        labeler_device: The expected device for the pseudo labeler.
37    """
38    def __init__(
39        self,
40        raw_path: Union[List[Any], str, os.PathLike],
41        raw_key: Optional[str],
42        patch_shape: Tuple[int, ...],
43        pseudo_labeler: Callable,
44        raw_transform: Optional[Callable] = None,
45        label_transform: Optional[Callable] = None,
46        transform: Optional[Callable] = None,
47        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
48        dtype: torch.dtype = torch.float32,
49        n_samples: Optional[int] = None,
50        sampler: Optional[Callable] = None,
51        ndim: Optional[Union[int]] = None,
52        with_channels: bool = False,
53        labeler_device: Optional[Union[str, torch.device]] = None,
54    ):
55        super().__init__(
56            raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform, roi=roi,
57            dtype=dtype, n_samples=n_samples, sampler=sampler, ndim=ndim, with_channels=with_channels
58        )
59        self.pseudo_labeler = pseudo_labeler
60        self.label_transform = label_transform
61        self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device
62
63    def __getitem__(self, index):
64        raw = self._get_sample(index)
65
66        # Transform for augmentations.
67        # Applied to the raw data since, labels are generated on the fly by the pseudo_labeler.
68        if self.transform is not None:
69            raw = self.transform(raw)[0]
70            if self.trafo_halo is not None:
71                raw = self.crop(raw)
72
73        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
74        with torch.no_grad():
75            # Ilastik needs uint as input, so normalize afterwards.
76            labels = self.pseudo_labeler(raw[None].to(self.labeler_device))[0]
77
78        # Normalize the raw data.
79        if self.raw_transform is not None:
80            raw = self.raw_transform(raw.cpu().detach().numpy())
81        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
82
83        if self.label_transform is not None:
84            labels = self.label_transform(labels)
85        labels = ensure_tensor_with_channels(labels, ndim=self._ndim)
86
87        return raw, labels

Dataset that uses a prediction function to provide raw data and pseudo labels for segmentation training.

The dataset loads a patch from the raw data and then applies the pseudo labeler to it to predict pseudo labels. The raw data and pseudo labels are returned together as a sample for a batch. The datataset supports all file formats that can be opened with elf.io.open_file, such as hdf5, zarr or n5. Use raw_path to specify the path to the file and raw_key to specify the internal dataset. It also supports regular image formats, such as .tif. For these cases set raw_key=None.

Arguments:
  • raw_path: The file path to the raw image data. May also be a list of file paths.
  • raw_key: The key to the internal dataset containing the raw data.
  • patch_shape: The patch shape for a training sample.
  • pseudo_labeler: The pseudo labeler. Must be a function that accepts the raw data as torch tensor and that returns the predicted labels as torch tensor.
  • raw_transform: Transformation applied to the raw data of a sample.
  • label_transform: Transformation applied to the label data of a sample.
  • roi: Region of interest in the raw data. If given, the raw data will only be loaded from the corresponding area.
  • dtype: The return data type of the raw data.
  • n_samples: The length of this dataset. If None, the length will be set to len(raw_image_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data and label data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • with_channels: Whether the raw data has channels.
  • labeler_device: The expected device for the pseudo labeler.
PseudoLabelDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: Optional[str], patch_shape: Tuple[int, ...], pseudo_labeler: Callable, raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, transform: Optional[Callable] = None, roi: Union[slice, Tuple[slice, ...], NoneType] = None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, with_channels: bool = False, labeler_device: Union[str, torch.device, NoneType] = None)
38    def __init__(
39        self,
40        raw_path: Union[List[Any], str, os.PathLike],
41        raw_key: Optional[str],
42        patch_shape: Tuple[int, ...],
43        pseudo_labeler: Callable,
44        raw_transform: Optional[Callable] = None,
45        label_transform: Optional[Callable] = None,
46        transform: Optional[Callable] = None,
47        roi: Optional[Union[slice, Tuple[slice, ...]]] = None,
48        dtype: torch.dtype = torch.float32,
49        n_samples: Optional[int] = None,
50        sampler: Optional[Callable] = None,
51        ndim: Optional[Union[int]] = None,
52        with_channels: bool = False,
53        labeler_device: Optional[Union[str, torch.device]] = None,
54    ):
55        super().__init__(
56            raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform, roi=roi,
57            dtype=dtype, n_samples=n_samples, sampler=sampler, ndim=ndim, with_channels=with_channels
58        )
59        self.pseudo_labeler = pseudo_labeler
60        self.label_transform = label_transform
61        self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device
pseudo_labeler
label_transform
labeler_device