torch_em.data.pseudo_label_dataset

 1import os
 2from typing import Union, Tuple, Optional, List, Any
 3
 4import torch
 5
 6from .raw_dataset import RawDataset
 7from ..util import ensure_tensor_with_channels
 8
 9
10class PseudoLabelDataset(RawDataset):
11    def __init__(
12        self,
13        raw_path: Union[List[Any], str, os.PathLike],
14        raw_key: str,
15        patch_shape: Tuple[int, ...],
16        pseudo_labeler,
17        raw_transform=None,
18        label_transform=None,
19        transform=None,
20        roi=None,
21        dtype: torch.dtype = torch.float32,
22        n_samples: Optional[int] = None,
23        sampler=None,
24        ndim: Optional[Union[int]] = None,
25        with_channels: bool = False,
26        labeler_device: Optional[Union[str, torch.device]] = None,
27    ):
28        super().__init__(raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform,
29                         roi=roi, dtype=dtype, n_samples=n_samples, sampler=sampler,
30                         ndim=ndim, with_channels=with_channels)
31        self.pseudo_labeler = pseudo_labeler
32        self.label_transform = label_transform
33        self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device
34
35    def __getitem__(self, index):
36        raw = self._get_sample(index)
37
38        # transform for augmentations
39        # only applied to raw since labels are generated on the fly anyway by the pseudo_labeler
40        if self.transform is not None:
41            raw = self.transform(raw)[0]
42            if self.trafo_halo is not None:
43                raw = self.crop(raw)
44
45        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
46        with torch.no_grad():
47            labels = self.pseudo_labeler(
48                raw[None].to(self.labeler_device))[0]  # ilastik needs uint input, so normalize afterwards
49
50        # normalize after ilastik
51        if self.raw_transform is not None:
52            raw = self.raw_transform(
53                raw.cpu().detach().numpy()
54            )  # normalization functions need numpy array, self.transform already creates torch.tensor
55
56        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
57
58        if self.label_transform is not None:
59            labels = self.label_transform(labels)
60        labels = ensure_tensor_with_channels(labels, ndim=self._ndim)
61
62        return raw, labels
class PseudoLabelDataset(typing.Generic[+T_co]):
11class PseudoLabelDataset(RawDataset):
12    def __init__(
13        self,
14        raw_path: Union[List[Any], str, os.PathLike],
15        raw_key: str,
16        patch_shape: Tuple[int, ...],
17        pseudo_labeler,
18        raw_transform=None,
19        label_transform=None,
20        transform=None,
21        roi=None,
22        dtype: torch.dtype = torch.float32,
23        n_samples: Optional[int] = None,
24        sampler=None,
25        ndim: Optional[Union[int]] = None,
26        with_channels: bool = False,
27        labeler_device: Optional[Union[str, torch.device]] = None,
28    ):
29        super().__init__(raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform,
30                         roi=roi, dtype=dtype, n_samples=n_samples, sampler=sampler,
31                         ndim=ndim, with_channels=with_channels)
32        self.pseudo_labeler = pseudo_labeler
33        self.label_transform = label_transform
34        self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device
35
36    def __getitem__(self, index):
37        raw = self._get_sample(index)
38
39        # transform for augmentations
40        # only applied to raw since labels are generated on the fly anyway by the pseudo_labeler
41        if self.transform is not None:
42            raw = self.transform(raw)[0]
43            if self.trafo_halo is not None:
44                raw = self.crop(raw)
45
46        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
47        with torch.no_grad():
48            labels = self.pseudo_labeler(
49                raw[None].to(self.labeler_device))[0]  # ilastik needs uint input, so normalize afterwards
50
51        # normalize after ilastik
52        if self.raw_transform is not None:
53            raw = self.raw_transform(
54                raw.cpu().detach().numpy()
55            )  # normalization functions need numpy array, self.transform already creates torch.tensor
56
57        raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
58
59        if self.label_transform is not None:
60            labels = self.label_transform(labels)
61        labels = ensure_tensor_with_channels(labels, ndim=self._ndim)
62
63        return raw, labels
PseudoLabelDataset( raw_path: Union[List[Any], str, os.PathLike], raw_key: str, patch_shape: Tuple[int, ...], pseudo_labeler, raw_transform=None, label_transform=None, transform=None, roi=None, dtype: torch.dtype = torch.float32, n_samples: Optional[int] = None, sampler=None, ndim: Optional[int] = None, with_channels: bool = False, labeler_device: Union[str, torch.device, NoneType] = None)
12    def __init__(
13        self,
14        raw_path: Union[List[Any], str, os.PathLike],
15        raw_key: str,
16        patch_shape: Tuple[int, ...],
17        pseudo_labeler,
18        raw_transform=None,
19        label_transform=None,
20        transform=None,
21        roi=None,
22        dtype: torch.dtype = torch.float32,
23        n_samples: Optional[int] = None,
24        sampler=None,
25        ndim: Optional[Union[int]] = None,
26        with_channels: bool = False,
27        labeler_device: Optional[Union[str, torch.device]] = None,
28    ):
29        super().__init__(raw_path, raw_key, patch_shape, raw_transform=raw_transform, transform=transform,
30                         roi=roi, dtype=dtype, n_samples=n_samples, sampler=sampler,
31                         ndim=ndim, with_channels=with_channels)
32        self.pseudo_labeler = pseudo_labeler
33        self.label_transform = label_transform
34        self.labeler_device = next(pseudo_labeler.parameters()).device if labeler_device is None else labeler_device
pseudo_labeler
label_transform
labeler_device