torch_em.self_training.pseudo_labeling

  1from typing import Callable, Literal, Optional, Union
  2
  3import torch
  4import numpy as np
  5
  6
  7class DefaultPseudoLabeler:
  8    """Compute pseudo labels based on model predictions, typically from a teacher model.
  9
 10    Args:
 11        activation: Activation function applied to the teacher prediction.
 12        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
 13            If None is given no mask will be computed.
 14        threshold_from_both_sides: Whether to include both values bigger than the threshold
 15            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 16            The former should be used for binary labels, the latter for for multiclass labels.
 17        mask_channel: A specific channel to use for computing the confidence mask.
 18            By default the confidence mask is computed across all channels independently.
 19            This is useful, if only one of the channels encodes a probability.
 20    """
 21    def __init__(
 22        self,
 23        activation: Optional[torch.nn.Module] = None,
 24        confidence_threshold: Optional[float] = None,
 25        threshold_from_both_sides: bool = True,
 26        mask_channel: Optional[int] = None,
 27    ):
 28        self.activation = activation
 29        self.confidence_threshold = confidence_threshold
 30        self.threshold_from_both_sides = threshold_from_both_sides
 31        self.mask_channel = mask_channel
 32        # TODO serialize the class names and kwargs for activation instead
 33        self.init_kwargs = {
 34            "activation": None, "confidence_threshold": confidence_threshold,
 35            "threshold_from_both_sides": threshold_from_both_sides,
 36            "mask_channel": mask_channel,
 37        }
 38
 39    def _compute_label_mask_both_sides(self, pseudo_labels):
 40        upper_threshold = self.confidence_threshold
 41        lower_threshold = 1.0 - self.confidence_threshold
 42        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
 43        return mask
 44
 45    def _compute_label_mask_one_side(self, pseudo_labels):
 46        mask = (pseudo_labels >= self.confidence_threshold)
 47        return mask
 48
 49    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
 50        """Compute pseudo-labels.
 51
 52        Args:
 53            teacher: The teacher model.
 54            input_: The input for this batch.
 55
 56        Returns:
 57            The pseudo-labels.
 58        """
 59        pseudo_labels = teacher(input_)
 60        if self.activation is not None:
 61            pseudo_labels = self.activation(pseudo_labels)
 62        if self.confidence_threshold is None:
 63            label_mask = None
 64        else:
 65            mask_input = pseudo_labels if self.mask_channel is None\
 66                else pseudo_labels[self.mask_channel:(self.mask_channel+1)]
 67            label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\
 68                else self._compute_label_mask_one_side(mask_input)
 69            if self.mask_channel is not None:
 70                size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2)))
 71                label_mask = label_mask.expand(*size)
 72        return pseudo_labels, label_mask
 73
 74    def step(self, metric, epoch):
 75        pass
 76
 77
 78class ProbabilisticPseudoLabeler:
 79    """Compute pseudo labels from a Probabilistic UNet.
 80
 81    Args:
 82        activation: Activation function applied to the teacher prediction.
 83        confidence_threshold: Threshold for computing a mask for filterign the pseudo labels.
 84            If none is given no mask will be computed.
 85        threshold_from_both_sides: Whether to include both values bigger than the threshold
 86            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 87            The former should be used for binary labels, the latter for for multiclass labels.
 88        prior_samples: The number of times to sample from the model distribution per input.
 89        consensus_masking: Whether to activate consensus masking in the label filter.
 90            If False, the weighted consensus response (weighted per-pixel response) is returned.
 91            If True, the masked consensus response (complete aggrement of pixels) is returned.
 92    """
 93    def __init__(
 94        self,
 95        activation: Optional[torch.nn.Module] = None,
 96        confidence_threshold: Optional[float] = None,
 97        threshold_from_both_sides: bool = True,
 98        prior_samples: int = 16,
 99        consensus_masking: bool = False,
100    ):
101        self.activation = activation
102        self.confidence_threshold = confidence_threshold
103        self.threshold_from_both_sides = threshold_from_both_sides
104        self.prior_samples = prior_samples
105        self.consensus_masking = consensus_masking
106        # TODO serialize the class names and kwargs for activation instead
107        self.init_kwargs = {
108            "activation": None, "confidence_threshold": confidence_threshold,
109            "threshold_from_both_sides": threshold_from_both_sides
110        }
111
112    def _compute_label_mask_both_sides(self, pseudo_labels):
113        upper_threshold = self.confidence_threshold
114        lower_threshold = 1.0 - self.confidence_threshold
115        mask = [
116            torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.))
117            for sample in pseudo_labels
118        ]
119        return mask
120
121    def _compute_label_mask_one_side(self, pseudo_labels):
122        mask = [
123            torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.))
124            for sample in pseudo_labels
125        ]
126        return mask
127
128    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
129        """Compute pseudo-labels.
130
131        Args:
132            teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`.
133            input_: The input for this batch.
134
135        Returns:
136            The pseudo-labels.
137        """
138        teacher.forward(input_)
139        if self.activation is not None:
140            pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)]
141        else:
142            pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)]
143        pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples
144
145        if self.confidence_threshold is None:
146            label_mask = None
147        else:
148            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \
149                else self._compute_label_mask_one_side(pseudo_labels)
150            label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples
151            if self.consensus_masking:
152                label_mask = torch.where(label_mask == 1, 1, 0)
153
154        return pseudo_labels, label_mask
155
156    def step(self, metric, epoch):
157        pass
158
159
160class ScheduledPseudoLabeler:
161    """
162    Implement scheduled pseudo-labeling with dynamic confidence-threshold updates.
163
164    Pseudo labels are generated from a teacher model prediction and can be filtered
165    by a confidence mask. The confidence threshold can be adapted over time either
166    by decreasing it based on a monitored metric (plateau behavior) or by increasing
167    it with a fixed epoch schedule.
168
169    Args:
170        activation: Activation function applied to the teacher prediction.
171        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
172            If None is given, no mask will be computed.
173        increase: If True, increase the confidence threshold over time according to
174            a fixed schedule. If False, decrease it based on plateau detection.
175        last_step_epoch: Last epoch at which threshold increase is allowed when
176            `increase=True`.
177        threshold_from_both_sides: Whether to include values larger than the
178            threshold and smaller than 1 - the threshold in the mask, or only values
179            larger than the threshold. The former should be used for binary labels,
180            the latter for multiclass labels.
181        mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
182            - 'min': A lower value of the monitored metric is considered better (e.g., loss).
183            - 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
184        factor: Update size for confidence-threshold scheduling. Interpreted as a
185            multiplicative factor for `threshold_mode='rel'` and as an additive step
186            for `threshold_mode='abs'`.
187        patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
188        threshold: Threshold value for determining a significant improvement in the performance metric
189            to reset the patience counter. This can be relative (percentage improvement)
190            or absolute depending on `threshold_mode`.
191        threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel')
192            or an absolute improvement ('abs').
193        min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
194        max_ct: Maximum allowed confidence threshold. The threshold will not be increased above this value.
195        eps: A small value to avoid floating-point precision errors during threshold comparison.
196        warm_up_epochs: Number of warm-up epochs. At the end of warm-up,
197            `confidence_threshold` is set to `max_ct`. This is intended for
198            decreasing-threshold scheduling (`increase=False`).
199        mask_channel: Specific channel to use for confidence masking. Currently,
200            only None is supported.
201        verbose: If True, prints messages when the confidence threshold is updated.
202    """
203
204    def __init__(
205        self,
206        activation: Optional[Union[torch.nn.Module, Callable]] = None,
207        confidence_threshold: Optional[float] = None,
208        increase: bool = False,
209        last_step_epoch: int = None,
210        threshold_from_both_sides: bool = True,
211        mode: Literal["min", "max"] = "min",
212        factor: float = 0.05,
213        patience: int = 10,
214        threshold: float = 1e-4,
215        threshold_mode: Literal["rel", "abs"] = "abs",
216        min_ct: float = 0.5,
217        max_ct: float = 0.95,
218        eps: float = 1e-8,
219        warm_up_epochs: int = 0,
220        mask_channel: Optional[int] = None,
221        verbose: bool = True,
222    ):
223        self.activation = activation
224        self.confidence_threshold = confidence_threshold
225        self.increase = increase
226        self.threshold_from_both_sides = threshold_from_both_sides
227        self.init_kwargs = {
228            "activation": None, "confidence_threshold": confidence_threshold,
229            "threshold_from_both_sides": threshold_from_both_sides,
230            "mask_channel": mask_channel,
231        }
232        # scheduler arguments
233        if mode not in {"min", "max"}:
234            raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.")
235        self.mode = mode
236
237        assert factor < 1, f"Factor must be smaller than 1, got {factor}"
238        self.factor = factor
239
240        self.patience = patience
241        self.threshold = threshold
242
243        if threshold_mode not in {"rel", "abs"}:
244            raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.")
245        self.threshold_mode = threshold_mode
246
247        self.min_ct = min_ct
248        self.max_ct = max_ct
249        self.eps = eps
250        self.warm_up_epochs = warm_up_epochs
251
252        if self.increase and self.warm_up_epochs > 0:
253            raise ValueError("warm_up_epochs > 0 is only supported when increase=False.")
254
255        # TODO implement mask_channel functionality; for now only None is supported
256        if mask_channel is not None:
257            raise NotImplementedError("mask_channel is not implemented yet; only None is supported.")
258        self.mask_channel = mask_channel
259        self.verbose = verbose
260
261        if mode == "min":
262            self.best = float("inf")
263        else:  # mode == "max":
264            self.best = float("-inf")
265
266        # self.best = 0
267        self.num_bad_epochs: int = 0
268        self.last_epoch = 0
269
270        if self.increase:
271            self.last_step_epoch = last_step_epoch
272
273            n_ct = len(np.arange(self.min_ct, self.max_ct + self.factor / 2, self.factor))
274            n_increments = n_ct - 1
275
276            if n_increments <= 0:
277                # nothing to increase
278                self.step_epoch = 0
279            else:
280                # compute initial step size; enforce minimum step size of 1
281                self.step_epoch = max(1, int(np.floor(self.last_step_epoch / n_increments)))
282
283                # ensure max_ct is reachable
284                required_epochs = self.step_epoch * n_increments
285                if self.last_step_epoch < required_epochs:
286                    self.last_step_epoch = required_epochs
287
288            print(
289                f"ScheduledPseudoLabeler: Increasing confidence_threshold every {self.step_epoch} epochs;",
290                f"until epoch {self.last_step_epoch}"
291            )
292
293    def _compute_label_mask_both_sides(self, pseudo_labels):
294        upper_threshold = self.confidence_threshold
295        lower_threshold = 1.0 - self.confidence_threshold
296        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
297        return mask
298
299    def _compute_label_mask_one_side(self, pseudo_labels):
300        mask = (pseudo_labels >= self.confidence_threshold)
301        return mask
302
303    def __call__(self, teacher, input_):
304        pseudo_labels = teacher(input_)
305        if self.activation is not None:
306            pseudo_labels = self.activation(pseudo_labels)
307        if self.confidence_threshold is None:
308            label_mask = None
309        else:
310            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
311                else self._compute_label_mask_one_side(pseudo_labels)
312        return pseudo_labels, label_mask
313
314    def end_warm_up(self):
315        self.confidence_threshold = self.max_ct
316        print(f"End of warm-up phase reached: setting confidence threshold to {self.confidence_threshold}")
317
318    def _is_better(self, a, best):
319        if self.mode == "min" and self.threshold_mode == "rel":
320            rel_epsilon = 1.0 - self.threshold
321            return a < best * rel_epsilon
322
323        elif self.mode == "min" and self.threshold_mode == "abs":
324            return a < best - self.threshold
325
326        elif self.mode == "max" and self.threshold_mode == "rel":
327            rel_epsilon = self.threshold + 1.0
328            return a > best * rel_epsilon
329
330        else:  # mode == 'max' and epsilon_mode == 'abs':
331            return a > best + self.threshold
332
333    def _reduce_ct(self, epoch):
334        old_ct = self.confidence_threshold
335        if self.threshold_mode == "rel":
336            new_ct = max(self.confidence_threshold * (1-self.factor), self.min_ct)
337        else:  # threshold_mode == 'abs':
338            new_ct = max(self.confidence_threshold - self.factor, self.min_ct)
339        if abs(old_ct - new_ct) > self.eps:
340            self.confidence_threshold = new_ct
341        if self.verbose:
342            print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}")
343
344    def decrease_step(self, metric, epoch=None):
345        if epoch is None:
346            epoch = self.last_epoch + 1
347            self.last_epoch = epoch
348
349        # If the metric is None, reduce the confidence threshold every epoch
350        if metric is None:
351            if epoch == 0:
352                return
353            if epoch % self.patience == 0:
354                self._reduce_ct(epoch)
355            return
356
357        else:
358            current = float(metric)
359
360            if self._is_better(current, self.best):
361                self.best = current
362                self.num_bad_epochs = 0
363            else:
364                self.num_bad_epochs += 1
365
366            if self.num_bad_epochs > self.patience:
367                self._reduce_ct(epoch)
368                self.num_bad_epochs = 0
369
370    def _increase_ct(self, epoch):
371        old_ct = self.confidence_threshold
372        if self.threshold_mode == "rel":
373            new_ct = min(self.confidence_threshold * (1+self.factor), self.max_ct)
374        else:  # threshold_mode == 'abs':
375            new_ct = min(self.confidence_threshold + self.factor, self.max_ct)
376        if abs(old_ct - new_ct) > self.eps:
377            self.confidence_threshold = new_ct
378        if self.verbose:
379            print(f"Epoch {epoch}: increase confidence threshold from {old_ct} to {self.confidence_threshold}")
380
381    def increase_step(self, epoch):
382        if epoch > self.last_step_epoch:
383            return
384
385        if epoch % self.step_epoch == 0 and epoch != 0:
386            self._increase_ct(epoch)
387
388    def step(self, metric=None, epoch=None):
389        if epoch == self.warm_up_epochs and self.warm_up_epochs > 0:
390            self.end_warm_up()
391        elif epoch > self.warm_up_epochs:
392            if self.increase:
393                self.increase_step(epoch)
394            else:
395                self.decrease_step(metric, epoch)
396        else:
397            return
class DefaultPseudoLabeler:
 8class DefaultPseudoLabeler:
 9    """Compute pseudo labels based on model predictions, typically from a teacher model.
10
11    Args:
12        activation: Activation function applied to the teacher prediction.
13        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
14            If None is given no mask will be computed.
15        threshold_from_both_sides: Whether to include both values bigger than the threshold
16            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
17            The former should be used for binary labels, the latter for for multiclass labels.
18        mask_channel: A specific channel to use for computing the confidence mask.
19            By default the confidence mask is computed across all channels independently.
20            This is useful, if only one of the channels encodes a probability.
21    """
22    def __init__(
23        self,
24        activation: Optional[torch.nn.Module] = None,
25        confidence_threshold: Optional[float] = None,
26        threshold_from_both_sides: bool = True,
27        mask_channel: Optional[int] = None,
28    ):
29        self.activation = activation
30        self.confidence_threshold = confidence_threshold
31        self.threshold_from_both_sides = threshold_from_both_sides
32        self.mask_channel = mask_channel
33        # TODO serialize the class names and kwargs for activation instead
34        self.init_kwargs = {
35            "activation": None, "confidence_threshold": confidence_threshold,
36            "threshold_from_both_sides": threshold_from_both_sides,
37            "mask_channel": mask_channel,
38        }
39
40    def _compute_label_mask_both_sides(self, pseudo_labels):
41        upper_threshold = self.confidence_threshold
42        lower_threshold = 1.0 - self.confidence_threshold
43        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
44        return mask
45
46    def _compute_label_mask_one_side(self, pseudo_labels):
47        mask = (pseudo_labels >= self.confidence_threshold)
48        return mask
49
50    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
51        """Compute pseudo-labels.
52
53        Args:
54            teacher: The teacher model.
55            input_: The input for this batch.
56
57        Returns:
58            The pseudo-labels.
59        """
60        pseudo_labels = teacher(input_)
61        if self.activation is not None:
62            pseudo_labels = self.activation(pseudo_labels)
63        if self.confidence_threshold is None:
64            label_mask = None
65        else:
66            mask_input = pseudo_labels if self.mask_channel is None\
67                else pseudo_labels[self.mask_channel:(self.mask_channel+1)]
68            label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\
69                else self._compute_label_mask_one_side(mask_input)
70            if self.mask_channel is not None:
71                size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2)))
72                label_mask = label_mask.expand(*size)
73        return pseudo_labels, label_mask
74
75    def step(self, metric, epoch):
76        pass

Compute pseudo labels based on model predictions, typically from a teacher model.

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If None is given no mask will be computed.
  • threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
  • mask_channel: A specific channel to use for computing the confidence mask. By default the confidence mask is computed across all channels independently. This is useful, if only one of the channels encodes a probability.
DefaultPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, mask_channel: Optional[int] = None)
22    def __init__(
23        self,
24        activation: Optional[torch.nn.Module] = None,
25        confidence_threshold: Optional[float] = None,
26        threshold_from_both_sides: bool = True,
27        mask_channel: Optional[int] = None,
28    ):
29        self.activation = activation
30        self.confidence_threshold = confidence_threshold
31        self.threshold_from_both_sides = threshold_from_both_sides
32        self.mask_channel = mask_channel
33        # TODO serialize the class names and kwargs for activation instead
34        self.init_kwargs = {
35            "activation": None, "confidence_threshold": confidence_threshold,
36            "threshold_from_both_sides": threshold_from_both_sides,
37            "mask_channel": mask_channel,
38        }
activation
confidence_threshold
threshold_from_both_sides
mask_channel
init_kwargs
def step(self, metric, epoch):
75    def step(self, metric, epoch):
76        pass
class ProbabilisticPseudoLabeler:
 79class ProbabilisticPseudoLabeler:
 80    """Compute pseudo labels from a Probabilistic UNet.
 81
 82    Args:
 83        activation: Activation function applied to the teacher prediction.
 84        confidence_threshold: Threshold for computing a mask for filterign the pseudo labels.
 85            If none is given no mask will be computed.
 86        threshold_from_both_sides: Whether to include both values bigger than the threshold
 87            and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
 88            The former should be used for binary labels, the latter for for multiclass labels.
 89        prior_samples: The number of times to sample from the model distribution per input.
 90        consensus_masking: Whether to activate consensus masking in the label filter.
 91            If False, the weighted consensus response (weighted per-pixel response) is returned.
 92            If True, the masked consensus response (complete aggrement of pixels) is returned.
 93    """
 94    def __init__(
 95        self,
 96        activation: Optional[torch.nn.Module] = None,
 97        confidence_threshold: Optional[float] = None,
 98        threshold_from_both_sides: bool = True,
 99        prior_samples: int = 16,
100        consensus_masking: bool = False,
101    ):
102        self.activation = activation
103        self.confidence_threshold = confidence_threshold
104        self.threshold_from_both_sides = threshold_from_both_sides
105        self.prior_samples = prior_samples
106        self.consensus_masking = consensus_masking
107        # TODO serialize the class names and kwargs for activation instead
108        self.init_kwargs = {
109            "activation": None, "confidence_threshold": confidence_threshold,
110            "threshold_from_both_sides": threshold_from_both_sides
111        }
112
113    def _compute_label_mask_both_sides(self, pseudo_labels):
114        upper_threshold = self.confidence_threshold
115        lower_threshold = 1.0 - self.confidence_threshold
116        mask = [
117            torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.))
118            for sample in pseudo_labels
119        ]
120        return mask
121
122    def _compute_label_mask_one_side(self, pseudo_labels):
123        mask = [
124            torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.))
125            for sample in pseudo_labels
126        ]
127        return mask
128
129    def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
130        """Compute pseudo-labels.
131
132        Args:
133            teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`.
134            input_: The input for this batch.
135
136        Returns:
137            The pseudo-labels.
138        """
139        teacher.forward(input_)
140        if self.activation is not None:
141            pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)]
142        else:
143            pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)]
144        pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples
145
146        if self.confidence_threshold is None:
147            label_mask = None
148        else:
149            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \
150                else self._compute_label_mask_one_side(pseudo_labels)
151            label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples
152            if self.consensus_masking:
153                label_mask = torch.where(label_mask == 1, 1, 0)
154
155        return pseudo_labels, label_mask
156
157    def step(self, metric, epoch):
158        pass

Compute pseudo labels from a Probabilistic UNet.

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed.
  • threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
  • prior_samples: The number of times to sample from the model distribution per input.
  • consensus_masking: Whether to activate consensus masking in the label filter. If False, the weighted consensus response (weighted per-pixel response) is returned. If True, the masked consensus response (complete aggrement of pixels) is returned.
ProbabilisticPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, prior_samples: int = 16, consensus_masking: bool = False)
 94    def __init__(
 95        self,
 96        activation: Optional[torch.nn.Module] = None,
 97        confidence_threshold: Optional[float] = None,
 98        threshold_from_both_sides: bool = True,
 99        prior_samples: int = 16,
100        consensus_masking: bool = False,
101    ):
102        self.activation = activation
103        self.confidence_threshold = confidence_threshold
104        self.threshold_from_both_sides = threshold_from_both_sides
105        self.prior_samples = prior_samples
106        self.consensus_masking = consensus_masking
107        # TODO serialize the class names and kwargs for activation instead
108        self.init_kwargs = {
109            "activation": None, "confidence_threshold": confidence_threshold,
110            "threshold_from_both_sides": threshold_from_both_sides
111        }
activation
confidence_threshold
threshold_from_both_sides
prior_samples
consensus_masking
init_kwargs
def step(self, metric, epoch):
157    def step(self, metric, epoch):
158        pass
class ScheduledPseudoLabeler:
161class ScheduledPseudoLabeler:
162    """
163    Implement scheduled pseudo-labeling with dynamic confidence-threshold updates.
164
165    Pseudo labels are generated from a teacher model prediction and can be filtered
166    by a confidence mask. The confidence threshold can be adapted over time either
167    by decreasing it based on a monitored metric (plateau behavior) or by increasing
168    it with a fixed epoch schedule.
169
170    Args:
171        activation: Activation function applied to the teacher prediction.
172        confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
173            If None is given, no mask will be computed.
174        increase: If True, increase the confidence threshold over time according to
175            a fixed schedule. If False, decrease it based on plateau detection.
176        last_step_epoch: Last epoch at which threshold increase is allowed when
177            `increase=True`.
178        threshold_from_both_sides: Whether to include values larger than the
179            threshold and smaller than 1 - the threshold in the mask, or only values
180            larger than the threshold. The former should be used for binary labels,
181            the latter for multiclass labels.
182        mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
183            - 'min': A lower value of the monitored metric is considered better (e.g., loss).
184            - 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
185        factor: Update size for confidence-threshold scheduling. Interpreted as a
186            multiplicative factor for `threshold_mode='rel'` and as an additive step
187            for `threshold_mode='abs'`.
188        patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
189        threshold: Threshold value for determining a significant improvement in the performance metric
190            to reset the patience counter. This can be relative (percentage improvement)
191            or absolute depending on `threshold_mode`.
192        threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel')
193            or an absolute improvement ('abs').
194        min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
195        max_ct: Maximum allowed confidence threshold. The threshold will not be increased above this value.
196        eps: A small value to avoid floating-point precision errors during threshold comparison.
197        warm_up_epochs: Number of warm-up epochs. At the end of warm-up,
198            `confidence_threshold` is set to `max_ct`. This is intended for
199            decreasing-threshold scheduling (`increase=False`).
200        mask_channel: Specific channel to use for confidence masking. Currently,
201            only None is supported.
202        verbose: If True, prints messages when the confidence threshold is updated.
203    """
204
205    def __init__(
206        self,
207        activation: Optional[Union[torch.nn.Module, Callable]] = None,
208        confidence_threshold: Optional[float] = None,
209        increase: bool = False,
210        last_step_epoch: int = None,
211        threshold_from_both_sides: bool = True,
212        mode: Literal["min", "max"] = "min",
213        factor: float = 0.05,
214        patience: int = 10,
215        threshold: float = 1e-4,
216        threshold_mode: Literal["rel", "abs"] = "abs",
217        min_ct: float = 0.5,
218        max_ct: float = 0.95,
219        eps: float = 1e-8,
220        warm_up_epochs: int = 0,
221        mask_channel: Optional[int] = None,
222        verbose: bool = True,
223    ):
224        self.activation = activation
225        self.confidence_threshold = confidence_threshold
226        self.increase = increase
227        self.threshold_from_both_sides = threshold_from_both_sides
228        self.init_kwargs = {
229            "activation": None, "confidence_threshold": confidence_threshold,
230            "threshold_from_both_sides": threshold_from_both_sides,
231            "mask_channel": mask_channel,
232        }
233        # scheduler arguments
234        if mode not in {"min", "max"}:
235            raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.")
236        self.mode = mode
237
238        assert factor < 1, f"Factor must be smaller than 1, got {factor}"
239        self.factor = factor
240
241        self.patience = patience
242        self.threshold = threshold
243
244        if threshold_mode not in {"rel", "abs"}:
245            raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.")
246        self.threshold_mode = threshold_mode
247
248        self.min_ct = min_ct
249        self.max_ct = max_ct
250        self.eps = eps
251        self.warm_up_epochs = warm_up_epochs
252
253        if self.increase and self.warm_up_epochs > 0:
254            raise ValueError("warm_up_epochs > 0 is only supported when increase=False.")
255
256        # TODO implement mask_channel functionality; for now only None is supported
257        if mask_channel is not None:
258            raise NotImplementedError("mask_channel is not implemented yet; only None is supported.")
259        self.mask_channel = mask_channel
260        self.verbose = verbose
261
262        if mode == "min":
263            self.best = float("inf")
264        else:  # mode == "max":
265            self.best = float("-inf")
266
267        # self.best = 0
268        self.num_bad_epochs: int = 0
269        self.last_epoch = 0
270
271        if self.increase:
272            self.last_step_epoch = last_step_epoch
273
274            n_ct = len(np.arange(self.min_ct, self.max_ct + self.factor / 2, self.factor))
275            n_increments = n_ct - 1
276
277            if n_increments <= 0:
278                # nothing to increase
279                self.step_epoch = 0
280            else:
281                # compute initial step size; enforce minimum step size of 1
282                self.step_epoch = max(1, int(np.floor(self.last_step_epoch / n_increments)))
283
284                # ensure max_ct is reachable
285                required_epochs = self.step_epoch * n_increments
286                if self.last_step_epoch < required_epochs:
287                    self.last_step_epoch = required_epochs
288
289            print(
290                f"ScheduledPseudoLabeler: Increasing confidence_threshold every {self.step_epoch} epochs;",
291                f"until epoch {self.last_step_epoch}"
292            )
293
294    def _compute_label_mask_both_sides(self, pseudo_labels):
295        upper_threshold = self.confidence_threshold
296        lower_threshold = 1.0 - self.confidence_threshold
297        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
298        return mask
299
300    def _compute_label_mask_one_side(self, pseudo_labels):
301        mask = (pseudo_labels >= self.confidence_threshold)
302        return mask
303
304    def __call__(self, teacher, input_):
305        pseudo_labels = teacher(input_)
306        if self.activation is not None:
307            pseudo_labels = self.activation(pseudo_labels)
308        if self.confidence_threshold is None:
309            label_mask = None
310        else:
311            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
312                else self._compute_label_mask_one_side(pseudo_labels)
313        return pseudo_labels, label_mask
314
315    def end_warm_up(self):
316        self.confidence_threshold = self.max_ct
317        print(f"End of warm-up phase reached: setting confidence threshold to {self.confidence_threshold}")
318
319    def _is_better(self, a, best):
320        if self.mode == "min" and self.threshold_mode == "rel":
321            rel_epsilon = 1.0 - self.threshold
322            return a < best * rel_epsilon
323
324        elif self.mode == "min" and self.threshold_mode == "abs":
325            return a < best - self.threshold
326
327        elif self.mode == "max" and self.threshold_mode == "rel":
328            rel_epsilon = self.threshold + 1.0
329            return a > best * rel_epsilon
330
331        else:  # mode == 'max' and epsilon_mode == 'abs':
332            return a > best + self.threshold
333
334    def _reduce_ct(self, epoch):
335        old_ct = self.confidence_threshold
336        if self.threshold_mode == "rel":
337            new_ct = max(self.confidence_threshold * (1-self.factor), self.min_ct)
338        else:  # threshold_mode == 'abs':
339            new_ct = max(self.confidence_threshold - self.factor, self.min_ct)
340        if abs(old_ct - new_ct) > self.eps:
341            self.confidence_threshold = new_ct
342        if self.verbose:
343            print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}")
344
345    def decrease_step(self, metric, epoch=None):
346        if epoch is None:
347            epoch = self.last_epoch + 1
348            self.last_epoch = epoch
349
350        # If the metric is None, reduce the confidence threshold every epoch
351        if metric is None:
352            if epoch == 0:
353                return
354            if epoch % self.patience == 0:
355                self._reduce_ct(epoch)
356            return
357
358        else:
359            current = float(metric)
360
361            if self._is_better(current, self.best):
362                self.best = current
363                self.num_bad_epochs = 0
364            else:
365                self.num_bad_epochs += 1
366
367            if self.num_bad_epochs > self.patience:
368                self._reduce_ct(epoch)
369                self.num_bad_epochs = 0
370
371    def _increase_ct(self, epoch):
372        old_ct = self.confidence_threshold
373        if self.threshold_mode == "rel":
374            new_ct = min(self.confidence_threshold * (1+self.factor), self.max_ct)
375        else:  # threshold_mode == 'abs':
376            new_ct = min(self.confidence_threshold + self.factor, self.max_ct)
377        if abs(old_ct - new_ct) > self.eps:
378            self.confidence_threshold = new_ct
379        if self.verbose:
380            print(f"Epoch {epoch}: increase confidence threshold from {old_ct} to {self.confidence_threshold}")
381
382    def increase_step(self, epoch):
383        if epoch > self.last_step_epoch:
384            return
385
386        if epoch % self.step_epoch == 0 and epoch != 0:
387            self._increase_ct(epoch)
388
389    def step(self, metric=None, epoch=None):
390        if epoch == self.warm_up_epochs and self.warm_up_epochs > 0:
391            self.end_warm_up()
392        elif epoch > self.warm_up_epochs:
393            if self.increase:
394                self.increase_step(epoch)
395            else:
396                self.decrease_step(metric, epoch)
397        else:
398            return

Implement scheduled pseudo-labeling with dynamic confidence-threshold updates.

Pseudo labels are generated from a teacher model prediction and can be filtered by a confidence mask. The confidence threshold can be adapted over time either by decreasing it based on a monitored metric (plateau behavior) or by increasing it with a fixed epoch schedule.

Arguments:
  • activation: Activation function applied to the teacher prediction.
  • confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If None is given, no mask will be computed.
  • increase: If True, increase the confidence threshold over time according to a fixed schedule. If False, decrease it based on plateau detection.
  • last_step_epoch: Last epoch at which threshold increase is allowed when increase=True.
  • threshold_from_both_sides: Whether to include values larger than the threshold and smaller than 1 - the threshold in the mask, or only values larger than the threshold. The former should be used for binary labels, the latter for multiclass labels.
  • mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
    • 'min': A lower value of the monitored metric is considered better (e.g., loss).
    • 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
  • factor: Update size for confidence-threshold scheduling. Interpreted as a multiplicative factor for threshold_mode='rel' and as an additive step for threshold_mode='abs'.
  • patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
  • threshold: Threshold value for determining a significant improvement in the performance metric to reset the patience counter. This can be relative (percentage improvement) or absolute depending on threshold_mode.
  • threshold_mode: Determines whether the threshold is interpreted as a relative improvement ('rel') or an absolute improvement ('abs').
  • min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
  • max_ct: Maximum allowed confidence threshold. The threshold will not be increased above this value.
  • eps: A small value to avoid floating-point precision errors during threshold comparison.
  • warm_up_epochs: Number of warm-up epochs. At the end of warm-up, confidence_threshold is set to max_ct. This is intended for decreasing-threshold scheduling (increase=False).
  • mask_channel: Specific channel to use for confidence masking. Currently, only None is supported.
  • verbose: If True, prints messages when the confidence threshold is updated.
ScheduledPseudoLabeler( activation: Union[torch.nn.modules.module.Module, Callable, NoneType] = None, confidence_threshold: Optional[float] = None, increase: bool = False, last_step_epoch: int = None, threshold_from_both_sides: bool = True, mode: Literal['min', 'max'] = 'min', factor: float = 0.05, patience: int = 10, threshold: float = 0.0001, threshold_mode: Literal['rel', 'abs'] = 'abs', min_ct: float = 0.5, max_ct: float = 0.95, eps: float = 1e-08, warm_up_epochs: int = 0, mask_channel: Optional[int] = None, verbose: bool = True)
205    def __init__(
206        self,
207        activation: Optional[Union[torch.nn.Module, Callable]] = None,
208        confidence_threshold: Optional[float] = None,
209        increase: bool = False,
210        last_step_epoch: int = None,
211        threshold_from_both_sides: bool = True,
212        mode: Literal["min", "max"] = "min",
213        factor: float = 0.05,
214        patience: int = 10,
215        threshold: float = 1e-4,
216        threshold_mode: Literal["rel", "abs"] = "abs",
217        min_ct: float = 0.5,
218        max_ct: float = 0.95,
219        eps: float = 1e-8,
220        warm_up_epochs: int = 0,
221        mask_channel: Optional[int] = None,
222        verbose: bool = True,
223    ):
224        self.activation = activation
225        self.confidence_threshold = confidence_threshold
226        self.increase = increase
227        self.threshold_from_both_sides = threshold_from_both_sides
228        self.init_kwargs = {
229            "activation": None, "confidence_threshold": confidence_threshold,
230            "threshold_from_both_sides": threshold_from_both_sides,
231            "mask_channel": mask_channel,
232        }
233        # scheduler arguments
234        if mode not in {"min", "max"}:
235            raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.")
236        self.mode = mode
237
238        assert factor < 1, f"Factor must be smaller than 1, got {factor}"
239        self.factor = factor
240
241        self.patience = patience
242        self.threshold = threshold
243
244        if threshold_mode not in {"rel", "abs"}:
245            raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.")
246        self.threshold_mode = threshold_mode
247
248        self.min_ct = min_ct
249        self.max_ct = max_ct
250        self.eps = eps
251        self.warm_up_epochs = warm_up_epochs
252
253        if self.increase and self.warm_up_epochs > 0:
254            raise ValueError("warm_up_epochs > 0 is only supported when increase=False.")
255
256        # TODO implement mask_channel functionality; for now only None is supported
257        if mask_channel is not None:
258            raise NotImplementedError("mask_channel is not implemented yet; only None is supported.")
259        self.mask_channel = mask_channel
260        self.verbose = verbose
261
262        if mode == "min":
263            self.best = float("inf")
264        else:  # mode == "max":
265            self.best = float("-inf")
266
267        # self.best = 0
268        self.num_bad_epochs: int = 0
269        self.last_epoch = 0
270
271        if self.increase:
272            self.last_step_epoch = last_step_epoch
273
274            n_ct = len(np.arange(self.min_ct, self.max_ct + self.factor / 2, self.factor))
275            n_increments = n_ct - 1
276
277            if n_increments <= 0:
278                # nothing to increase
279                self.step_epoch = 0
280            else:
281                # compute initial step size; enforce minimum step size of 1
282                self.step_epoch = max(1, int(np.floor(self.last_step_epoch / n_increments)))
283
284                # ensure max_ct is reachable
285                required_epochs = self.step_epoch * n_increments
286                if self.last_step_epoch < required_epochs:
287                    self.last_step_epoch = required_epochs
288
289            print(
290                f"ScheduledPseudoLabeler: Increasing confidence_threshold every {self.step_epoch} epochs;",
291                f"until epoch {self.last_step_epoch}"
292            )
activation
confidence_threshold
increase
threshold_from_both_sides
init_kwargs
mode
factor
patience
threshold
threshold_mode
min_ct
max_ct
eps
warm_up_epochs
mask_channel
verbose
num_bad_epochs: int
last_epoch
def end_warm_up(self):
315    def end_warm_up(self):
316        self.confidence_threshold = self.max_ct
317        print(f"End of warm-up phase reached: setting confidence threshold to {self.confidence_threshold}")
def decrease_step(self, metric, epoch=None):
345    def decrease_step(self, metric, epoch=None):
346        if epoch is None:
347            epoch = self.last_epoch + 1
348            self.last_epoch = epoch
349
350        # If the metric is None, reduce the confidence threshold every epoch
351        if metric is None:
352            if epoch == 0:
353                return
354            if epoch % self.patience == 0:
355                self._reduce_ct(epoch)
356            return
357
358        else:
359            current = float(metric)
360
361            if self._is_better(current, self.best):
362                self.best = current
363                self.num_bad_epochs = 0
364            else:
365                self.num_bad_epochs += 1
366
367            if self.num_bad_epochs > self.patience:
368                self._reduce_ct(epoch)
369                self.num_bad_epochs = 0
def increase_step(self, epoch):
382    def increase_step(self, epoch):
383        if epoch > self.last_step_epoch:
384            return
385
386        if epoch % self.step_epoch == 0 and epoch != 0:
387            self._increase_ct(epoch)
def step(self, metric=None, epoch=None):
389    def step(self, metric=None, epoch=None):
390        if epoch == self.warm_up_epochs and self.warm_up_epochs > 0:
391            self.end_warm_up()
392        elif epoch > self.warm_up_epochs:
393            if self.increase:
394                self.increase_step(epoch)
395            else:
396                self.decrease_step(metric, epoch)
397        else:
398            return