torch_em.self_training.pseudo_labeling
1from typing import Callable, Literal, Optional, Union 2 3import torch 4 5 6class DefaultPseudoLabeler: 7 """Compute pseudo labels based on model predictions, typically from a teacher model. 8 9 Args: 10 activation: Activation function applied to the teacher prediction. 11 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 12 If None is given no mask will be computed. 13 threshold_from_both_sides: Whether to include both values bigger than the threshold 14 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 15 The former should be used for binary labels, the latter for for multiclass labels. 16 mask_channel: A specific channel to use for computing the confidence mask. 17 By default the confidence mask is computed across all channels independently. 18 This is useful, if only one of the channels encodes a probability. 19 """ 20 def __init__( 21 self, 22 activation: Optional[torch.nn.Module] = None, 23 confidence_threshold: Optional[float] = None, 24 threshold_from_both_sides: bool = True, 25 mask_channel: Optional[int] = None, 26 ): 27 self.activation = activation 28 self.confidence_threshold = confidence_threshold 29 self.threshold_from_both_sides = threshold_from_both_sides 30 self.mask_channel = mask_channel 31 # TODO serialize the class names and kwargs for activation instead 32 self.init_kwargs = { 33 "activation": None, "confidence_threshold": confidence_threshold, 34 "threshold_from_both_sides": threshold_from_both_sides, 35 "mask_channel": mask_channel, 36 } 37 38 def _compute_label_mask_both_sides(self, pseudo_labels): 39 upper_threshold = self.confidence_threshold 40 lower_threshold = 1.0 - self.confidence_threshold 41 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 42 return mask 43 44 def _compute_label_mask_one_side(self, pseudo_labels): 45 mask = (pseudo_labels >= self.confidence_threshold) 46 return mask 47 48 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 49 """Compute pseudo-labels. 50 51 Args: 52 teacher: The teacher model. 53 input_: The input for this batch. 54 55 Returns: 56 The pseudo-labels. 57 """ 58 pseudo_labels = teacher(input_) 59 if self.activation is not None: 60 pseudo_labels = self.activation(pseudo_labels) 61 if self.confidence_threshold is None: 62 label_mask = None 63 else: 64 mask_input = pseudo_labels if self.mask_channel is None\ 65 else pseudo_labels[self.mask_channel:(self.mask_channel+1)] 66 label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\ 67 else self._compute_label_mask_one_side(mask_input) 68 if self.mask_channel is not None: 69 size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2))) 70 label_mask = label_mask.expand(*size) 71 return pseudo_labels, label_mask 72 73 def step(self, metric, epoch): 74 pass 75 76 77class ProbabilisticPseudoLabeler: 78 """Compute pseudo labels from a Probabilistic UNet. 79 80 Args: 81 activation: Activation function applied to the teacher prediction. 82 confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. 83 If none is given no mask will be computed. 84 threshold_from_both_sides: Whether to include both values bigger than the threshold 85 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 86 The former should be used for binary labels, the latter for for multiclass labels. 87 prior_samples: The number of times to sample from the model distribution per input. 88 consensus_masking: Whether to activate consensus masking in the label filter. 89 If False, the weighted consensus response (weighted per-pixel response) is returned. 90 If True, the masked consensus response (complete aggrement of pixels) is returned. 91 """ 92 def __init__( 93 self, 94 activation: Optional[torch.nn.Module] = None, 95 confidence_threshold: Optional[float] = None, 96 threshold_from_both_sides: bool = True, 97 prior_samples: int = 16, 98 consensus_masking: bool = False, 99 ): 100 self.activation = activation 101 self.confidence_threshold = confidence_threshold 102 self.threshold_from_both_sides = threshold_from_both_sides 103 self.prior_samples = prior_samples 104 self.consensus_masking = consensus_masking 105 # TODO serialize the class names and kwargs for activation instead 106 self.init_kwargs = { 107 "activation": None, "confidence_threshold": confidence_threshold, 108 "threshold_from_both_sides": threshold_from_both_sides 109 } 110 111 def _compute_label_mask_both_sides(self, pseudo_labels): 112 upper_threshold = self.confidence_threshold 113 lower_threshold = 1.0 - self.confidence_threshold 114 mask = [ 115 torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.)) 116 for sample in pseudo_labels 117 ] 118 return mask 119 120 def _compute_label_mask_one_side(self, pseudo_labels): 121 mask = [ 122 torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.)) 123 for sample in pseudo_labels 124 ] 125 return mask 126 127 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 128 """Compute pseudo-labels. 129 130 Args: 131 teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`. 132 input_: The input for this batch. 133 134 Returns: 135 The pseudo-labels. 136 """ 137 teacher.forward(input_) 138 if self.activation is not None: 139 pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] 140 else: 141 pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] 142 pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples 143 144 if self.confidence_threshold is None: 145 label_mask = None 146 else: 147 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ 148 else self._compute_label_mask_one_side(pseudo_labels) 149 label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples 150 if self.consensus_masking: 151 label_mask = torch.where(label_mask == 1, 1, 0) 152 153 return pseudo_labels, label_mask 154 155 def step(self, metric, epoch): 156 pass 157 158 159class ScheduledPseudoLabeler: 160 """ 161 This class implements a scheduled pseudo-labeling mechanism, where pseudo labels 162 are generated from a teacher model's predictions, and the confidence threshold 163 for filtering the pseudo labels can be adjusted over time based on a performance 164 metric or a fixed schedule. It includes options for adjusting thresholds from 165 both sides (for binary classification) or from one side (for multiclass problems). 166 The threshold can be dynamically reduced to improve the quality of the pseudo labels 167 when the model performance does not improve for a given number of epochs (patience). 168 169 Args: 170 activation: Activation function applied to the teacher prediction. 171 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 172 If none is given no mask will be computed. 173 threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it, 174 or only values bigger than it in the mask. The former should be used for binary labels, 175 the latter for for multiclass labels. 176 mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric. 177 - 'min': A lower value of the monitored metric is considered better (e.g., loss). 178 - 'max': A higher value of the monitored metric is considered better (e.g., accuracy). 179 factor Factor by which the confidence threshold is reduced when the performance stagnates. 180 patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced. 181 threshold: Threshold value for determining a significant improvement in the performance metric 182 to reset the patience counter. This can be relative (percentage improvement) 183 or absolute depending on `threshold_mode`. 184 threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel') 185 or an absolute improvement ('abs'). 186 min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value. 187 eps: A small value to avoid floating-point precision errors during threshold comparison. 188 verbose: If True, prints messages when the confidence threshold is reduced. 189 """ 190 191 def __init__( 192 self, 193 activation: Optional[Union[torch.nn.Module, Callable]] = None, 194 confidence_threshold: Optional[float] = None, 195 threshold_from_both_sides=True, 196 mode: Literal["min", "max"] = "min", 197 factor: float = 0.05, 198 patience: int = 10, 199 threshold: float = 1e-4, 200 threshold_mode: Literal["rel", "abs"] = "abs", 201 min_ct: float = 0.5, 202 eps: float = 1e-8, 203 verbose: bool = True, 204 ): 205 self.activation = activation 206 self.confidence_threshold = confidence_threshold 207 self.threshold_from_both_sides = threshold_from_both_sides 208 self.init_kwargs = { 209 "activation": None, "confidence_threshold": confidence_threshold, 210 "threshold_from_both_sides": threshold_from_both_sides 211 } 212 # scheduler arguments 213 if mode not in {"min", "max"}: 214 raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.") 215 self.mode = mode 216 217 if factor >= 1.0: 218 raise ValueError("Factor should be < 1.0.") 219 self.factor = factor 220 221 self.patience = patience 222 self.threshold = threshold 223 224 if threshold_mode not in {"rel", "abs"}: 225 raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.") 226 self.threshold_mode = threshold_mode 227 228 self.min_ct = min_ct 229 self.eps = eps 230 self.verbose = verbose 231 232 if mode == "min": 233 self.best = float("inf") 234 else: # mode == "max": 235 self.best = float("-inf") 236 237 # self.best = 0 238 self.num_bad_epochs: int = 0 239 self.last_epoch = 0 240 241 def _compute_label_mask_both_sides(self, pseudo_labels): 242 upper_threshold = self.confidence_threshold 243 lower_threshold = 1.0 - self.confidence_threshold 244 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 245 return mask 246 247 def _compute_label_mask_one_side(self, pseudo_labels): 248 mask = (pseudo_labels >= self.confidence_threshold) 249 return mask 250 251 def __call__(self, teacher, input_): 252 pseudo_labels = teacher(input_) 253 if self.activation is not None: 254 pseudo_labels = self.activation(pseudo_labels) 255 if self.confidence_threshold is None: 256 label_mask = None 257 else: 258 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 259 else self._compute_label_mask_one_side(pseudo_labels) 260 return pseudo_labels, label_mask 261 262 def _is_better(self, a, best): 263 if self.mode == "min" and self.threshold_mode == "rel": 264 rel_epsilon = 1.0 - self.threshold 265 return a < best * rel_epsilon 266 267 elif self.mode == "min" and self.threshold_mode == "abs": 268 return a < best - self.threshold 269 270 elif self.mode == "max" and self.threshold_mode == "rel": 271 rel_epsilon = self.threshold + 1.0 272 return a > best * rel_epsilon 273 274 else: # mode == 'max' and epsilon_mode == 'abs': 275 return a > best + self.threshold 276 277 def _reduce_ct(self, epoch): 278 old_ct = self.confidence_threshold 279 if self.threshold_mode == "rel": 280 new_ct = max(self.confidence_threshold * self.factor, self.min_ct) 281 else: # threshold_mode == 'abs': 282 new_ct = max(self.confidence_threshold - self.factor, self.min_ct) 283 if old_ct - new_ct > self.eps: 284 self.confidence_threshold = new_ct 285 if self.verbose: 286 print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}") 287 288 def step(self, metric, epoch=None): 289 if epoch is None: 290 epoch = self.last_epoch + 1 291 self.last_epoch = epoch 292 293 # If the metric is None, reduce the confidence threshold every epoch 294 if metric is None: 295 if epoch == 0: 296 return 297 if epoch % self.patience == 0: 298 self._reduce_ct(epoch) 299 return 300 301 else: 302 current = float(metric) 303 304 if self._is_better(current, self.best): 305 self.best = current 306 self.num_bad_epochs = 0 307 else: 308 self.num_bad_epochs += 1 309 310 if self.num_bad_epochs > self.patience: 311 self._reduce_ct(epoch) 312 self.num_bad_epochs = 0
class
DefaultPseudoLabeler:
7class DefaultPseudoLabeler: 8 """Compute pseudo labels based on model predictions, typically from a teacher model. 9 10 Args: 11 activation: Activation function applied to the teacher prediction. 12 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 13 If None is given no mask will be computed. 14 threshold_from_both_sides: Whether to include both values bigger than the threshold 15 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 16 The former should be used for binary labels, the latter for for multiclass labels. 17 mask_channel: A specific channel to use for computing the confidence mask. 18 By default the confidence mask is computed across all channels independently. 19 This is useful, if only one of the channels encodes a probability. 20 """ 21 def __init__( 22 self, 23 activation: Optional[torch.nn.Module] = None, 24 confidence_threshold: Optional[float] = None, 25 threshold_from_both_sides: bool = True, 26 mask_channel: Optional[int] = None, 27 ): 28 self.activation = activation 29 self.confidence_threshold = confidence_threshold 30 self.threshold_from_both_sides = threshold_from_both_sides 31 self.mask_channel = mask_channel 32 # TODO serialize the class names and kwargs for activation instead 33 self.init_kwargs = { 34 "activation": None, "confidence_threshold": confidence_threshold, 35 "threshold_from_both_sides": threshold_from_both_sides, 36 "mask_channel": mask_channel, 37 } 38 39 def _compute_label_mask_both_sides(self, pseudo_labels): 40 upper_threshold = self.confidence_threshold 41 lower_threshold = 1.0 - self.confidence_threshold 42 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 43 return mask 44 45 def _compute_label_mask_one_side(self, pseudo_labels): 46 mask = (pseudo_labels >= self.confidence_threshold) 47 return mask 48 49 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 50 """Compute pseudo-labels. 51 52 Args: 53 teacher: The teacher model. 54 input_: The input for this batch. 55 56 Returns: 57 The pseudo-labels. 58 """ 59 pseudo_labels = teacher(input_) 60 if self.activation is not None: 61 pseudo_labels = self.activation(pseudo_labels) 62 if self.confidence_threshold is None: 63 label_mask = None 64 else: 65 mask_input = pseudo_labels if self.mask_channel is None\ 66 else pseudo_labels[self.mask_channel:(self.mask_channel+1)] 67 label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\ 68 else self._compute_label_mask_one_side(mask_input) 69 if self.mask_channel is not None: 70 size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2))) 71 label_mask = label_mask.expand(*size) 72 return pseudo_labels, label_mask 73 74 def step(self, metric, epoch): 75 pass
Compute pseudo labels based on model predictions, typically from a teacher model.
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If None is given no mask will be computed.
- threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
- mask_channel: A specific channel to use for computing the confidence mask. By default the confidence mask is computed across all channels independently. This is useful, if only one of the channels encodes a probability.
DefaultPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, mask_channel: Optional[int] = None)
21 def __init__( 22 self, 23 activation: Optional[torch.nn.Module] = None, 24 confidence_threshold: Optional[float] = None, 25 threshold_from_both_sides: bool = True, 26 mask_channel: Optional[int] = None, 27 ): 28 self.activation = activation 29 self.confidence_threshold = confidence_threshold 30 self.threshold_from_both_sides = threshold_from_both_sides 31 self.mask_channel = mask_channel 32 # TODO serialize the class names and kwargs for activation instead 33 self.init_kwargs = { 34 "activation": None, "confidence_threshold": confidence_threshold, 35 "threshold_from_both_sides": threshold_from_both_sides, 36 "mask_channel": mask_channel, 37 }
class
ProbabilisticPseudoLabeler:
78class ProbabilisticPseudoLabeler: 79 """Compute pseudo labels from a Probabilistic UNet. 80 81 Args: 82 activation: Activation function applied to the teacher prediction. 83 confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. 84 If none is given no mask will be computed. 85 threshold_from_both_sides: Whether to include both values bigger than the threshold 86 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 87 The former should be used for binary labels, the latter for for multiclass labels. 88 prior_samples: The number of times to sample from the model distribution per input. 89 consensus_masking: Whether to activate consensus masking in the label filter. 90 If False, the weighted consensus response (weighted per-pixel response) is returned. 91 If True, the masked consensus response (complete aggrement of pixels) is returned. 92 """ 93 def __init__( 94 self, 95 activation: Optional[torch.nn.Module] = None, 96 confidence_threshold: Optional[float] = None, 97 threshold_from_both_sides: bool = True, 98 prior_samples: int = 16, 99 consensus_masking: bool = False, 100 ): 101 self.activation = activation 102 self.confidence_threshold = confidence_threshold 103 self.threshold_from_both_sides = threshold_from_both_sides 104 self.prior_samples = prior_samples 105 self.consensus_masking = consensus_masking 106 # TODO serialize the class names and kwargs for activation instead 107 self.init_kwargs = { 108 "activation": None, "confidence_threshold": confidence_threshold, 109 "threshold_from_both_sides": threshold_from_both_sides 110 } 111 112 def _compute_label_mask_both_sides(self, pseudo_labels): 113 upper_threshold = self.confidence_threshold 114 lower_threshold = 1.0 - self.confidence_threshold 115 mask = [ 116 torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.)) 117 for sample in pseudo_labels 118 ] 119 return mask 120 121 def _compute_label_mask_one_side(self, pseudo_labels): 122 mask = [ 123 torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.)) 124 for sample in pseudo_labels 125 ] 126 return mask 127 128 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 129 """Compute pseudo-labels. 130 131 Args: 132 teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`. 133 input_: The input for this batch. 134 135 Returns: 136 The pseudo-labels. 137 """ 138 teacher.forward(input_) 139 if self.activation is not None: 140 pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] 141 else: 142 pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] 143 pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples 144 145 if self.confidence_threshold is None: 146 label_mask = None 147 else: 148 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ 149 else self._compute_label_mask_one_side(pseudo_labels) 150 label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples 151 if self.consensus_masking: 152 label_mask = torch.where(label_mask == 1, 1, 0) 153 154 return pseudo_labels, label_mask 155 156 def step(self, metric, epoch): 157 pass
Compute pseudo labels from a Probabilistic UNet.
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed.
- threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
- prior_samples: The number of times to sample from the model distribution per input.
- consensus_masking: Whether to activate consensus masking in the label filter. If False, the weighted consensus response (weighted per-pixel response) is returned. If True, the masked consensus response (complete aggrement of pixels) is returned.
ProbabilisticPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, prior_samples: int = 16, consensus_masking: bool = False)
93 def __init__( 94 self, 95 activation: Optional[torch.nn.Module] = None, 96 confidence_threshold: Optional[float] = None, 97 threshold_from_both_sides: bool = True, 98 prior_samples: int = 16, 99 consensus_masking: bool = False, 100 ): 101 self.activation = activation 102 self.confidence_threshold = confidence_threshold 103 self.threshold_from_both_sides = threshold_from_both_sides 104 self.prior_samples = prior_samples 105 self.consensus_masking = consensus_masking 106 # TODO serialize the class names and kwargs for activation instead 107 self.init_kwargs = { 108 "activation": None, "confidence_threshold": confidence_threshold, 109 "threshold_from_both_sides": threshold_from_both_sides 110 }
class
ScheduledPseudoLabeler:
160class ScheduledPseudoLabeler: 161 """ 162 This class implements a scheduled pseudo-labeling mechanism, where pseudo labels 163 are generated from a teacher model's predictions, and the confidence threshold 164 for filtering the pseudo labels can be adjusted over time based on a performance 165 metric or a fixed schedule. It includes options for adjusting thresholds from 166 both sides (for binary classification) or from one side (for multiclass problems). 167 The threshold can be dynamically reduced to improve the quality of the pseudo labels 168 when the model performance does not improve for a given number of epochs (patience). 169 170 Args: 171 activation: Activation function applied to the teacher prediction. 172 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 173 If none is given no mask will be computed. 174 threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it, 175 or only values bigger than it in the mask. The former should be used for binary labels, 176 the latter for for multiclass labels. 177 mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric. 178 - 'min': A lower value of the monitored metric is considered better (e.g., loss). 179 - 'max': A higher value of the monitored metric is considered better (e.g., accuracy). 180 factor Factor by which the confidence threshold is reduced when the performance stagnates. 181 patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced. 182 threshold: Threshold value for determining a significant improvement in the performance metric 183 to reset the patience counter. This can be relative (percentage improvement) 184 or absolute depending on `threshold_mode`. 185 threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel') 186 or an absolute improvement ('abs'). 187 min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value. 188 eps: A small value to avoid floating-point precision errors during threshold comparison. 189 verbose: If True, prints messages when the confidence threshold is reduced. 190 """ 191 192 def __init__( 193 self, 194 activation: Optional[Union[torch.nn.Module, Callable]] = None, 195 confidence_threshold: Optional[float] = None, 196 threshold_from_both_sides=True, 197 mode: Literal["min", "max"] = "min", 198 factor: float = 0.05, 199 patience: int = 10, 200 threshold: float = 1e-4, 201 threshold_mode: Literal["rel", "abs"] = "abs", 202 min_ct: float = 0.5, 203 eps: float = 1e-8, 204 verbose: bool = True, 205 ): 206 self.activation = activation 207 self.confidence_threshold = confidence_threshold 208 self.threshold_from_both_sides = threshold_from_both_sides 209 self.init_kwargs = { 210 "activation": None, "confidence_threshold": confidence_threshold, 211 "threshold_from_both_sides": threshold_from_both_sides 212 } 213 # scheduler arguments 214 if mode not in {"min", "max"}: 215 raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.") 216 self.mode = mode 217 218 if factor >= 1.0: 219 raise ValueError("Factor should be < 1.0.") 220 self.factor = factor 221 222 self.patience = patience 223 self.threshold = threshold 224 225 if threshold_mode not in {"rel", "abs"}: 226 raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.") 227 self.threshold_mode = threshold_mode 228 229 self.min_ct = min_ct 230 self.eps = eps 231 self.verbose = verbose 232 233 if mode == "min": 234 self.best = float("inf") 235 else: # mode == "max": 236 self.best = float("-inf") 237 238 # self.best = 0 239 self.num_bad_epochs: int = 0 240 self.last_epoch = 0 241 242 def _compute_label_mask_both_sides(self, pseudo_labels): 243 upper_threshold = self.confidence_threshold 244 lower_threshold = 1.0 - self.confidence_threshold 245 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 246 return mask 247 248 def _compute_label_mask_one_side(self, pseudo_labels): 249 mask = (pseudo_labels >= self.confidence_threshold) 250 return mask 251 252 def __call__(self, teacher, input_): 253 pseudo_labels = teacher(input_) 254 if self.activation is not None: 255 pseudo_labels = self.activation(pseudo_labels) 256 if self.confidence_threshold is None: 257 label_mask = None 258 else: 259 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 260 else self._compute_label_mask_one_side(pseudo_labels) 261 return pseudo_labels, label_mask 262 263 def _is_better(self, a, best): 264 if self.mode == "min" and self.threshold_mode == "rel": 265 rel_epsilon = 1.0 - self.threshold 266 return a < best * rel_epsilon 267 268 elif self.mode == "min" and self.threshold_mode == "abs": 269 return a < best - self.threshold 270 271 elif self.mode == "max" and self.threshold_mode == "rel": 272 rel_epsilon = self.threshold + 1.0 273 return a > best * rel_epsilon 274 275 else: # mode == 'max' and epsilon_mode == 'abs': 276 return a > best + self.threshold 277 278 def _reduce_ct(self, epoch): 279 old_ct = self.confidence_threshold 280 if self.threshold_mode == "rel": 281 new_ct = max(self.confidence_threshold * self.factor, self.min_ct) 282 else: # threshold_mode == 'abs': 283 new_ct = max(self.confidence_threshold - self.factor, self.min_ct) 284 if old_ct - new_ct > self.eps: 285 self.confidence_threshold = new_ct 286 if self.verbose: 287 print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}") 288 289 def step(self, metric, epoch=None): 290 if epoch is None: 291 epoch = self.last_epoch + 1 292 self.last_epoch = epoch 293 294 # If the metric is None, reduce the confidence threshold every epoch 295 if metric is None: 296 if epoch == 0: 297 return 298 if epoch % self.patience == 0: 299 self._reduce_ct(epoch) 300 return 301 302 else: 303 current = float(metric) 304 305 if self._is_better(current, self.best): 306 self.best = current 307 self.num_bad_epochs = 0 308 else: 309 self.num_bad_epochs += 1 310 311 if self.num_bad_epochs > self.patience: 312 self._reduce_ct(epoch) 313 self.num_bad_epochs = 0
This class implements a scheduled pseudo-labeling mechanism, where pseudo labels are generated from a teacher model's predictions, and the confidence threshold for filtering the pseudo labels can be adjusted over time based on a performance metric or a fixed schedule. It includes options for adjusting thresholds from both sides (for binary classification) or from one side (for multiclass problems). The threshold can be dynamically reduced to improve the quality of the pseudo labels when the model performance does not improve for a given number of epochs (patience).
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If none is given no mask will be computed.
- threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - it, or only values bigger than it in the mask. The former should be used for binary labels, the latter for for multiclass labels.
- mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
- 'min': A lower value of the monitored metric is considered better (e.g., loss).
- 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
- factor Factor by which the confidence threshold is reduced when the performance stagnates.
- patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
- threshold: Threshold value for determining a significant improvement in the performance metric
to reset the patience counter. This can be relative (percentage improvement)
or absolute depending on
threshold_mode
. - threshold_mode: Determines whether the
threshold
is interpreted as a relative improvement ('rel') or an absolute improvement ('abs'). - min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
- eps: A small value to avoid floating-point precision errors during threshold comparison.
- verbose: If True, prints messages when the confidence threshold is reduced.
ScheduledPseudoLabeler( activation: Union[torch.nn.modules.module.Module, Callable, NoneType] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides=True, mode: Literal['min', 'max'] = 'min', factor: float = 0.05, patience: int = 10, threshold: float = 0.0001, threshold_mode: Literal['rel', 'abs'] = 'abs', min_ct: float = 0.5, eps: float = 1e-08, verbose: bool = True)
192 def __init__( 193 self, 194 activation: Optional[Union[torch.nn.Module, Callable]] = None, 195 confidence_threshold: Optional[float] = None, 196 threshold_from_both_sides=True, 197 mode: Literal["min", "max"] = "min", 198 factor: float = 0.05, 199 patience: int = 10, 200 threshold: float = 1e-4, 201 threshold_mode: Literal["rel", "abs"] = "abs", 202 min_ct: float = 0.5, 203 eps: float = 1e-8, 204 verbose: bool = True, 205 ): 206 self.activation = activation 207 self.confidence_threshold = confidence_threshold 208 self.threshold_from_both_sides = threshold_from_both_sides 209 self.init_kwargs = { 210 "activation": None, "confidence_threshold": confidence_threshold, 211 "threshold_from_both_sides": threshold_from_both_sides 212 } 213 # scheduler arguments 214 if mode not in {"min", "max"}: 215 raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.") 216 self.mode = mode 217 218 if factor >= 1.0: 219 raise ValueError("Factor should be < 1.0.") 220 self.factor = factor 221 222 self.patience = patience 223 self.threshold = threshold 224 225 if threshold_mode not in {"rel", "abs"}: 226 raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.") 227 self.threshold_mode = threshold_mode 228 229 self.min_ct = min_ct 230 self.eps = eps 231 self.verbose = verbose 232 233 if mode == "min": 234 self.best = float("inf") 235 else: # mode == "max": 236 self.best = float("-inf") 237 238 # self.best = 0 239 self.num_bad_epochs: int = 0 240 self.last_epoch = 0
def
step(self, metric, epoch=None):
289 def step(self, metric, epoch=None): 290 if epoch is None: 291 epoch = self.last_epoch + 1 292 self.last_epoch = epoch 293 294 # If the metric is None, reduce the confidence threshold every epoch 295 if metric is None: 296 if epoch == 0: 297 return 298 if epoch % self.patience == 0: 299 self._reduce_ct(epoch) 300 return 301 302 else: 303 current = float(metric) 304 305 if self._is_better(current, self.best): 306 self.best = current 307 self.num_bad_epochs = 0 308 else: 309 self.num_bad_epochs += 1 310 311 if self.num_bad_epochs > self.patience: 312 self._reduce_ct(epoch) 313 self.num_bad_epochs = 0