torch_em.self_training.pseudo_labeling
1from typing import Callable, Literal, Optional, Union 2 3import torch 4 5 6class DefaultPseudoLabeler: 7 """Compute pseudo labels based on model predictions, typically from a teacher model. 8 9 Args: 10 activation: Activation function applied to the teacher prediction. 11 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 12 If None is given no mask will be computed. 13 threshold_from_both_sides: Whether to include both values bigger than the threshold 14 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 15 The former should be used for binary labels, the latter for for multiclass labels. 16 """ 17 def __init__( 18 self, 19 activation: Optional[torch.nn.Module] = None, 20 confidence_threshold: Optional[float] = None, 21 threshold_from_both_sides: bool = True, 22 ): 23 self.activation = activation 24 self.confidence_threshold = confidence_threshold 25 self.threshold_from_both_sides = threshold_from_both_sides 26 # TODO serialize the class names and kwargs for activation instead 27 self.init_kwargs = { 28 "activation": None, "confidence_threshold": confidence_threshold, 29 "threshold_from_both_sides": threshold_from_both_sides 30 } 31 32 def _compute_label_mask_both_sides(self, pseudo_labels): 33 upper_threshold = self.confidence_threshold 34 lower_threshold = 1.0 - self.confidence_threshold 35 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 36 return mask 37 38 def _compute_label_mask_one_side(self, pseudo_labels): 39 mask = (pseudo_labels >= self.confidence_threshold) 40 return mask 41 42 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 43 """Compute pseudo-labels. 44 45 Args: 46 teacher: The teacher model. 47 input_: The input for this batch. 48 49 Returns: 50 The pseudo-labels. 51 """ 52 pseudo_labels = teacher(input_) 53 if self.activation is not None: 54 pseudo_labels = self.activation(pseudo_labels) 55 if self.confidence_threshold is None: 56 label_mask = None 57 else: 58 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 59 else self._compute_label_mask_one_side(pseudo_labels) 60 return pseudo_labels, label_mask 61 62 def step(self, metric, epoch): 63 pass 64 65 66class ProbabilisticPseudoLabeler: 67 """Compute pseudo labels from a Probabilistic UNet. 68 69 Args: 70 activation: Activation function applied to the teacher prediction. 71 confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. 72 If none is given no mask will be computed. 73 threshold_from_both_sides: Whether to include both values bigger than the threshold 74 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 75 The former should be used for binary labels, the latter for for multiclass labels. 76 prior_samples: The number of times to sample from the model distribution per input. 77 consensus_masking: Whether to activate consensus masking in the label filter. 78 If False, the weighted consensus response (weighted per-pixel response) is returned. 79 If True, the masked consensus response (complete aggrement of pixels) is returned. 80 """ 81 def __init__( 82 self, 83 activation: Optional[torch.nn.Module] = None, 84 confidence_threshold: Optional[float] = None, 85 threshold_from_both_sides: bool = True, 86 prior_samples: int = 16, 87 consensus_masking: bool = False, 88 ): 89 self.activation = activation 90 self.confidence_threshold = confidence_threshold 91 self.threshold_from_both_sides = threshold_from_both_sides 92 self.prior_samples = prior_samples 93 self.consensus_masking = consensus_masking 94 # TODO serialize the class names and kwargs for activation instead 95 self.init_kwargs = { 96 "activation": None, "confidence_threshold": confidence_threshold, 97 "threshold_from_both_sides": threshold_from_both_sides 98 } 99 100 def _compute_label_mask_both_sides(self, pseudo_labels): 101 upper_threshold = self.confidence_threshold 102 lower_threshold = 1.0 - self.confidence_threshold 103 mask = [ 104 torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.)) 105 for sample in pseudo_labels 106 ] 107 return mask 108 109 def _compute_label_mask_one_side(self, pseudo_labels): 110 mask = [ 111 torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.)) 112 for sample in pseudo_labels 113 ] 114 return mask 115 116 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 117 """Compute pseudo-labels. 118 119 Args: 120 teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`. 121 input_: The input for this batch. 122 123 Returns: 124 The pseudo-labels. 125 """ 126 teacher.forward(input_) 127 if self.activation is not None: 128 pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] 129 else: 130 pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] 131 pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples 132 133 if self.confidence_threshold is None: 134 label_mask = None 135 else: 136 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ 137 else self._compute_label_mask_one_side(pseudo_labels) 138 label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples 139 if self.consensus_masking: 140 label_mask = torch.where(label_mask == 1, 1, 0) 141 142 return pseudo_labels, label_mask 143 144 def step(self, metric, epoch): 145 pass 146 147 148class ScheduledPseudoLabeler: 149 """ 150 This class implements a scheduled pseudo-labeling mechanism, where pseudo labels 151 are generated from a teacher model's predictions, and the confidence threshold 152 for filtering the pseudo labels can be adjusted over time based on a performance 153 metric or a fixed schedule. It includes options for adjusting thresholds from 154 both sides (for binary classification) or from one side (for multiclass problems). 155 The threshold can be dynamically reduced to improve the quality of the pseudo labels 156 when the model performance does not improve for a given number of epochs (patience). 157 158 Args: 159 activation: Activation function applied to the teacher prediction. 160 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 161 If none is given no mask will be computed. 162 threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it, 163 or only values bigger than it in the mask. The former should be used for binary labels, 164 the latter for for multiclass labels. 165 mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric. 166 - 'min': A lower value of the monitored metric is considered better (e.g., loss). 167 - 'max': A higher value of the monitored metric is considered better (e.g., accuracy). 168 factor Factor by which the confidence threshold is reduced when the performance stagnates. 169 patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced. 170 threshold: Threshold value for determining a significant improvement in the performance metric 171 to reset the patience counter. This can be relative (percentage improvement) 172 or absolute depending on `threshold_mode`. 173 threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel') 174 or an absolute improvement ('abs'). 175 min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value. 176 eps: A small value to avoid floating-point precision errors during threshold comparison. 177 verbose: If True, prints messages when the confidence threshold is reduced. 178 """ 179 180 def __init__( 181 self, 182 activation: Optional[Union[torch.nn.Module, Callable]] = None, 183 confidence_threshold: Optional[float] = None, 184 threshold_from_both_sides=True, 185 mode: Literal["min", "max"] = "min", 186 factor: float = 0.05, 187 patience: int = 10, 188 threshold: float = 1e-4, 189 threshold_mode: Literal["rel", "abs"] = "abs", 190 min_ct: float = 0.5, 191 eps: float = 1e-8, 192 verbose: bool = True, 193 ): 194 self.activation = activation 195 self.confidence_threshold = confidence_threshold 196 self.threshold_from_both_sides = threshold_from_both_sides 197 self.init_kwargs = { 198 "activation": None, "confidence_threshold": confidence_threshold, 199 "threshold_from_both_sides": threshold_from_both_sides 200 } 201 # scheduler arguments 202 if mode not in {"min", "max"}: 203 raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.") 204 self.mode = mode 205 206 if factor >= 1.0: 207 raise ValueError("Factor should be < 1.0.") 208 self.factor = factor 209 210 self.patience = patience 211 self.threshold = threshold 212 213 if threshold_mode not in {"rel", "abs"}: 214 raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.") 215 self.threshold_mode = threshold_mode 216 217 self.min_ct = min_ct 218 self.eps = eps 219 self.verbose = verbose 220 221 if mode == "min": 222 self.best = float("inf") 223 else: # mode == "max": 224 self.best = float("-inf") 225 226 # self.best = 0 227 self.num_bad_epochs: int = 0 228 self.last_epoch = 0 229 230 def _compute_label_mask_both_sides(self, pseudo_labels): 231 upper_threshold = self.confidence_threshold 232 lower_threshold = 1.0 - self.confidence_threshold 233 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 234 return mask 235 236 def _compute_label_mask_one_side(self, pseudo_labels): 237 mask = (pseudo_labels >= self.confidence_threshold) 238 return mask 239 240 def __call__(self, teacher, input_): 241 pseudo_labels = teacher(input_) 242 if self.activation is not None: 243 pseudo_labels = self.activation(pseudo_labels) 244 if self.confidence_threshold is None: 245 label_mask = None 246 else: 247 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 248 else self._compute_label_mask_one_side(pseudo_labels) 249 return pseudo_labels, label_mask 250 251 def _is_better(self, a, best): 252 if self.mode == "min" and self.threshold_mode == "rel": 253 rel_epsilon = 1.0 - self.threshold 254 return a < best * rel_epsilon 255 256 elif self.mode == "min" and self.threshold_mode == "abs": 257 return a < best - self.threshold 258 259 elif self.mode == "max" and self.threshold_mode == "rel": 260 rel_epsilon = self.threshold + 1.0 261 return a > best * rel_epsilon 262 263 else: # mode == 'max' and epsilon_mode == 'abs': 264 return a > best + self.threshold 265 266 def _reduce_ct(self, epoch): 267 old_ct = self.confidence_threshold 268 if self.threshold_mode == "rel": 269 new_ct = max(self.confidence_threshold * self.factor, self.min_ct) 270 else: # threshold_mode == 'abs': 271 new_ct = max(self.confidence_threshold - self.factor, self.min_ct) 272 if old_ct - new_ct > self.eps: 273 self.confidence_threshold = new_ct 274 if self.verbose: 275 print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}") 276 277 def step(self, metric, epoch=None): 278 if epoch is None: 279 epoch = self.last_epoch + 1 280 self.last_epoch = epoch 281 282 # If the metric is None, reduce the confidence threshold every epoch 283 if metric is None: 284 if epoch == 0: 285 return 286 if epoch % self.patience == 0: 287 self._reduce_ct(epoch) 288 return 289 290 else: 291 current = float(metric) 292 293 if self._is_better(current, self.best): 294 self.best = current 295 self.num_bad_epochs = 0 296 else: 297 self.num_bad_epochs += 1 298 299 if self.num_bad_epochs > self.patience: 300 self._reduce_ct(epoch) 301 self.num_bad_epochs = 0
class
DefaultPseudoLabeler:
7class DefaultPseudoLabeler: 8 """Compute pseudo labels based on model predictions, typically from a teacher model. 9 10 Args: 11 activation: Activation function applied to the teacher prediction. 12 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 13 If None is given no mask will be computed. 14 threshold_from_both_sides: Whether to include both values bigger than the threshold 15 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 16 The former should be used for binary labels, the latter for for multiclass labels. 17 """ 18 def __init__( 19 self, 20 activation: Optional[torch.nn.Module] = None, 21 confidence_threshold: Optional[float] = None, 22 threshold_from_both_sides: bool = True, 23 ): 24 self.activation = activation 25 self.confidence_threshold = confidence_threshold 26 self.threshold_from_both_sides = threshold_from_both_sides 27 # TODO serialize the class names and kwargs for activation instead 28 self.init_kwargs = { 29 "activation": None, "confidence_threshold": confidence_threshold, 30 "threshold_from_both_sides": threshold_from_both_sides 31 } 32 33 def _compute_label_mask_both_sides(self, pseudo_labels): 34 upper_threshold = self.confidence_threshold 35 lower_threshold = 1.0 - self.confidence_threshold 36 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 37 return mask 38 39 def _compute_label_mask_one_side(self, pseudo_labels): 40 mask = (pseudo_labels >= self.confidence_threshold) 41 return mask 42 43 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 44 """Compute pseudo-labels. 45 46 Args: 47 teacher: The teacher model. 48 input_: The input for this batch. 49 50 Returns: 51 The pseudo-labels. 52 """ 53 pseudo_labels = teacher(input_) 54 if self.activation is not None: 55 pseudo_labels = self.activation(pseudo_labels) 56 if self.confidence_threshold is None: 57 label_mask = None 58 else: 59 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 60 else self._compute_label_mask_one_side(pseudo_labels) 61 return pseudo_labels, label_mask 62 63 def step(self, metric, epoch): 64 pass
Compute pseudo labels based on model predictions, typically from a teacher model.
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If None is given no mask will be computed.
- threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
DefaultPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True)
18 def __init__( 19 self, 20 activation: Optional[torch.nn.Module] = None, 21 confidence_threshold: Optional[float] = None, 22 threshold_from_both_sides: bool = True, 23 ): 24 self.activation = activation 25 self.confidence_threshold = confidence_threshold 26 self.threshold_from_both_sides = threshold_from_both_sides 27 # TODO serialize the class names and kwargs for activation instead 28 self.init_kwargs = { 29 "activation": None, "confidence_threshold": confidence_threshold, 30 "threshold_from_both_sides": threshold_from_both_sides 31 }
class
ProbabilisticPseudoLabeler:
67class ProbabilisticPseudoLabeler: 68 """Compute pseudo labels from a Probabilistic UNet. 69 70 Args: 71 activation: Activation function applied to the teacher prediction. 72 confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. 73 If none is given no mask will be computed. 74 threshold_from_both_sides: Whether to include both values bigger than the threshold 75 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 76 The former should be used for binary labels, the latter for for multiclass labels. 77 prior_samples: The number of times to sample from the model distribution per input. 78 consensus_masking: Whether to activate consensus masking in the label filter. 79 If False, the weighted consensus response (weighted per-pixel response) is returned. 80 If True, the masked consensus response (complete aggrement of pixels) is returned. 81 """ 82 def __init__( 83 self, 84 activation: Optional[torch.nn.Module] = None, 85 confidence_threshold: Optional[float] = None, 86 threshold_from_both_sides: bool = True, 87 prior_samples: int = 16, 88 consensus_masking: bool = False, 89 ): 90 self.activation = activation 91 self.confidence_threshold = confidence_threshold 92 self.threshold_from_both_sides = threshold_from_both_sides 93 self.prior_samples = prior_samples 94 self.consensus_masking = consensus_masking 95 # TODO serialize the class names and kwargs for activation instead 96 self.init_kwargs = { 97 "activation": None, "confidence_threshold": confidence_threshold, 98 "threshold_from_both_sides": threshold_from_both_sides 99 } 100 101 def _compute_label_mask_both_sides(self, pseudo_labels): 102 upper_threshold = self.confidence_threshold 103 lower_threshold = 1.0 - self.confidence_threshold 104 mask = [ 105 torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.)) 106 for sample in pseudo_labels 107 ] 108 return mask 109 110 def _compute_label_mask_one_side(self, pseudo_labels): 111 mask = [ 112 torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.)) 113 for sample in pseudo_labels 114 ] 115 return mask 116 117 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 118 """Compute pseudo-labels. 119 120 Args: 121 teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`. 122 input_: The input for this batch. 123 124 Returns: 125 The pseudo-labels. 126 """ 127 teacher.forward(input_) 128 if self.activation is not None: 129 pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] 130 else: 131 pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] 132 pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples 133 134 if self.confidence_threshold is None: 135 label_mask = None 136 else: 137 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ 138 else self._compute_label_mask_one_side(pseudo_labels) 139 label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples 140 if self.consensus_masking: 141 label_mask = torch.where(label_mask == 1, 1, 0) 142 143 return pseudo_labels, label_mask 144 145 def step(self, metric, epoch): 146 pass
Compute pseudo labels from a Probabilistic UNet.
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed.
- threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
- prior_samples: The number of times to sample from the model distribution per input.
- consensus_masking: Whether to activate consensus masking in the label filter. If False, the weighted consensus response (weighted per-pixel response) is returned. If True, the masked consensus response (complete aggrement of pixels) is returned.
ProbabilisticPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, prior_samples: int = 16, consensus_masking: bool = False)
82 def __init__( 83 self, 84 activation: Optional[torch.nn.Module] = None, 85 confidence_threshold: Optional[float] = None, 86 threshold_from_both_sides: bool = True, 87 prior_samples: int = 16, 88 consensus_masking: bool = False, 89 ): 90 self.activation = activation 91 self.confidence_threshold = confidence_threshold 92 self.threshold_from_both_sides = threshold_from_both_sides 93 self.prior_samples = prior_samples 94 self.consensus_masking = consensus_masking 95 # TODO serialize the class names and kwargs for activation instead 96 self.init_kwargs = { 97 "activation": None, "confidence_threshold": confidence_threshold, 98 "threshold_from_both_sides": threshold_from_both_sides 99 }
class
ScheduledPseudoLabeler:
149class ScheduledPseudoLabeler: 150 """ 151 This class implements a scheduled pseudo-labeling mechanism, where pseudo labels 152 are generated from a teacher model's predictions, and the confidence threshold 153 for filtering the pseudo labels can be adjusted over time based on a performance 154 metric or a fixed schedule. It includes options for adjusting thresholds from 155 both sides (for binary classification) or from one side (for multiclass problems). 156 The threshold can be dynamically reduced to improve the quality of the pseudo labels 157 when the model performance does not improve for a given number of epochs (patience). 158 159 Args: 160 activation: Activation function applied to the teacher prediction. 161 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 162 If none is given no mask will be computed. 163 threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it, 164 or only values bigger than it in the mask. The former should be used for binary labels, 165 the latter for for multiclass labels. 166 mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric. 167 - 'min': A lower value of the monitored metric is considered better (e.g., loss). 168 - 'max': A higher value of the monitored metric is considered better (e.g., accuracy). 169 factor Factor by which the confidence threshold is reduced when the performance stagnates. 170 patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced. 171 threshold: Threshold value for determining a significant improvement in the performance metric 172 to reset the patience counter. This can be relative (percentage improvement) 173 or absolute depending on `threshold_mode`. 174 threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel') 175 or an absolute improvement ('abs'). 176 min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value. 177 eps: A small value to avoid floating-point precision errors during threshold comparison. 178 verbose: If True, prints messages when the confidence threshold is reduced. 179 """ 180 181 def __init__( 182 self, 183 activation: Optional[Union[torch.nn.Module, Callable]] = None, 184 confidence_threshold: Optional[float] = None, 185 threshold_from_both_sides=True, 186 mode: Literal["min", "max"] = "min", 187 factor: float = 0.05, 188 patience: int = 10, 189 threshold: float = 1e-4, 190 threshold_mode: Literal["rel", "abs"] = "abs", 191 min_ct: float = 0.5, 192 eps: float = 1e-8, 193 verbose: bool = True, 194 ): 195 self.activation = activation 196 self.confidence_threshold = confidence_threshold 197 self.threshold_from_both_sides = threshold_from_both_sides 198 self.init_kwargs = { 199 "activation": None, "confidence_threshold": confidence_threshold, 200 "threshold_from_both_sides": threshold_from_both_sides 201 } 202 # scheduler arguments 203 if mode not in {"min", "max"}: 204 raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.") 205 self.mode = mode 206 207 if factor >= 1.0: 208 raise ValueError("Factor should be < 1.0.") 209 self.factor = factor 210 211 self.patience = patience 212 self.threshold = threshold 213 214 if threshold_mode not in {"rel", "abs"}: 215 raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.") 216 self.threshold_mode = threshold_mode 217 218 self.min_ct = min_ct 219 self.eps = eps 220 self.verbose = verbose 221 222 if mode == "min": 223 self.best = float("inf") 224 else: # mode == "max": 225 self.best = float("-inf") 226 227 # self.best = 0 228 self.num_bad_epochs: int = 0 229 self.last_epoch = 0 230 231 def _compute_label_mask_both_sides(self, pseudo_labels): 232 upper_threshold = self.confidence_threshold 233 lower_threshold = 1.0 - self.confidence_threshold 234 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 235 return mask 236 237 def _compute_label_mask_one_side(self, pseudo_labels): 238 mask = (pseudo_labels >= self.confidence_threshold) 239 return mask 240 241 def __call__(self, teacher, input_): 242 pseudo_labels = teacher(input_) 243 if self.activation is not None: 244 pseudo_labels = self.activation(pseudo_labels) 245 if self.confidence_threshold is None: 246 label_mask = None 247 else: 248 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 249 else self._compute_label_mask_one_side(pseudo_labels) 250 return pseudo_labels, label_mask 251 252 def _is_better(self, a, best): 253 if self.mode == "min" and self.threshold_mode == "rel": 254 rel_epsilon = 1.0 - self.threshold 255 return a < best * rel_epsilon 256 257 elif self.mode == "min" and self.threshold_mode == "abs": 258 return a < best - self.threshold 259 260 elif self.mode == "max" and self.threshold_mode == "rel": 261 rel_epsilon = self.threshold + 1.0 262 return a > best * rel_epsilon 263 264 else: # mode == 'max' and epsilon_mode == 'abs': 265 return a > best + self.threshold 266 267 def _reduce_ct(self, epoch): 268 old_ct = self.confidence_threshold 269 if self.threshold_mode == "rel": 270 new_ct = max(self.confidence_threshold * self.factor, self.min_ct) 271 else: # threshold_mode == 'abs': 272 new_ct = max(self.confidence_threshold - self.factor, self.min_ct) 273 if old_ct - new_ct > self.eps: 274 self.confidence_threshold = new_ct 275 if self.verbose: 276 print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}") 277 278 def step(self, metric, epoch=None): 279 if epoch is None: 280 epoch = self.last_epoch + 1 281 self.last_epoch = epoch 282 283 # If the metric is None, reduce the confidence threshold every epoch 284 if metric is None: 285 if epoch == 0: 286 return 287 if epoch % self.patience == 0: 288 self._reduce_ct(epoch) 289 return 290 291 else: 292 current = float(metric) 293 294 if self._is_better(current, self.best): 295 self.best = current 296 self.num_bad_epochs = 0 297 else: 298 self.num_bad_epochs += 1 299 300 if self.num_bad_epochs > self.patience: 301 self._reduce_ct(epoch) 302 self.num_bad_epochs = 0
This class implements a scheduled pseudo-labeling mechanism, where pseudo labels are generated from a teacher model's predictions, and the confidence threshold for filtering the pseudo labels can be adjusted over time based on a performance metric or a fixed schedule. It includes options for adjusting thresholds from both sides (for binary classification) or from one side (for multiclass problems). The threshold can be dynamically reduced to improve the quality of the pseudo labels when the model performance does not improve for a given number of epochs (patience).
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If none is given no mask will be computed.
- threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it, or only values bigger than it in the mask. The former should be used for binary labels, the latter for for multiclass labels.
- mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
- 'min': A lower value of the monitored metric is considered better (e.g., loss).
- 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
- factor Factor by which the confidence threshold is reduced when the performance stagnates.
- patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
- threshold: Threshold value for determining a significant improvement in the performance metric
to reset the patience counter. This can be relative (percentage improvement)
or absolute depending on
threshold_mode
. - threshold_mode: Determines whether the
threshold
is interpreted as a relative improvement ('rel') or an absolute improvement ('abs'). - min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
- eps: A small value to avoid floating-point precision errors during threshold comparison.
- verbose: If True, prints messages when the confidence threshold is reduced.
ScheduledPseudoLabeler( activation: Union[torch.nn.modules.module.Module, Callable, NoneType] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides=True, mode: Literal['min', 'max'] = 'min', factor: float = 0.05, patience: int = 10, threshold: float = 0.0001, threshold_mode: Literal['rel', 'abs'] = 'abs', min_ct: float = 0.5, eps: float = 1e-08, verbose: bool = True)
181 def __init__( 182 self, 183 activation: Optional[Union[torch.nn.Module, Callable]] = None, 184 confidence_threshold: Optional[float] = None, 185 threshold_from_both_sides=True, 186 mode: Literal["min", "max"] = "min", 187 factor: float = 0.05, 188 patience: int = 10, 189 threshold: float = 1e-4, 190 threshold_mode: Literal["rel", "abs"] = "abs", 191 min_ct: float = 0.5, 192 eps: float = 1e-8, 193 verbose: bool = True, 194 ): 195 self.activation = activation 196 self.confidence_threshold = confidence_threshold 197 self.threshold_from_both_sides = threshold_from_both_sides 198 self.init_kwargs = { 199 "activation": None, "confidence_threshold": confidence_threshold, 200 "threshold_from_both_sides": threshold_from_both_sides 201 } 202 # scheduler arguments 203 if mode not in {"min", "max"}: 204 raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.") 205 self.mode = mode 206 207 if factor >= 1.0: 208 raise ValueError("Factor should be < 1.0.") 209 self.factor = factor 210 211 self.patience = patience 212 self.threshold = threshold 213 214 if threshold_mode not in {"rel", "abs"}: 215 raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.") 216 self.threshold_mode = threshold_mode 217 218 self.min_ct = min_ct 219 self.eps = eps 220 self.verbose = verbose 221 222 if mode == "min": 223 self.best = float("inf") 224 else: # mode == "max": 225 self.best = float("-inf") 226 227 # self.best = 0 228 self.num_bad_epochs: int = 0 229 self.last_epoch = 0
def
step(self, metric, epoch=None):
278 def step(self, metric, epoch=None): 279 if epoch is None: 280 epoch = self.last_epoch + 1 281 self.last_epoch = epoch 282 283 # If the metric is None, reduce the confidence threshold every epoch 284 if metric is None: 285 if epoch == 0: 286 return 287 if epoch % self.patience == 0: 288 self._reduce_ct(epoch) 289 return 290 291 else: 292 current = float(metric) 293 294 if self._is_better(current, self.best): 295 self.best = current 296 self.num_bad_epochs = 0 297 else: 298 self.num_bad_epochs += 1 299 300 if self.num_bad_epochs > self.patience: 301 self._reduce_ct(epoch) 302 self.num_bad_epochs = 0