torch_em.self_training.pseudo_labeling

  1from typing import Callable, Literal, Optional, Union
  2
  3import torch
  4
  5
  6class DefaultPseudoLabeler:
  7    """Compute pseudo labels based on model predictions, typically from a teacher model.
  8
  9    Args:
 10        activation: Activation function applied to the teacher prediction.
 11        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
 12            If None is given no mask will be computed.
 13        threshold_from_both_sides: Whether to include both values bigger than the threshold
 14            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 15            The former should be used for binary labels, the latter for for multiclass labels.
 16        mask_channel: A specific channel to use for computing the confidence mask.
 17            By default the confidence mask is computed across all channels independently.
 18            This is useful, if only one of the channels encodes a probability.
 19    """
 20    def __init__(
 21        self,
 22        activation: Optional[torch.nn.Module] = None,
 23        confidence_threshold: Optional[float] = None,
 24        threshold_from_both_sides: bool = True,
 25        mask_channel: Optional[int] = None,
 26    ):
 27        self.activation = activation
 28        self.confidence_threshold = confidence_threshold
 29        self.threshold_from_both_sides = threshold_from_both_sides
 30        self.mask_channel = mask_channel
 31        # TODO serialize the class names and kwargs for activation instead
 32        self.init_kwargs = {
 33            "activation": None, "confidence_threshold": confidence_threshold,
 34            "threshold_from_both_sides": threshold_from_both_sides,
 35            "mask_channel": mask_channel,
 36        }
 37
 38    def _compute_label_mask_both_sides(self, pseudo_labels):
 39        upper_threshold = self.confidence_threshold
 40        lower_threshold = 1.0 - self.confidence_threshold
 41        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
 42        return mask
 43
 44    def _compute_label_mask_one_side(self, pseudo_labels):
 45        mask = (pseudo_labels >= self.confidence_threshold)
 46        return mask
 47
 48    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
 49        """Compute pseudo-labels.
 50
 51        Args:
 52            teacher: The teacher model.
 53            input_: The input for this batch.
 54
 55        Returns:
 56            The pseudo-labels.
 57        """
 58        pseudo_labels = teacher(input_)
 59        if self.activation is not None:
 60            pseudo_labels = self.activation(pseudo_labels)
 61        if self.confidence_threshold is None:
 62            label_mask = None
 63        else:
 64            mask_input = pseudo_labels if self.mask_channel is None\
 65                else pseudo_labels[self.mask_channel:(self.mask_channel+1)]
 66            label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\
 67                else self._compute_label_mask_one_side(mask_input)
 68            if self.mask_channel is not None:
 69                size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2)))
 70                label_mask = label_mask.expand(*size)
 71        return pseudo_labels, label_mask
 72
 73    def step(self, metric, epoch):
 74        pass
 75
 76
 77class ProbabilisticPseudoLabeler:
 78    """Compute pseudo labels from a Probabilistic UNet.
 79
 80    Args:
 81        activation: Activation function applied to the teacher prediction.
 82        confidence_threshold: Threshold for computing a mask for filterign the pseudo labels.
 83            If none is given no mask will be computed.
 84        threshold_from_both_sides: Whether to include both values bigger than the threshold
 85            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 86            The former should be used for binary labels, the latter for for multiclass labels.
 87        prior_samples: The number of times to sample from the model distribution per input.
 88        consensus_masking: Whether to activate consensus masking in the label filter.
 89            If False, the weighted consensus response (weighted per-pixel response) is returned.
 90            If True, the masked consensus response (complete aggrement of pixels) is returned.
 91    """
 92    def __init__(
 93        self,
 94        activation: Optional[torch.nn.Module] = None,
 95        confidence_threshold: Optional[float] = None,
 96        threshold_from_both_sides: bool = True,
 97        prior_samples: int = 16,
 98        consensus_masking: bool = False,
 99    ):
100        self.activation = activation
101        self.confidence_threshold = confidence_threshold
102        self.threshold_from_both_sides = threshold_from_both_sides
103        self.prior_samples = prior_samples
104        self.consensus_masking = consensus_masking
105        # TODO serialize the class names and kwargs for activation instead
106        self.init_kwargs = {
107            "activation": None, "confidence_threshold": confidence_threshold,
108            "threshold_from_both_sides": threshold_from_both_sides
109        }
110
111    def _compute_label_mask_both_sides(self, pseudo_labels):
112        upper_threshold = self.confidence_threshold
113        lower_threshold = 1.0 - self.confidence_threshold
114        mask = [
115            torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.))
116            for sample in pseudo_labels
117        ]
118        return mask
119
120    def _compute_label_mask_one_side(self, pseudo_labels):
121        mask = [
122            torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.))
123            for sample in pseudo_labels
124        ]
125        return mask
126
127    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
128        """Compute pseudo-labels.
129
130        Args:
131            teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`.
132            input_: The input for this batch.
133
134        Returns:
135            The pseudo-labels.
136        """
137        teacher.forward(input_)
138        if self.activation is not None:
139            pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)]
140        else:
141            pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)]
142        pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples
143
144        if self.confidence_threshold is None:
145            label_mask = None
146        else:
147            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \
148                else self._compute_label_mask_one_side(pseudo_labels)
149            label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples
150            if self.consensus_masking:
151                label_mask = torch.where(label_mask == 1, 1, 0)
152
153        return pseudo_labels, label_mask
154
155    def step(self, metric, epoch):
156        pass
157
158
159class ScheduledPseudoLabeler:
160    """
161    This class implements a scheduled pseudo-labeling mechanism, where pseudo labels
162    are generated from a teacher model's predictions, and the confidence threshold
163    for filtering the pseudo labels can be adjusted over time based on a performance
164    metric or a fixed schedule. It includes options for adjusting thresholds from
165    both sides (for binary classification) or from one side (for multiclass problems).
166    The threshold can be dynamically reduced to improve the quality of the pseudo labels
167    when the model performance does not improve for a given number of epochs (patience).
168
169    Args:
170        activation: Activation function applied to the teacher prediction.
171        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
172            If none is given no mask will be computed.
173        threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it,
174            or only values bigger than it in the mask. The former should be used for binary labels,
175            the latter for for multiclass labels.
176        mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
177            - 'min': A lower value of the monitored metric is considered better (e.g., loss).
178            - 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
179        factor Factor by which the confidence threshold is reduced when the performance stagnates.
180        patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
181        threshold: Threshold value for determining a significant improvement in the performance metric
182            to reset the patience counter. This can be relative (percentage improvement)
183            or absolute depending on `threshold_mode`.
184        threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel')
185            or an absolute improvement ('abs').
186        min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
187        eps: A small value to avoid floating-point precision errors during threshold comparison.
188        verbose: If True, prints messages when the confidence threshold is reduced.
189    """
190
191    def __init__(
192        self,
193        activation: Optional[Union[torch.nn.Module, Callable]] = None,
194        confidence_threshold: Optional[float] = None,
195        threshold_from_both_sides=True,
196        mode: Literal["min", "max"] = "min",
197        factor: float = 0.05,
198        patience: int = 10,
199        threshold: float = 1e-4,
200        threshold_mode: Literal["rel", "abs"] = "abs",
201        min_ct: float = 0.5,
202        eps: float = 1e-8,
203        verbose: bool = True,
204    ):
205        self.activation = activation
206        self.confidence_threshold = confidence_threshold
207        self.threshold_from_both_sides = threshold_from_both_sides
208        self.init_kwargs = {
209            "activation": None, "confidence_threshold": confidence_threshold,
210            "threshold_from_both_sides": threshold_from_both_sides
211        }
212        # scheduler arguments
213        if mode not in {"min", "max"}:
214            raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.")
215        self.mode = mode
216
217        if factor >= 1.0:
218            raise ValueError("Factor should be < 1.0.")
219        self.factor = factor
220
221        self.patience = patience
222        self.threshold = threshold
223
224        if threshold_mode not in {"rel", "abs"}:
225            raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.")
226        self.threshold_mode = threshold_mode
227
228        self.min_ct = min_ct
229        self.eps = eps
230        self.verbose = verbose
231
232        if mode == "min":
233            self.best = float("inf")
234        else:  # mode == "max":
235            self.best = float("-inf")
236
237        # self.best = 0
238        self.num_bad_epochs: int = 0
239        self.last_epoch = 0
240
241    def _compute_label_mask_both_sides(self, pseudo_labels):
242        upper_threshold = self.confidence_threshold
243        lower_threshold = 1.0 - self.confidence_threshold
244        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
245        return mask
246
247    def _compute_label_mask_one_side(self, pseudo_labels):
248        mask = (pseudo_labels >= self.confidence_threshold)
249        return mask
250
251    def __call__(self, teacher, input_):
252        pseudo_labels = teacher(input_)
253        if self.activation is not None:
254            pseudo_labels = self.activation(pseudo_labels)
255        if self.confidence_threshold is None:
256            label_mask = None
257        else:
258            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
259                else self._compute_label_mask_one_side(pseudo_labels)
260        return pseudo_labels, label_mask
261
262    def _is_better(self, a, best):
263        if self.mode == "min" and self.threshold_mode == "rel":
264            rel_epsilon = 1.0 - self.threshold
265            return a < best * rel_epsilon
266
267        elif self.mode == "min" and self.threshold_mode == "abs":
268            return a < best - self.threshold
269
270        elif self.mode == "max" and self.threshold_mode == "rel":
271            rel_epsilon = self.threshold + 1.0
272            return a > best * rel_epsilon
273
274        else:  # mode == 'max' and epsilon_mode == 'abs':
275            return a > best + self.threshold
276
277    def _reduce_ct(self, epoch):
278        old_ct = self.confidence_threshold
279        if self.threshold_mode == "rel":
280            new_ct = max(self.confidence_threshold * self.factor, self.min_ct)
281        else:  # threshold_mode == 'abs':
282            new_ct = max(self.confidence_threshold - self.factor, self.min_ct)
283        if old_ct - new_ct > self.eps:
284            self.confidence_threshold = new_ct
285        if self.verbose:
286            print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}")
287
288    def step(self, metric, epoch=None):
289        if epoch is None:
290            epoch = self.last_epoch + 1
291            self.last_epoch = epoch
292
293        # If the metric is None, reduce the confidence threshold every epoch
294        if metric is None:
295            if epoch == 0:
296                return
297            if epoch % self.patience == 0:
298                self._reduce_ct(epoch)
299            return
300
301        else:
302            current = float(metric)
303
304            if self._is_better(current, self.best):
305                self.best = current
306                self.num_bad_epochs = 0
307            else:
308                self.num_bad_epochs += 1
309
310            if self.num_bad_epochs > self.patience:
311                self._reduce_ct(epoch)
312                self.num_bad_epochs = 0
class DefaultPseudoLabeler:
 7class DefaultPseudoLabeler:
 8    """Compute pseudo labels based on model predictions, typically from a teacher model.
 9
10    Args:
11        activation: Activation function applied to the teacher prediction.
12        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
13            If None is given no mask will be computed.
14        threshold_from_both_sides: Whether to include both values bigger than the threshold
15            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
16            The former should be used for binary labels, the latter for for multiclass labels.
17        mask_channel: A specific channel to use for computing the confidence mask.
18            By default the confidence mask is computed across all channels independently.
19            This is useful, if only one of the channels encodes a probability.
20    """
21    def __init__(
22        self,
23        activation: Optional[torch.nn.Module] = None,
24        confidence_threshold: Optional[float] = None,
25        threshold_from_both_sides: bool = True,
26        mask_channel: Optional[int] = None,
27    ):
28        self.activation = activation
29        self.confidence_threshold = confidence_threshold
30        self.threshold_from_both_sides = threshold_from_both_sides
31        self.mask_channel = mask_channel
32        # TODO serialize the class names and kwargs for activation instead
33        self.init_kwargs = {
34            "activation": None, "confidence_threshold": confidence_threshold,
35            "threshold_from_both_sides": threshold_from_both_sides,
36            "mask_channel": mask_channel,
37        }
38
39    def _compute_label_mask_both_sides(self, pseudo_labels):
40        upper_threshold = self.confidence_threshold
41        lower_threshold = 1.0 - self.confidence_threshold
42        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
43        return mask
44
45    def _compute_label_mask_one_side(self, pseudo_labels):
46        mask = (pseudo_labels >= self.confidence_threshold)
47        return mask
48
49    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
50        """Compute pseudo-labels.
51
52        Args:
53            teacher: The teacher model.
54            input_: The input for this batch.
55
56        Returns:
57            The pseudo-labels.
58        """
59        pseudo_labels = teacher(input_)
60        if self.activation is not None:
61            pseudo_labels = self.activation(pseudo_labels)
62        if self.confidence_threshold is None:
63            label_mask = None
64        else:
65            mask_input = pseudo_labels if self.mask_channel is None\
66                else pseudo_labels[self.mask_channel:(self.mask_channel+1)]
67            label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\
68                else self._compute_label_mask_one_side(mask_input)
69            if self.mask_channel is not None:
70                size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2)))
71                label_mask = label_mask.expand(*size)
72        return pseudo_labels, label_mask
73
74    def step(self, metric, epoch):
75        pass

Compute pseudo labels based on model predictions, typically from a teacher model.

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If None is given no mask will be computed.
  • threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
  • mask_channel: A specific channel to use for computing the confidence mask. By default the confidence mask is computed across all channels independently. This is useful, if only one of the channels encodes a probability.
DefaultPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, mask_channel: Optional[int] = None)
21    def __init__(
22        self,
23        activation: Optional[torch.nn.Module] = None,
24        confidence_threshold: Optional[float] = None,
25        threshold_from_both_sides: bool = True,
26        mask_channel: Optional[int] = None,
27    ):
28        self.activation = activation
29        self.confidence_threshold = confidence_threshold
30        self.threshold_from_both_sides = threshold_from_both_sides
31        self.mask_channel = mask_channel
32        # TODO serialize the class names and kwargs for activation instead
33        self.init_kwargs = {
34            "activation": None, "confidence_threshold": confidence_threshold,
35            "threshold_from_both_sides": threshold_from_both_sides,
36            "mask_channel": mask_channel,
37        }
activation
confidence_threshold
threshold_from_both_sides
mask_channel
init_kwargs
def step(self, metric, epoch):
74    def step(self, metric, epoch):
75        pass
class ProbabilisticPseudoLabeler:
 78class ProbabilisticPseudoLabeler:
 79    """Compute pseudo labels from a Probabilistic UNet.
 80
 81    Args:
 82        activation: Activation function applied to the teacher prediction.
 83        confidence_threshold: Threshold for computing a mask for filterign the pseudo labels.
 84            If none is given no mask will be computed.
 85        threshold_from_both_sides: Whether to include both values bigger than the threshold
 86            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 87            The former should be used for binary labels, the latter for for multiclass labels.
 88        prior_samples: The number of times to sample from the model distribution per input.
 89        consensus_masking: Whether to activate consensus masking in the label filter.
 90            If False, the weighted consensus response (weighted per-pixel response) is returned.
 91            If True, the masked consensus response (complete aggrement of pixels) is returned.
 92    """
 93    def __init__(
 94        self,
 95        activation: Optional[torch.nn.Module] = None,
 96        confidence_threshold: Optional[float] = None,
 97        threshold_from_both_sides: bool = True,
 98        prior_samples: int = 16,
 99        consensus_masking: bool = False,
100    ):
101        self.activation = activation
102        self.confidence_threshold = confidence_threshold
103        self.threshold_from_both_sides = threshold_from_both_sides
104        self.prior_samples = prior_samples
105        self.consensus_masking = consensus_masking
106        # TODO serialize the class names and kwargs for activation instead
107        self.init_kwargs = {
108            "activation": None, "confidence_threshold": confidence_threshold,
109            "threshold_from_both_sides": threshold_from_both_sides
110        }
111
112    def _compute_label_mask_both_sides(self, pseudo_labels):
113        upper_threshold = self.confidence_threshold
114        lower_threshold = 1.0 - self.confidence_threshold
115        mask = [
116            torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.))
117            for sample in pseudo_labels
118        ]
119        return mask
120
121    def _compute_label_mask_one_side(self, pseudo_labels):
122        mask = [
123            torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.))
124            for sample in pseudo_labels
125        ]
126        return mask
127
128    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
129        """Compute pseudo-labels.
130
131        Args:
132            teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`.
133            input_: The input for this batch.
134
135        Returns:
136            The pseudo-labels.
137        """
138        teacher.forward(input_)
139        if self.activation is not None:
140            pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)]
141        else:
142            pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)]
143        pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples
144
145        if self.confidence_threshold is None:
146            label_mask = None
147        else:
148            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \
149                else self._compute_label_mask_one_side(pseudo_labels)
150            label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples
151            if self.consensus_masking:
152                label_mask = torch.where(label_mask == 1, 1, 0)
153
154        return pseudo_labels, label_mask
155
156    def step(self, metric, epoch):
157        pass

Compute pseudo labels from a Probabilistic UNet.

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed.
  • threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
  • prior_samples: The number of times to sample from the model distribution per input.
  • consensus_masking: Whether to activate consensus masking in the label filter. If False, the weighted consensus response (weighted per-pixel response) is returned. If True, the masked consensus response (complete aggrement of pixels) is returned.
ProbabilisticPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, prior_samples: int = 16, consensus_masking: bool = False)
 93    def __init__(
 94        self,
 95        activation: Optional[torch.nn.Module] = None,
 96        confidence_threshold: Optional[float] = None,
 97        threshold_from_both_sides: bool = True,
 98        prior_samples: int = 16,
 99        consensus_masking: bool = False,
100    ):
101        self.activation = activation
102        self.confidence_threshold = confidence_threshold
103        self.threshold_from_both_sides = threshold_from_both_sides
104        self.prior_samples = prior_samples
105        self.consensus_masking = consensus_masking
106        # TODO serialize the class names and kwargs for activation instead
107        self.init_kwargs = {
108            "activation": None, "confidence_threshold": confidence_threshold,
109            "threshold_from_both_sides": threshold_from_both_sides
110        }
activation
confidence_threshold
threshold_from_both_sides
prior_samples
consensus_masking
init_kwargs
def step(self, metric, epoch):
156    def step(self, metric, epoch):
157        pass
class ScheduledPseudoLabeler:
160class ScheduledPseudoLabeler:
161    """
162    This class implements a scheduled pseudo-labeling mechanism, where pseudo labels
163    are generated from a teacher model's predictions, and the confidence threshold
164    for filtering the pseudo labels can be adjusted over time based on a performance
165    metric or a fixed schedule. It includes options for adjusting thresholds from
166    both sides (for binary classification) or from one side (for multiclass problems).
167    The threshold can be dynamically reduced to improve the quality of the pseudo labels
168    when the model performance does not improve for a given number of epochs (patience).
169
170    Args:
171        activation: Activation function applied to the teacher prediction.
172        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
173            If none is given no mask will be computed.
174        threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it,
175            or only values bigger than it in the mask. The former should be used for binary labels,
176            the latter for for multiclass labels.
177        mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
178            - 'min': A lower value of the monitored metric is considered better (e.g., loss).
179            - 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
180        factor Factor by which the confidence threshold is reduced when the performance stagnates.
181        patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
182        threshold: Threshold value for determining a significant improvement in the performance metric
183            to reset the patience counter. This can be relative (percentage improvement)
184            or absolute depending on `threshold_mode`.
185        threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel')
186            or an absolute improvement ('abs').
187        min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
188        eps: A small value to avoid floating-point precision errors during threshold comparison.
189        verbose: If True, prints messages when the confidence threshold is reduced.
190    """
191
192    def __init__(
193        self,
194        activation: Optional[Union[torch.nn.Module, Callable]] = None,
195        confidence_threshold: Optional[float] = None,
196        threshold_from_both_sides=True,
197        mode: Literal["min", "max"] = "min",
198        factor: float = 0.05,
199        patience: int = 10,
200        threshold: float = 1e-4,
201        threshold_mode: Literal["rel", "abs"] = "abs",
202        min_ct: float = 0.5,
203        eps: float = 1e-8,
204        verbose: bool = True,
205    ):
206        self.activation = activation
207        self.confidence_threshold = confidence_threshold
208        self.threshold_from_both_sides = threshold_from_both_sides
209        self.init_kwargs = {
210            "activation": None, "confidence_threshold": confidence_threshold,
211            "threshold_from_both_sides": threshold_from_both_sides
212        }
213        # scheduler arguments
214        if mode not in {"min", "max"}:
215            raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.")
216        self.mode = mode
217
218        if factor >= 1.0:
219            raise ValueError("Factor should be < 1.0.")
220        self.factor = factor
221
222        self.patience = patience
223        self.threshold = threshold
224
225        if threshold_mode not in {"rel", "abs"}:
226            raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.")
227        self.threshold_mode = threshold_mode
228
229        self.min_ct = min_ct
230        self.eps = eps
231        self.verbose = verbose
232
233        if mode == "min":
234            self.best = float("inf")
235        else:  # mode == "max":
236            self.best = float("-inf")
237
238        # self.best = 0
239        self.num_bad_epochs: int = 0
240        self.last_epoch = 0
241
242    def _compute_label_mask_both_sides(self, pseudo_labels):
243        upper_threshold = self.confidence_threshold
244        lower_threshold = 1.0 - self.confidence_threshold
245        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
246        return mask
247
248    def _compute_label_mask_one_side(self, pseudo_labels):
249        mask = (pseudo_labels >= self.confidence_threshold)
250        return mask
251
252    def __call__(self, teacher, input_):
253        pseudo_labels = teacher(input_)
254        if self.activation is not None:
255            pseudo_labels = self.activation(pseudo_labels)
256        if self.confidence_threshold is None:
257            label_mask = None
258        else:
259            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
260                else self._compute_label_mask_one_side(pseudo_labels)
261        return pseudo_labels, label_mask
262
263    def _is_better(self, a, best):
264        if self.mode == "min" and self.threshold_mode == "rel":
265            rel_epsilon = 1.0 - self.threshold
266            return a < best * rel_epsilon
267
268        elif self.mode == "min" and self.threshold_mode == "abs":
269            return a < best - self.threshold
270
271        elif self.mode == "max" and self.threshold_mode == "rel":
272            rel_epsilon = self.threshold + 1.0
273            return a > best * rel_epsilon
274
275        else:  # mode == 'max' and epsilon_mode == 'abs':
276            return a > best + self.threshold
277
278    def _reduce_ct(self, epoch):
279        old_ct = self.confidence_threshold
280        if self.threshold_mode == "rel":
281            new_ct = max(self.confidence_threshold * self.factor, self.min_ct)
282        else:  # threshold_mode == 'abs':
283            new_ct = max(self.confidence_threshold - self.factor, self.min_ct)
284        if old_ct - new_ct > self.eps:
285            self.confidence_threshold = new_ct
286        if self.verbose:
287            print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}")
288
289    def step(self, metric, epoch=None):
290        if epoch is None:
291            epoch = self.last_epoch + 1
292            self.last_epoch = epoch
293
294        # If the metric is None, reduce the confidence threshold every epoch
295        if metric is None:
296            if epoch == 0:
297                return
298            if epoch % self.patience == 0:
299                self._reduce_ct(epoch)
300            return
301
302        else:
303            current = float(metric)
304
305            if self._is_better(current, self.best):
306                self.best = current
307                self.num_bad_epochs = 0
308            else:
309                self.num_bad_epochs += 1
310
311            if self.num_bad_epochs > self.patience:
312                self._reduce_ct(epoch)
313                self.num_bad_epochs = 0

This class implements a scheduled pseudo-labeling mechanism, where pseudo labels are generated from a teacher model's predictions, and the confidence threshold for filtering the pseudo labels can be adjusted over time based on a performance metric or a fixed schedule. It includes options for adjusting thresholds from both sides (for binary classification) or from one side (for multiclass problems). The threshold can be dynamically reduced to improve the quality of the pseudo labels when the model performance does not improve for a given number of epochs (patience).

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If none is given no mask will be computed.
  • threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it, or only values bigger than it in the mask. The former should be used for binary labels, the latter for for multiclass labels.
  • mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
    • 'min': A lower value of the monitored metric is considered better (e.g., loss).
    • 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
  • factor Factor by which the confidence threshold is reduced when the performance stagnates.
  • patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
  • threshold: Threshold value for determining a significant improvement in the performance metric to reset the patience counter. This can be relative (percentage improvement) or absolute depending on threshold_mode.
  • threshold_mode: Determines whether the threshold is interpreted as a relative improvement ('rel') or an absolute improvement ('abs').
  • min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
  • eps: A small value to avoid floating-point precision errors during threshold comparison.
  • verbose: If True, prints messages when the confidence threshold is reduced.
ScheduledPseudoLabeler( activation: Union[torch.nn.modules.module.Module, Callable, NoneType] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides=True, mode: Literal['min', 'max'] = 'min', factor: float = 0.05, patience: int = 10, threshold: float = 0.0001, threshold_mode: Literal['rel', 'abs'] = 'abs', min_ct: float = 0.5, eps: float = 1e-08, verbose: bool = True)
192    def __init__(
193        self,
194        activation: Optional[Union[torch.nn.Module, Callable]] = None,
195        confidence_threshold: Optional[float] = None,
196        threshold_from_both_sides=True,
197        mode: Literal["min", "max"] = "min",
198        factor: float = 0.05,
199        patience: int = 10,
200        threshold: float = 1e-4,
201        threshold_mode: Literal["rel", "abs"] = "abs",
202        min_ct: float = 0.5,
203        eps: float = 1e-8,
204        verbose: bool = True,
205    ):
206        self.activation = activation
207        self.confidence_threshold = confidence_threshold
208        self.threshold_from_both_sides = threshold_from_both_sides
209        self.init_kwargs = {
210            "activation": None, "confidence_threshold": confidence_threshold,
211            "threshold_from_both_sides": threshold_from_both_sides
212        }
213        # scheduler arguments
214        if mode not in {"min", "max"}:
215            raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.")
216        self.mode = mode
217
218        if factor >= 1.0:
219            raise ValueError("Factor should be < 1.0.")
220        self.factor = factor
221
222        self.patience = patience
223        self.threshold = threshold
224
225        if threshold_mode not in {"rel", "abs"}:
226            raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.")
227        self.threshold_mode = threshold_mode
228
229        self.min_ct = min_ct
230        self.eps = eps
231        self.verbose = verbose
232
233        if mode == "min":
234            self.best = float("inf")
235        else:  # mode == "max":
236            self.best = float("-inf")
237
238        # self.best = 0
239        self.num_bad_epochs: int = 0
240        self.last_epoch = 0
activation
confidence_threshold
threshold_from_both_sides
init_kwargs
mode
factor
patience
threshold
threshold_mode
min_ct
eps
verbose
num_bad_epochs: int
last_epoch
def step(self, metric, epoch=None):
289    def step(self, metric, epoch=None):
290        if epoch is None:
291            epoch = self.last_epoch + 1
292            self.last_epoch = epoch
293
294        # If the metric is None, reduce the confidence threshold every epoch
295        if metric is None:
296            if epoch == 0:
297                return
298            if epoch % self.patience == 0:
299                self._reduce_ct(epoch)
300            return
301
302        else:
303            current = float(metric)
304
305            if self._is_better(current, self.best):
306                self.best = current
307                self.num_bad_epochs = 0
308            else:
309                self.num_bad_epochs += 1
310
311            if self.num_bad_epochs > self.patience:
312                self._reduce_ct(epoch)
313                self.num_bad_epochs = 0