torch_em.self_training.pseudo_labeling

  1from typing import Callable, Literal, Optional, Union
  2
  3import torch
  4
  5
  6class DefaultPseudoLabeler:
  7    """Compute pseudo labels based on model predictions, typically from a teacher model.
  8
  9    Args:
 10        activation: Activation function applied to the teacher prediction.
 11        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
 12            If None is given no mask will be computed.
 13        threshold_from_both_sides: Whether to include both values bigger than the threshold
 14            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 15            The former should be used for binary labels, the latter for for multiclass labels.
 16    """
 17    def __init__(
 18        self,
 19        activation: Optional[torch.nn.Module] = None,
 20        confidence_threshold: Optional[float] = None,
 21        threshold_from_both_sides: bool = True,
 22    ):
 23        self.activation = activation
 24        self.confidence_threshold = confidence_threshold
 25        self.threshold_from_both_sides = threshold_from_both_sides
 26        # TODO serialize the class names and kwargs for activation instead
 27        self.init_kwargs = {
 28            "activation": None, "confidence_threshold": confidence_threshold,
 29            "threshold_from_both_sides": threshold_from_both_sides
 30        }
 31
 32    def _compute_label_mask_both_sides(self, pseudo_labels):
 33        upper_threshold = self.confidence_threshold
 34        lower_threshold = 1.0 - self.confidence_threshold
 35        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
 36        return mask
 37
 38    def _compute_label_mask_one_side(self, pseudo_labels):
 39        mask = (pseudo_labels >= self.confidence_threshold)
 40        return mask
 41
 42    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
 43        """Compute pseudo-labels.
 44
 45        Args:
 46            teacher: The teacher model.
 47            input_: The input for this batch.
 48
 49        Returns:
 50            The pseudo-labels.
 51        """
 52        pseudo_labels = teacher(input_)
 53        if self.activation is not None:
 54            pseudo_labels = self.activation(pseudo_labels)
 55        if self.confidence_threshold is None:
 56            label_mask = None
 57        else:
 58            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
 59                else self._compute_label_mask_one_side(pseudo_labels)
 60        return pseudo_labels, label_mask
 61
 62    def step(self, metric, epoch):
 63        pass
 64
 65
 66class ProbabilisticPseudoLabeler:
 67    """Compute pseudo labels from a Probabilistic UNet.
 68
 69    Args:
 70        activation: Activation function applied to the teacher prediction.
 71        confidence_threshold: Threshold for computing a mask for filterign the pseudo labels.
 72            If none is given no mask will be computed.
 73        threshold_from_both_sides: Whether to include both values bigger than the threshold
 74            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 75            The former should be used for binary labels, the latter for for multiclass labels.
 76        prior_samples: The number of times to sample from the model distribution per input.
 77        consensus_masking: Whether to activate consensus masking in the label filter.
 78            If False, the weighted consensus response (weighted per-pixel response) is returned.
 79            If True, the masked consensus response (complete aggrement of pixels) is returned.
 80    """
 81    def __init__(
 82        self,
 83        activation: Optional[torch.nn.Module] = None,
 84        confidence_threshold: Optional[float] = None,
 85        threshold_from_both_sides: bool = True,
 86        prior_samples: int = 16,
 87        consensus_masking: bool = False,
 88    ):
 89        self.activation = activation
 90        self.confidence_threshold = confidence_threshold
 91        self.threshold_from_both_sides = threshold_from_both_sides
 92        self.prior_samples = prior_samples
 93        self.consensus_masking = consensus_masking
 94        # TODO serialize the class names and kwargs for activation instead
 95        self.init_kwargs = {
 96            "activation": None, "confidence_threshold": confidence_threshold,
 97            "threshold_from_both_sides": threshold_from_both_sides
 98        }
 99
100    def _compute_label_mask_both_sides(self, pseudo_labels):
101        upper_threshold = self.confidence_threshold
102        lower_threshold = 1.0 - self.confidence_threshold
103        mask = [
104            torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.))
105            for sample in pseudo_labels
106        ]
107        return mask
108
109    def _compute_label_mask_one_side(self, pseudo_labels):
110        mask = [
111            torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.))
112            for sample in pseudo_labels
113        ]
114        return mask
115
116    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
117        """Compute pseudo-labels.
118
119        Args:
120            teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`.
121            input_: The input for this batch.
122
123        Returns:
124            The pseudo-labels.
125        """
126        teacher.forward(input_)
127        if self.activation is not None:
128            pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)]
129        else:
130            pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)]
131        pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples
132
133        if self.confidence_threshold is None:
134            label_mask = None
135        else:
136            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \
137                else self._compute_label_mask_one_side(pseudo_labels)
138            label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples
139            if self.consensus_masking:
140                label_mask = torch.where(label_mask == 1, 1, 0)
141
142        return pseudo_labels, label_mask
143
144    def step(self, metric, epoch):
145        pass
146
147
148class ScheduledPseudoLabeler:
149    """
150    This class implements a scheduled pseudo-labeling mechanism, where pseudo labels
151    are generated from a teacher model's predictions, and the confidence threshold
152    for filtering the pseudo labels can be adjusted over time based on a performance
153    metric or a fixed schedule. It includes options for adjusting thresholds from
154    both sides (for binary classification) or from one side (for multiclass problems).
155    The threshold can be dynamically reduced to improve the quality of the pseudo labels
156    when the model performance does not improve for a given number of epochs (patience).
157
158    Args:
159        activation: Activation function applied to the teacher prediction.
160        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
161            If none is given no mask will be computed.
162        threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it,
163            or only values bigger than it in the mask. The former should be used for binary labels,
164            the latter for for multiclass labels.
165        mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
166            - 'min': A lower value of the monitored metric is considered better (e.g., loss).
167            - 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
168        factor Factor by which the confidence threshold is reduced when the performance stagnates.
169        patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
170        threshold: Threshold value for determining a significant improvement in the performance metric
171            to reset the patience counter. This can be relative (percentage improvement)
172            or absolute depending on `threshold_mode`.
173        threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel')
174            or an absolute improvement ('abs').
175        min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
176        eps: A small value to avoid floating-point precision errors during threshold comparison.
177        verbose: If True, prints messages when the confidence threshold is reduced.
178    """
179
180    def __init__(
181        self,
182        activation: Optional[Union[torch.nn.Module, Callable]] = None,
183        confidence_threshold: Optional[float] = None,
184        threshold_from_both_sides=True,
185        mode: Literal["min", "max"] = "min",
186        factor: float = 0.05,
187        patience: int = 10,
188        threshold: float = 1e-4,
189        threshold_mode: Literal["rel", "abs"] = "abs",
190        min_ct: float = 0.5,
191        eps: float = 1e-8,
192        verbose: bool = True,
193    ):
194        self.activation = activation
195        self.confidence_threshold = confidence_threshold
196        self.threshold_from_both_sides = threshold_from_both_sides
197        self.init_kwargs = {
198            "activation": None, "confidence_threshold": confidence_threshold,
199            "threshold_from_both_sides": threshold_from_both_sides
200        }
201        # scheduler arguments
202        if mode not in {"min", "max"}:
203            raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.")
204        self.mode = mode
205
206        if factor >= 1.0:
207            raise ValueError("Factor should be < 1.0.")
208        self.factor = factor
209
210        self.patience = patience
211        self.threshold = threshold
212
213        if threshold_mode not in {"rel", "abs"}:
214            raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.")
215        self.threshold_mode = threshold_mode
216
217        self.min_ct = min_ct
218        self.eps = eps
219        self.verbose = verbose
220
221        if mode == "min":
222            self.best = float("inf")
223        else:  # mode == "max":
224            self.best = float("-inf")
225
226        # self.best = 0
227        self.num_bad_epochs: int = 0
228        self.last_epoch = 0
229
230    def _compute_label_mask_both_sides(self, pseudo_labels):
231        upper_threshold = self.confidence_threshold
232        lower_threshold = 1.0 - self.confidence_threshold
233        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
234        return mask
235
236    def _compute_label_mask_one_side(self, pseudo_labels):
237        mask = (pseudo_labels >= self.confidence_threshold)
238        return mask
239
240    def __call__(self, teacher, input_):
241        pseudo_labels = teacher(input_)
242        if self.activation is not None:
243            pseudo_labels = self.activation(pseudo_labels)
244        if self.confidence_threshold is None:
245            label_mask = None
246        else:
247            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
248                else self._compute_label_mask_one_side(pseudo_labels)
249        return pseudo_labels, label_mask
250
251    def _is_better(self, a, best):
252        if self.mode == "min" and self.threshold_mode == "rel":
253            rel_epsilon = 1.0 - self.threshold
254            return a < best * rel_epsilon
255
256        elif self.mode == "min" and self.threshold_mode == "abs":
257            return a < best - self.threshold
258
259        elif self.mode == "max" and self.threshold_mode == "rel":
260            rel_epsilon = self.threshold + 1.0
261            return a > best * rel_epsilon
262
263        else:  # mode == 'max' and epsilon_mode == 'abs':
264            return a > best + self.threshold
265
266    def _reduce_ct(self, epoch):
267        old_ct = self.confidence_threshold
268        if self.threshold_mode == "rel":
269            new_ct = max(self.confidence_threshold * self.factor, self.min_ct)
270        else:  # threshold_mode == 'abs':
271            new_ct = max(self.confidence_threshold - self.factor, self.min_ct)
272        if old_ct - new_ct > self.eps:
273            self.confidence_threshold = new_ct
274        if self.verbose:
275            print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}")
276
277    def step(self, metric, epoch=None):
278        if epoch is None:
279            epoch = self.last_epoch + 1
280            self.last_epoch = epoch
281
282        # If the metric is None, reduce the confidence threshold every epoch
283        if metric is None:
284            if epoch == 0:
285                return
286            if epoch % self.patience == 0:
287                self._reduce_ct(epoch)
288            return
289
290        else:
291            current = float(metric)
292
293            if self._is_better(current, self.best):
294                self.best = current
295                self.num_bad_epochs = 0
296            else:
297                self.num_bad_epochs += 1
298
299            if self.num_bad_epochs > self.patience:
300                self._reduce_ct(epoch)
301                self.num_bad_epochs = 0
class DefaultPseudoLabeler:
 7class DefaultPseudoLabeler:
 8    """Compute pseudo labels based on model predictions, typically from a teacher model.
 9
10    Args:
11        activation: Activation function applied to the teacher prediction.
12        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
13            If None is given no mask will be computed.
14        threshold_from_both_sides: Whether to include both values bigger than the threshold
15            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
16            The former should be used for binary labels, the latter for for multiclass labels.
17    """
18    def __init__(
19        self,
20        activation: Optional[torch.nn.Module] = None,
21        confidence_threshold: Optional[float] = None,
22        threshold_from_both_sides: bool = True,
23    ):
24        self.activation = activation
25        self.confidence_threshold = confidence_threshold
26        self.threshold_from_both_sides = threshold_from_both_sides
27        # TODO serialize the class names and kwargs for activation instead
28        self.init_kwargs = {
29            "activation": None, "confidence_threshold": confidence_threshold,
30            "threshold_from_both_sides": threshold_from_both_sides
31        }
32
33    def _compute_label_mask_both_sides(self, pseudo_labels):
34        upper_threshold = self.confidence_threshold
35        lower_threshold = 1.0 - self.confidence_threshold
36        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
37        return mask
38
39    def _compute_label_mask_one_side(self, pseudo_labels):
40        mask = (pseudo_labels >= self.confidence_threshold)
41        return mask
42
43    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
44        """Compute pseudo-labels.
45
46        Args:
47            teacher: The teacher model.
48            input_: The input for this batch.
49
50        Returns:
51            The pseudo-labels.
52        """
53        pseudo_labels = teacher(input_)
54        if self.activation is not None:
55            pseudo_labels = self.activation(pseudo_labels)
56        if self.confidence_threshold is None:
57            label_mask = None
58        else:
59            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
60                else self._compute_label_mask_one_side(pseudo_labels)
61        return pseudo_labels, label_mask
62
63    def step(self, metric, epoch):
64        pass

Compute pseudo labels based on model predictions, typically from a teacher model.

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If None is given no mask will be computed.
  • threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
DefaultPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True)
18    def __init__(
19        self,
20        activation: Optional[torch.nn.Module] = None,
21        confidence_threshold: Optional[float] = None,
22        threshold_from_both_sides: bool = True,
23    ):
24        self.activation = activation
25        self.confidence_threshold = confidence_threshold
26        self.threshold_from_both_sides = threshold_from_both_sides
27        # TODO serialize the class names and kwargs for activation instead
28        self.init_kwargs = {
29            "activation": None, "confidence_threshold": confidence_threshold,
30            "threshold_from_both_sides": threshold_from_both_sides
31        }
activation
confidence_threshold
threshold_from_both_sides
init_kwargs
def step(self, metric, epoch):
63    def step(self, metric, epoch):
64        pass
class ProbabilisticPseudoLabeler:
 67class ProbabilisticPseudoLabeler:
 68    """Compute pseudo labels from a Probabilistic UNet.
 69
 70    Args:
 71        activation: Activation function applied to the teacher prediction.
 72        confidence_threshold: Threshold for computing a mask for filterign the pseudo labels.
 73            If none is given no mask will be computed.
 74        threshold_from_both_sides: Whether to include both values bigger than the threshold
 75            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 76            The former should be used for binary labels, the latter for for multiclass labels.
 77        prior_samples: The number of times to sample from the model distribution per input.
 78        consensus_masking: Whether to activate consensus masking in the label filter.
 79            If False, the weighted consensus response (weighted per-pixel response) is returned.
 80            If True, the masked consensus response (complete aggrement of pixels) is returned.
 81    """
 82    def __init__(
 83        self,
 84        activation: Optional[torch.nn.Module] = None,
 85        confidence_threshold: Optional[float] = None,
 86        threshold_from_both_sides: bool = True,
 87        prior_samples: int = 16,
 88        consensus_masking: bool = False,
 89    ):
 90        self.activation = activation
 91        self.confidence_threshold = confidence_threshold
 92        self.threshold_from_both_sides = threshold_from_both_sides
 93        self.prior_samples = prior_samples
 94        self.consensus_masking = consensus_masking
 95        # TODO serialize the class names and kwargs for activation instead
 96        self.init_kwargs = {
 97            "activation": None, "confidence_threshold": confidence_threshold,
 98            "threshold_from_both_sides": threshold_from_both_sides
 99        }
100
101    def _compute_label_mask_both_sides(self, pseudo_labels):
102        upper_threshold = self.confidence_threshold
103        lower_threshold = 1.0 - self.confidence_threshold
104        mask = [
105            torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.))
106            for sample in pseudo_labels
107        ]
108        return mask
109
110    def _compute_label_mask_one_side(self, pseudo_labels):
111        mask = [
112            torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.))
113            for sample in pseudo_labels
114        ]
115        return mask
116
117    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
118        """Compute pseudo-labels.
119
120        Args:
121            teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`.
122            input_: The input for this batch.
123
124        Returns:
125            The pseudo-labels.
126        """
127        teacher.forward(input_)
128        if self.activation is not None:
129            pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)]
130        else:
131            pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)]
132        pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples
133
134        if self.confidence_threshold is None:
135            label_mask = None
136        else:
137            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \
138                else self._compute_label_mask_one_side(pseudo_labels)
139            label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples
140            if self.consensus_masking:
141                label_mask = torch.where(label_mask == 1, 1, 0)
142
143        return pseudo_labels, label_mask
144
145    def step(self, metric, epoch):
146        pass

Compute pseudo labels from a Probabilistic UNet.

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed.
  • threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
  • prior_samples: The number of times to sample from the model distribution per input.
  • consensus_masking: Whether to activate consensus masking in the label filter. If False, the weighted consensus response (weighted per-pixel response) is returned. If True, the masked consensus response (complete aggrement of pixels) is returned.
ProbabilisticPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, prior_samples: int = 16, consensus_masking: bool = False)
82    def __init__(
83        self,
84        activation: Optional[torch.nn.Module] = None,
85        confidence_threshold: Optional[float] = None,
86        threshold_from_both_sides: bool = True,
87        prior_samples: int = 16,
88        consensus_masking: bool = False,
89    ):
90        self.activation = activation
91        self.confidence_threshold = confidence_threshold
92        self.threshold_from_both_sides = threshold_from_both_sides
93        self.prior_samples = prior_samples
94        self.consensus_masking = consensus_masking
95        # TODO serialize the class names and kwargs for activation instead
96        self.init_kwargs = {
97            "activation": None, "confidence_threshold": confidence_threshold,
98            "threshold_from_both_sides": threshold_from_both_sides
99        }
activation
confidence_threshold
threshold_from_both_sides
prior_samples
consensus_masking
init_kwargs
def step(self, metric, epoch):
145    def step(self, metric, epoch):
146        pass
class ScheduledPseudoLabeler:
149class ScheduledPseudoLabeler:
150    """
151    This class implements a scheduled pseudo-labeling mechanism, where pseudo labels
152    are generated from a teacher model's predictions, and the confidence threshold
153    for filtering the pseudo labels can be adjusted over time based on a performance
154    metric or a fixed schedule. It includes options for adjusting thresholds from
155    both sides (for binary classification) or from one side (for multiclass problems).
156    The threshold can be dynamically reduced to improve the quality of the pseudo labels
157    when the model performance does not improve for a given number of epochs (patience).
158
159    Args:
160        activation: Activation function applied to the teacher prediction.
161        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
162            If none is given no mask will be computed.
163        threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it,
164            or only values bigger than it in the mask. The former should be used for binary labels,
165            the latter for for multiclass labels.
166        mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
167            - 'min': A lower value of the monitored metric is considered better (e.g., loss).
168            - 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
169        factor Factor by which the confidence threshold is reduced when the performance stagnates.
170        patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
171        threshold: Threshold value for determining a significant improvement in the performance metric
172            to reset the patience counter. This can be relative (percentage improvement)
173            or absolute depending on `threshold_mode`.
174        threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel')
175            or an absolute improvement ('abs').
176        min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
177        eps: A small value to avoid floating-point precision errors during threshold comparison.
178        verbose: If True, prints messages when the confidence threshold is reduced.
179    """
180
181    def __init__(
182        self,
183        activation: Optional[Union[torch.nn.Module, Callable]] = None,
184        confidence_threshold: Optional[float] = None,
185        threshold_from_both_sides=True,
186        mode: Literal["min", "max"] = "min",
187        factor: float = 0.05,
188        patience: int = 10,
189        threshold: float = 1e-4,
190        threshold_mode: Literal["rel", "abs"] = "abs",
191        min_ct: float = 0.5,
192        eps: float = 1e-8,
193        verbose: bool = True,
194    ):
195        self.activation = activation
196        self.confidence_threshold = confidence_threshold
197        self.threshold_from_both_sides = threshold_from_both_sides
198        self.init_kwargs = {
199            "activation": None, "confidence_threshold": confidence_threshold,
200            "threshold_from_both_sides": threshold_from_both_sides
201        }
202        # scheduler arguments
203        if mode not in {"min", "max"}:
204            raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.")
205        self.mode = mode
206
207        if factor >= 1.0:
208            raise ValueError("Factor should be < 1.0.")
209        self.factor = factor
210
211        self.patience = patience
212        self.threshold = threshold
213
214        if threshold_mode not in {"rel", "abs"}:
215            raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.")
216        self.threshold_mode = threshold_mode
217
218        self.min_ct = min_ct
219        self.eps = eps
220        self.verbose = verbose
221
222        if mode == "min":
223            self.best = float("inf")
224        else:  # mode == "max":
225            self.best = float("-inf")
226
227        # self.best = 0
228        self.num_bad_epochs: int = 0
229        self.last_epoch = 0
230
231    def _compute_label_mask_both_sides(self, pseudo_labels):
232        upper_threshold = self.confidence_threshold
233        lower_threshold = 1.0 - self.confidence_threshold
234        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
235        return mask
236
237    def _compute_label_mask_one_side(self, pseudo_labels):
238        mask = (pseudo_labels >= self.confidence_threshold)
239        return mask
240
241    def __call__(self, teacher, input_):
242        pseudo_labels = teacher(input_)
243        if self.activation is not None:
244            pseudo_labels = self.activation(pseudo_labels)
245        if self.confidence_threshold is None:
246            label_mask = None
247        else:
248            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
249                else self._compute_label_mask_one_side(pseudo_labels)
250        return pseudo_labels, label_mask
251
252    def _is_better(self, a, best):
253        if self.mode == "min" and self.threshold_mode == "rel":
254            rel_epsilon = 1.0 - self.threshold
255            return a < best * rel_epsilon
256
257        elif self.mode == "min" and self.threshold_mode == "abs":
258            return a < best - self.threshold
259
260        elif self.mode == "max" and self.threshold_mode == "rel":
261            rel_epsilon = self.threshold + 1.0
262            return a > best * rel_epsilon
263
264        else:  # mode == 'max' and epsilon_mode == 'abs':
265            return a > best + self.threshold
266
267    def _reduce_ct(self, epoch):
268        old_ct = self.confidence_threshold
269        if self.threshold_mode == "rel":
270            new_ct = max(self.confidence_threshold * self.factor, self.min_ct)
271        else:  # threshold_mode == 'abs':
272            new_ct = max(self.confidence_threshold - self.factor, self.min_ct)
273        if old_ct - new_ct > self.eps:
274            self.confidence_threshold = new_ct
275        if self.verbose:
276            print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}")
277
278    def step(self, metric, epoch=None):
279        if epoch is None:
280            epoch = self.last_epoch + 1
281            self.last_epoch = epoch
282
283        # If the metric is None, reduce the confidence threshold every epoch
284        if metric is None:
285            if epoch == 0:
286                return
287            if epoch % self.patience == 0:
288                self._reduce_ct(epoch)
289            return
290
291        else:
292            current = float(metric)
293
294            if self._is_better(current, self.best):
295                self.best = current
296                self.num_bad_epochs = 0
297            else:
298                self.num_bad_epochs += 1
299
300            if self.num_bad_epochs > self.patience:
301                self._reduce_ct(epoch)
302                self.num_bad_epochs = 0

This class implements a scheduled pseudo-labeling mechanism, where pseudo labels are generated from a teacher model's predictions, and the confidence threshold for filtering the pseudo labels can be adjusted over time based on a performance metric or a fixed schedule. It includes options for adjusting thresholds from both sides (for binary classification) or from one side (for multiclass problems). The threshold can be dynamically reduced to improve the quality of the pseudo labels when the model performance does not improve for a given number of epochs (patience).

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If none is given no mask will be computed.
  • threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it, or only values bigger than it in the mask. The former should be used for binary labels, the latter for for multiclass labels.
  • mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
    • 'min': A lower value of the monitored metric is considered better (e.g., loss).
    • 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
  • factor Factor by which the confidence threshold is reduced when the performance stagnates.
  • patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
  • threshold: Threshold value for determining a significant improvement in the performance metric to reset the patience counter. This can be relative (percentage improvement) or absolute depending on threshold_mode.
  • threshold_mode: Determines whether the threshold is interpreted as a relative improvement ('rel') or an absolute improvement ('abs').
  • min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
  • eps: A small value to avoid floating-point precision errors during threshold comparison.
  • verbose: If True, prints messages when the confidence threshold is reduced.
ScheduledPseudoLabeler( activation: Union[torch.nn.modules.module.Module, Callable, NoneType] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides=True, mode: Literal['min', 'max'] = 'min', factor: float = 0.05, patience: int = 10, threshold: float = 0.0001, threshold_mode: Literal['rel', 'abs'] = 'abs', min_ct: float = 0.5, eps: float = 1e-08, verbose: bool = True)
181    def __init__(
182        self,
183        activation: Optional[Union[torch.nn.Module, Callable]] = None,
184        confidence_threshold: Optional[float] = None,
185        threshold_from_both_sides=True,
186        mode: Literal["min", "max"] = "min",
187        factor: float = 0.05,
188        patience: int = 10,
189        threshold: float = 1e-4,
190        threshold_mode: Literal["rel", "abs"] = "abs",
191        min_ct: float = 0.5,
192        eps: float = 1e-8,
193        verbose: bool = True,
194    ):
195        self.activation = activation
196        self.confidence_threshold = confidence_threshold
197        self.threshold_from_both_sides = threshold_from_both_sides
198        self.init_kwargs = {
199            "activation": None, "confidence_threshold": confidence_threshold,
200            "threshold_from_both_sides": threshold_from_both_sides
201        }
202        # scheduler arguments
203        if mode not in {"min", "max"}:
204            raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.")
205        self.mode = mode
206
207        if factor >= 1.0:
208            raise ValueError("Factor should be < 1.0.")
209        self.factor = factor
210
211        self.patience = patience
212        self.threshold = threshold
213
214        if threshold_mode not in {"rel", "abs"}:
215            raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.")
216        self.threshold_mode = threshold_mode
217
218        self.min_ct = min_ct
219        self.eps = eps
220        self.verbose = verbose
221
222        if mode == "min":
223            self.best = float("inf")
224        else:  # mode == "max":
225            self.best = float("-inf")
226
227        # self.best = 0
228        self.num_bad_epochs: int = 0
229        self.last_epoch = 0
activation
confidence_threshold
threshold_from_both_sides
init_kwargs
mode
factor
patience
threshold
threshold_mode
min_ct
eps
verbose
num_bad_epochs: int
last_epoch
def step(self, metric, epoch=None):
278    def step(self, metric, epoch=None):
279        if epoch is None:
280            epoch = self.last_epoch + 1
281            self.last_epoch = epoch
282
283        # If the metric is None, reduce the confidence threshold every epoch
284        if metric is None:
285            if epoch == 0:
286                return
287            if epoch % self.patience == 0:
288                self._reduce_ct(epoch)
289            return
290
291        else:
292            current = float(metric)
293
294            if self._is_better(current, self.best):
295                self.best = current
296                self.num_bad_epochs = 0
297            else:
298                self.num_bad_epochs += 1
299
300            if self.num_bad_epochs > self.patience:
301                self._reduce_ct(epoch)
302                self.num_bad_epochs = 0