torch_em.self_training.pseudo_labeling

  1from typing import Optional
  2
  3import torch
  4
  5
  6class DefaultPseudoLabeler:
  7    """Compute pseudo labels based on model predictions, typically from a teacher model.
  8
  9    Args:
 10        activation: Activation function applied to the teacher prediction.
 11        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
 12            If None is given no mask will be computed.
 13        threshold_from_both_sides: Whether to include both values bigger than the threshold
 14            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 15            The former should be used for binary labels, the latter for for multiclass labels.
 16    """
 17    def __init__(
 18        self,
 19        activation: Optional[torch.nn.Module] = None,
 20        confidence_threshold: Optional[float] = None,
 21        threshold_from_both_sides: bool = True,
 22    ):
 23        self.activation = activation
 24        self.confidence_threshold = confidence_threshold
 25        self.threshold_from_both_sides = threshold_from_both_sides
 26        # TODO serialize the class names and kwargs for activation instead
 27        self.init_kwargs = {
 28            "activation": None, "confidence_threshold": confidence_threshold,
 29            "threshold_from_both_sides": threshold_from_both_sides
 30        }
 31
 32    def _compute_label_mask_both_sides(self, pseudo_labels):
 33        upper_threshold = self.confidence_threshold
 34        lower_threshold = 1.0 - self.confidence_threshold
 35        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
 36        return mask
 37
 38    def _compute_label_mask_one_side(self, pseudo_labels):
 39        mask = (pseudo_labels >= self.confidence_threshold)
 40        return mask
 41
 42    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
 43        """Compute pseudo-labels.
 44
 45        Args:
 46            teacher: The teacher model.
 47            input_: The input for this batch.
 48
 49        Returns:
 50            The pseudo-labels.
 51        """
 52        pseudo_labels = teacher(input_)
 53        if self.activation is not None:
 54            pseudo_labels = self.activation(pseudo_labels)
 55        if self.confidence_threshold is None:
 56            label_mask = None
 57        else:
 58            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
 59                else self._compute_label_mask_one_side(pseudo_labels)
 60        return pseudo_labels, label_mask
 61
 62
 63class ProbabilisticPseudoLabeler:
 64    """Compute pseudo labels from a Probabilistic UNet.
 65
 66    Args:
 67        activation: Activation function applied to the teacher prediction.
 68        confidence_threshold: Threshold for computing a mask for filterign the pseudo labels.
 69            If none is given no mask will be computed.
 70        threshold_from_both_sides: Whether to include both values bigger than the threshold
 71            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 72            The former should be used for binary labels, the latter for for multiclass labels.
 73        prior_samples: The number of times to sample from the model distribution per input.
 74        consensus_masking: Whether to activate consensus masking in the label filter.
 75            If False, the weighted consensus response (weighted per-pixel response) is returned.
 76            If True, the masked consensus response (complete aggrement of pixels) is returned.
 77    """
 78    def __init__(
 79        self,
 80        activation: Optional[torch.nn.Module] = None,
 81        confidence_threshold: Optional[float] = None,
 82        threshold_from_both_sides: bool = True,
 83        prior_samples: int = 16,
 84        consensus_masking: bool = False,
 85    ):
 86        self.activation = activation
 87        self.confidence_threshold = confidence_threshold
 88        self.threshold_from_both_sides = threshold_from_both_sides
 89        self.prior_samples = prior_samples
 90        self.consensus_masking = consensus_masking
 91        # TODO serialize the class names and kwargs for activation instead
 92        self.init_kwargs = {
 93            "activation": None, "confidence_threshold": confidence_threshold,
 94            "threshold_from_both_sides": threshold_from_both_sides
 95        }
 96
 97    def _compute_label_mask_both_sides(self, pseudo_labels):
 98        upper_threshold = self.confidence_threshold
 99        lower_threshold = 1.0 - self.confidence_threshold
100        mask = [
101            torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.))
102            for sample in pseudo_labels
103        ]
104        return mask
105
106    def _compute_label_mask_one_side(self, pseudo_labels):
107        mask = [
108            torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.))
109            for sample in pseudo_labels
110        ]
111        return mask
112
113    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
114        """Compute pseudo-labels.
115
116        Args:
117            teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`.
118            input_: The input for this batch.
119
120        Returns:
121            The pseudo-labels.
122        """
123        teacher.forward(input_)
124        if self.activation is not None:
125            pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)]
126        else:
127            pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)]
128        pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples
129
130        if self.confidence_threshold is None:
131            label_mask = None
132        else:
133            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \
134                else self._compute_label_mask_one_side(pseudo_labels)
135            label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples
136            if self.consensus_masking:
137                label_mask = torch.where(label_mask == 1, 1, 0)
138
139        return pseudo_labels, label_mask
class DefaultPseudoLabeler:
 7class DefaultPseudoLabeler:
 8    """Compute pseudo labels based on model predictions, typically from a teacher model.
 9
10    Args:
11        activation: Activation function applied to the teacher prediction.
12        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
13            If None is given no mask will be computed.
14        threshold_from_both_sides: Whether to include both values bigger than the threshold
15            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
16            The former should be used for binary labels, the latter for for multiclass labels.
17    """
18    def __init__(
19        self,
20        activation: Optional[torch.nn.Module] = None,
21        confidence_threshold: Optional[float] = None,
22        threshold_from_both_sides: bool = True,
23    ):
24        self.activation = activation
25        self.confidence_threshold = confidence_threshold
26        self.threshold_from_both_sides = threshold_from_both_sides
27        # TODO serialize the class names and kwargs for activation instead
28        self.init_kwargs = {
29            "activation": None, "confidence_threshold": confidence_threshold,
30            "threshold_from_both_sides": threshold_from_both_sides
31        }
32
33    def _compute_label_mask_both_sides(self, pseudo_labels):
34        upper_threshold = self.confidence_threshold
35        lower_threshold = 1.0 - self.confidence_threshold
36        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
37        return mask
38
39    def _compute_label_mask_one_side(self, pseudo_labels):
40        mask = (pseudo_labels >= self.confidence_threshold)
41        return mask
42
43    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
44        """Compute pseudo-labels.
45
46        Args:
47            teacher: The teacher model.
48            input_: The input for this batch.
49
50        Returns:
51            The pseudo-labels.
52        """
53        pseudo_labels = teacher(input_)
54        if self.activation is not None:
55            pseudo_labels = self.activation(pseudo_labels)
56        if self.confidence_threshold is None:
57            label_mask = None
58        else:
59            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
60                else self._compute_label_mask_one_side(pseudo_labels)
61        return pseudo_labels, label_mask

Compute pseudo labels based on model predictions, typically from a teacher model.

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If None is given no mask will be computed.
  • threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
DefaultPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True)
18    def __init__(
19        self,
20        activation: Optional[torch.nn.Module] = None,
21        confidence_threshold: Optional[float] = None,
22        threshold_from_both_sides: bool = True,
23    ):
24        self.activation = activation
25        self.confidence_threshold = confidence_threshold
26        self.threshold_from_both_sides = threshold_from_both_sides
27        # TODO serialize the class names and kwargs for activation instead
28        self.init_kwargs = {
29            "activation": None, "confidence_threshold": confidence_threshold,
30            "threshold_from_both_sides": threshold_from_both_sides
31        }
activation
confidence_threshold
threshold_from_both_sides
init_kwargs
class ProbabilisticPseudoLabeler:
 64class ProbabilisticPseudoLabeler:
 65    """Compute pseudo labels from a Probabilistic UNet.
 66
 67    Args:
 68        activation: Activation function applied to the teacher prediction.
 69        confidence_threshold: Threshold for computing a mask for filterign the pseudo labels.
 70            If none is given no mask will be computed.
 71        threshold_from_both_sides: Whether to include both values bigger than the threshold
 72            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 73            The former should be used for binary labels, the latter for for multiclass labels.
 74        prior_samples: The number of times to sample from the model distribution per input.
 75        consensus_masking: Whether to activate consensus masking in the label filter.
 76            If False, the weighted consensus response (weighted per-pixel response) is returned.
 77            If True, the masked consensus response (complete aggrement of pixels) is returned.
 78    """
 79    def __init__(
 80        self,
 81        activation: Optional[torch.nn.Module] = None,
 82        confidence_threshold: Optional[float] = None,
 83        threshold_from_both_sides: bool = True,
 84        prior_samples: int = 16,
 85        consensus_masking: bool = False,
 86    ):
 87        self.activation = activation
 88        self.confidence_threshold = confidence_threshold
 89        self.threshold_from_both_sides = threshold_from_both_sides
 90        self.prior_samples = prior_samples
 91        self.consensus_masking = consensus_masking
 92        # TODO serialize the class names and kwargs for activation instead
 93        self.init_kwargs = {
 94            "activation": None, "confidence_threshold": confidence_threshold,
 95            "threshold_from_both_sides": threshold_from_both_sides
 96        }
 97
 98    def _compute_label_mask_both_sides(self, pseudo_labels):
 99        upper_threshold = self.confidence_threshold
100        lower_threshold = 1.0 - self.confidence_threshold
101        mask = [
102            torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.))
103            for sample in pseudo_labels
104        ]
105        return mask
106
107    def _compute_label_mask_one_side(self, pseudo_labels):
108        mask = [
109            torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.))
110            for sample in pseudo_labels
111        ]
112        return mask
113
114    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
115        """Compute pseudo-labels.
116
117        Args:
118            teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`.
119            input_: The input for this batch.
120
121        Returns:
122            The pseudo-labels.
123        """
124        teacher.forward(input_)
125        if self.activation is not None:
126            pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)]
127        else:
128            pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)]
129        pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples
130
131        if self.confidence_threshold is None:
132            label_mask = None
133        else:
134            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \
135                else self._compute_label_mask_one_side(pseudo_labels)
136            label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples
137            if self.consensus_masking:
138                label_mask = torch.where(label_mask == 1, 1, 0)
139
140        return pseudo_labels, label_mask

Compute pseudo labels from a Probabilistic UNet.

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed.
  • threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
  • prior_samples: The number of times to sample from the model distribution per input.
  • consensus_masking: Whether to activate consensus masking in the label filter. If False, the weighted consensus response (weighted per-pixel response) is returned. If True, the masked consensus response (complete aggrement of pixels) is returned.
ProbabilisticPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, prior_samples: int = 16, consensus_masking: bool = False)
79    def __init__(
80        self,
81        activation: Optional[torch.nn.Module] = None,
82        confidence_threshold: Optional[float] = None,
83        threshold_from_both_sides: bool = True,
84        prior_samples: int = 16,
85        consensus_masking: bool = False,
86    ):
87        self.activation = activation
88        self.confidence_threshold = confidence_threshold
89        self.threshold_from_both_sides = threshold_from_both_sides
90        self.prior_samples = prior_samples
91        self.consensus_masking = consensus_masking
92        # TODO serialize the class names and kwargs for activation instead
93        self.init_kwargs = {
94            "activation": None, "confidence_threshold": confidence_threshold,
95            "threshold_from_both_sides": threshold_from_both_sides
96        }
activation
confidence_threshold
threshold_from_both_sides
prior_samples
consensus_masking
init_kwargs