torch_em.self_training.pseudo_labeling
1from typing import Callable, Literal, Optional, Union 2 3import torch 4import numpy as np 5 6 7class DefaultPseudoLabeler: 8 """Compute pseudo labels based on model predictions, typically from a teacher model. 9 10 Args: 11 activation: Activation function applied to the teacher prediction. 12 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 13 If None is given no mask will be computed. 14 threshold_from_both_sides: Whether to include both values bigger than the threshold 15 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 16 The former should be used for binary labels, the latter for for multiclass labels. 17 mask_channel: A specific channel to use for computing the confidence mask. 18 By default the confidence mask is computed across all channels independently. 19 This is useful, if only one of the channels encodes a probability. 20 """ 21 def __init__( 22 self, 23 activation: Optional[torch.nn.Module] = None, 24 confidence_threshold: Optional[float] = None, 25 threshold_from_both_sides: bool = True, 26 mask_channel: Optional[int] = None, 27 ): 28 self.activation = activation 29 self.confidence_threshold = confidence_threshold 30 self.threshold_from_both_sides = threshold_from_both_sides 31 self.mask_channel = mask_channel 32 # TODO serialize the class names and kwargs for activation instead 33 self.init_kwargs = { 34 "activation": None, "confidence_threshold": confidence_threshold, 35 "threshold_from_both_sides": threshold_from_both_sides, 36 "mask_channel": mask_channel, 37 } 38 39 def _compute_label_mask_both_sides(self, pseudo_labels): 40 upper_threshold = self.confidence_threshold 41 lower_threshold = 1.0 - self.confidence_threshold 42 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 43 return mask 44 45 def _compute_label_mask_one_side(self, pseudo_labels): 46 mask = (pseudo_labels >= self.confidence_threshold) 47 return mask 48 49 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 50 """Compute pseudo-labels. 51 52 Args: 53 teacher: The teacher model. 54 input_: The input for this batch. 55 56 Returns: 57 The pseudo-labels. 58 """ 59 pseudo_labels = teacher(input_) 60 if self.activation is not None: 61 pseudo_labels = self.activation(pseudo_labels) 62 if self.confidence_threshold is None: 63 label_mask = None 64 else: 65 mask_input = pseudo_labels if self.mask_channel is None\ 66 else pseudo_labels[self.mask_channel:(self.mask_channel+1)] 67 label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\ 68 else self._compute_label_mask_one_side(mask_input) 69 if self.mask_channel is not None: 70 size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2))) 71 label_mask = label_mask.expand(*size) 72 return pseudo_labels, label_mask 73 74 def step(self, metric, epoch): 75 pass 76 77 78class ProbabilisticPseudoLabeler: 79 """Compute pseudo labels from a Probabilistic UNet. 80 81 Args: 82 activation: Activation function applied to the teacher prediction. 83 confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. 84 If none is given no mask will be computed. 85 threshold_from_both_sides: Whether to include both values bigger than the threshold 86 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 87 The former should be used for binary labels, the latter for for multiclass labels. 88 prior_samples: The number of times to sample from the model distribution per input. 89 consensus_masking: Whether to activate consensus masking in the label filter. 90 If False, the weighted consensus response (weighted per-pixel response) is returned. 91 If True, the masked consensus response (complete aggrement of pixels) is returned. 92 """ 93 def __init__( 94 self, 95 activation: Optional[torch.nn.Module] = None, 96 confidence_threshold: Optional[float] = None, 97 threshold_from_both_sides: bool = True, 98 prior_samples: int = 16, 99 consensus_masking: bool = False, 100 ): 101 self.activation = activation 102 self.confidence_threshold = confidence_threshold 103 self.threshold_from_both_sides = threshold_from_both_sides 104 self.prior_samples = prior_samples 105 self.consensus_masking = consensus_masking 106 # TODO serialize the class names and kwargs for activation instead 107 self.init_kwargs = { 108 "activation": None, "confidence_threshold": confidence_threshold, 109 "threshold_from_both_sides": threshold_from_both_sides 110 } 111 112 def _compute_label_mask_both_sides(self, pseudo_labels): 113 upper_threshold = self.confidence_threshold 114 lower_threshold = 1.0 - self.confidence_threshold 115 mask = [ 116 torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.)) 117 for sample in pseudo_labels 118 ] 119 return mask 120 121 def _compute_label_mask_one_side(self, pseudo_labels): 122 mask = [ 123 torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.)) 124 for sample in pseudo_labels 125 ] 126 return mask 127 128 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 129 """Compute pseudo-labels. 130 131 Args: 132 teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`. 133 input_: The input for this batch. 134 135 Returns: 136 The pseudo-labels. 137 """ 138 teacher.forward(input_) 139 if self.activation is not None: 140 pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] 141 else: 142 pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] 143 pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples 144 145 if self.confidence_threshold is None: 146 label_mask = None 147 else: 148 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ 149 else self._compute_label_mask_one_side(pseudo_labels) 150 label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples 151 if self.consensus_masking: 152 label_mask = torch.where(label_mask == 1, 1, 0) 153 154 return pseudo_labels, label_mask 155 156 def step(self, metric, epoch): 157 pass 158 159 160class ScheduledPseudoLabeler: 161 """ 162 Implement scheduled pseudo-labeling with dynamic confidence-threshold updates. 163 164 Pseudo labels are generated from a teacher model prediction and can be filtered 165 by a confidence mask. The confidence threshold can be adapted over time either 166 by decreasing it based on a monitored metric (plateau behavior) or by increasing 167 it with a fixed epoch schedule. 168 169 Args: 170 activation: Activation function applied to the teacher prediction. 171 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 172 If None is given, no mask will be computed. 173 increase: If True, increase the confidence threshold over time according to 174 a fixed schedule. If False, decrease it based on plateau detection. 175 last_step_epoch: Last epoch at which threshold increase is allowed when 176 `increase=True`. 177 threshold_from_both_sides: Whether to include values larger than the 178 threshold and smaller than 1 - the threshold in the mask, or only values 179 larger than the threshold. The former should be used for binary labels, 180 the latter for multiclass labels. 181 mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric. 182 - 'min': A lower value of the monitored metric is considered better (e.g., loss). 183 - 'max': A higher value of the monitored metric is considered better (e.g., accuracy). 184 factor: Update size for confidence-threshold scheduling. Interpreted as a 185 multiplicative factor for `threshold_mode='rel'` and as an additive step 186 for `threshold_mode='abs'`. 187 patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced. 188 threshold: Threshold value for determining a significant improvement in the performance metric 189 to reset the patience counter. This can be relative (percentage improvement) 190 or absolute depending on `threshold_mode`. 191 threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel') 192 or an absolute improvement ('abs'). 193 min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value. 194 max_ct: Maximum allowed confidence threshold. The threshold will not be increased above this value. 195 eps: A small value to avoid floating-point precision errors during threshold comparison. 196 warm_up_epochs: Number of warm-up epochs. At the end of warm-up, 197 `confidence_threshold` is set to `max_ct`. This is intended for 198 decreasing-threshold scheduling (`increase=False`). 199 mask_channel: Specific channel to use for confidence masking. Currently, 200 only None is supported. 201 verbose: If True, prints messages when the confidence threshold is updated. 202 """ 203 204 def __init__( 205 self, 206 activation: Optional[Union[torch.nn.Module, Callable]] = None, 207 confidence_threshold: Optional[float] = None, 208 increase: bool = False, 209 last_step_epoch: int = None, 210 threshold_from_both_sides: bool = True, 211 mode: Literal["min", "max"] = "min", 212 factor: float = 0.05, 213 patience: int = 10, 214 threshold: float = 1e-4, 215 threshold_mode: Literal["rel", "abs"] = "abs", 216 min_ct: float = 0.5, 217 max_ct: float = 0.95, 218 eps: float = 1e-8, 219 warm_up_epochs: int = 0, 220 mask_channel: Optional[int] = None, 221 verbose: bool = True, 222 ): 223 self.activation = activation 224 self.confidence_threshold = confidence_threshold 225 self.increase = increase 226 self.threshold_from_both_sides = threshold_from_both_sides 227 self.init_kwargs = { 228 "activation": None, "confidence_threshold": confidence_threshold, 229 "threshold_from_both_sides": threshold_from_both_sides, 230 "mask_channel": mask_channel, 231 } 232 # scheduler arguments 233 if mode not in {"min", "max"}: 234 raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.") 235 self.mode = mode 236 237 assert factor < 1, f"Factor must be smaller than 1, got {factor}" 238 self.factor = factor 239 240 self.patience = patience 241 self.threshold = threshold 242 243 if threshold_mode not in {"rel", "abs"}: 244 raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.") 245 self.threshold_mode = threshold_mode 246 247 self.min_ct = min_ct 248 self.max_ct = max_ct 249 self.eps = eps 250 self.warm_up_epochs = warm_up_epochs 251 252 if self.increase and self.warm_up_epochs > 0: 253 raise ValueError("warm_up_epochs > 0 is only supported when increase=False.") 254 255 # TODO implement mask_channel functionality; for now only None is supported 256 if mask_channel is not None: 257 raise NotImplementedError("mask_channel is not implemented yet; only None is supported.") 258 self.mask_channel = mask_channel 259 self.verbose = verbose 260 261 if mode == "min": 262 self.best = float("inf") 263 else: # mode == "max": 264 self.best = float("-inf") 265 266 # self.best = 0 267 self.num_bad_epochs: int = 0 268 self.last_epoch = 0 269 270 if self.increase: 271 self.last_step_epoch = last_step_epoch 272 273 n_ct = len(np.arange(self.min_ct, self.max_ct + self.factor / 2, self.factor)) 274 n_increments = n_ct - 1 275 276 if n_increments <= 0: 277 # nothing to increase 278 self.step_epoch = 0 279 else: 280 # compute initial step size; enforce minimum step size of 1 281 self.step_epoch = max(1, int(np.floor(self.last_step_epoch / n_increments))) 282 283 # ensure max_ct is reachable 284 required_epochs = self.step_epoch * n_increments 285 if self.last_step_epoch < required_epochs: 286 self.last_step_epoch = required_epochs 287 288 print( 289 f"ScheduledPseudoLabeler: Increasing confidence_threshold every {self.step_epoch} epochs;", 290 f"until epoch {self.last_step_epoch}" 291 ) 292 293 def _compute_label_mask_both_sides(self, pseudo_labels): 294 upper_threshold = self.confidence_threshold 295 lower_threshold = 1.0 - self.confidence_threshold 296 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 297 return mask 298 299 def _compute_label_mask_one_side(self, pseudo_labels): 300 mask = (pseudo_labels >= self.confidence_threshold) 301 return mask 302 303 def __call__(self, teacher, input_): 304 pseudo_labels = teacher(input_) 305 if self.activation is not None: 306 pseudo_labels = self.activation(pseudo_labels) 307 if self.confidence_threshold is None: 308 label_mask = None 309 else: 310 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 311 else self._compute_label_mask_one_side(pseudo_labels) 312 return pseudo_labels, label_mask 313 314 def end_warm_up(self): 315 self.confidence_threshold = self.max_ct 316 print(f"End of warm-up phase reached: setting confidence threshold to {self.confidence_threshold}") 317 318 def _is_better(self, a, best): 319 if self.mode == "min" and self.threshold_mode == "rel": 320 rel_epsilon = 1.0 - self.threshold 321 return a < best * rel_epsilon 322 323 elif self.mode == "min" and self.threshold_mode == "abs": 324 return a < best - self.threshold 325 326 elif self.mode == "max" and self.threshold_mode == "rel": 327 rel_epsilon = self.threshold + 1.0 328 return a > best * rel_epsilon 329 330 else: # mode == 'max' and epsilon_mode == 'abs': 331 return a > best + self.threshold 332 333 def _reduce_ct(self, epoch): 334 old_ct = self.confidence_threshold 335 if self.threshold_mode == "rel": 336 new_ct = max(self.confidence_threshold * (1-self.factor), self.min_ct) 337 else: # threshold_mode == 'abs': 338 new_ct = max(self.confidence_threshold - self.factor, self.min_ct) 339 if abs(old_ct - new_ct) > self.eps: 340 self.confidence_threshold = new_ct 341 if self.verbose: 342 print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}") 343 344 def decrease_step(self, metric, epoch=None): 345 if epoch is None: 346 epoch = self.last_epoch + 1 347 self.last_epoch = epoch 348 349 # If the metric is None, reduce the confidence threshold every epoch 350 if metric is None: 351 if epoch == 0: 352 return 353 if epoch % self.patience == 0: 354 self._reduce_ct(epoch) 355 return 356 357 else: 358 current = float(metric) 359 360 if self._is_better(current, self.best): 361 self.best = current 362 self.num_bad_epochs = 0 363 else: 364 self.num_bad_epochs += 1 365 366 if self.num_bad_epochs > self.patience: 367 self._reduce_ct(epoch) 368 self.num_bad_epochs = 0 369 370 def _increase_ct(self, epoch): 371 old_ct = self.confidence_threshold 372 if self.threshold_mode == "rel": 373 new_ct = min(self.confidence_threshold * (1+self.factor), self.max_ct) 374 else: # threshold_mode == 'abs': 375 new_ct = min(self.confidence_threshold + self.factor, self.max_ct) 376 if abs(old_ct - new_ct) > self.eps: 377 self.confidence_threshold = new_ct 378 if self.verbose: 379 print(f"Epoch {epoch}: increase confidence threshold from {old_ct} to {self.confidence_threshold}") 380 381 def increase_step(self, epoch): 382 if epoch > self.last_step_epoch: 383 return 384 385 if epoch % self.step_epoch == 0 and epoch != 0: 386 self._increase_ct(epoch) 387 388 def step(self, metric=None, epoch=None): 389 if epoch == self.warm_up_epochs and self.warm_up_epochs > 0: 390 self.end_warm_up() 391 elif epoch > self.warm_up_epochs: 392 if self.increase: 393 self.increase_step(epoch) 394 else: 395 self.decrease_step(metric, epoch) 396 else: 397 return
class
DefaultPseudoLabeler:
8class DefaultPseudoLabeler: 9 """Compute pseudo labels based on model predictions, typically from a teacher model. 10 11 Args: 12 activation: Activation function applied to the teacher prediction. 13 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 14 If None is given no mask will be computed. 15 threshold_from_both_sides: Whether to include both values bigger than the threshold 16 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 17 The former should be used for binary labels, the latter for for multiclass labels. 18 mask_channel: A specific channel to use for computing the confidence mask. 19 By default the confidence mask is computed across all channels independently. 20 This is useful, if only one of the channels encodes a probability. 21 """ 22 def __init__( 23 self, 24 activation: Optional[torch.nn.Module] = None, 25 confidence_threshold: Optional[float] = None, 26 threshold_from_both_sides: bool = True, 27 mask_channel: Optional[int] = None, 28 ): 29 self.activation = activation 30 self.confidence_threshold = confidence_threshold 31 self.threshold_from_both_sides = threshold_from_both_sides 32 self.mask_channel = mask_channel 33 # TODO serialize the class names and kwargs for activation instead 34 self.init_kwargs = { 35 "activation": None, "confidence_threshold": confidence_threshold, 36 "threshold_from_both_sides": threshold_from_both_sides, 37 "mask_channel": mask_channel, 38 } 39 40 def _compute_label_mask_both_sides(self, pseudo_labels): 41 upper_threshold = self.confidence_threshold 42 lower_threshold = 1.0 - self.confidence_threshold 43 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 44 return mask 45 46 def _compute_label_mask_one_side(self, pseudo_labels): 47 mask = (pseudo_labels >= self.confidence_threshold) 48 return mask 49 50 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 51 """Compute pseudo-labels. 52 53 Args: 54 teacher: The teacher model. 55 input_: The input for this batch. 56 57 Returns: 58 The pseudo-labels. 59 """ 60 pseudo_labels = teacher(input_) 61 if self.activation is not None: 62 pseudo_labels = self.activation(pseudo_labels) 63 if self.confidence_threshold is None: 64 label_mask = None 65 else: 66 mask_input = pseudo_labels if self.mask_channel is None\ 67 else pseudo_labels[self.mask_channel:(self.mask_channel+1)] 68 label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\ 69 else self._compute_label_mask_one_side(mask_input) 70 if self.mask_channel is not None: 71 size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2))) 72 label_mask = label_mask.expand(*size) 73 return pseudo_labels, label_mask 74 75 def step(self, metric, epoch): 76 pass
Compute pseudo labels based on model predictions, typically from a teacher model.
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If None is given no mask will be computed.
- threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
- mask_channel: A specific channel to use for computing the confidence mask. By default the confidence mask is computed across all channels independently. This is useful, if only one of the channels encodes a probability.
DefaultPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, mask_channel: Optional[int] = None)
22 def __init__( 23 self, 24 activation: Optional[torch.nn.Module] = None, 25 confidence_threshold: Optional[float] = None, 26 threshold_from_both_sides: bool = True, 27 mask_channel: Optional[int] = None, 28 ): 29 self.activation = activation 30 self.confidence_threshold = confidence_threshold 31 self.threshold_from_both_sides = threshold_from_both_sides 32 self.mask_channel = mask_channel 33 # TODO serialize the class names and kwargs for activation instead 34 self.init_kwargs = { 35 "activation": None, "confidence_threshold": confidence_threshold, 36 "threshold_from_both_sides": threshold_from_both_sides, 37 "mask_channel": mask_channel, 38 }
class
ProbabilisticPseudoLabeler:
79class ProbabilisticPseudoLabeler: 80 """Compute pseudo labels from a Probabilistic UNet. 81 82 Args: 83 activation: Activation function applied to the teacher prediction. 84 confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. 85 If none is given no mask will be computed. 86 threshold_from_both_sides: Whether to include both values bigger than the threshold 87 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 88 The former should be used for binary labels, the latter for for multiclass labels. 89 prior_samples: The number of times to sample from the model distribution per input. 90 consensus_masking: Whether to activate consensus masking in the label filter. 91 If False, the weighted consensus response (weighted per-pixel response) is returned. 92 If True, the masked consensus response (complete aggrement of pixels) is returned. 93 """ 94 def __init__( 95 self, 96 activation: Optional[torch.nn.Module] = None, 97 confidence_threshold: Optional[float] = None, 98 threshold_from_both_sides: bool = True, 99 prior_samples: int = 16, 100 consensus_masking: bool = False, 101 ): 102 self.activation = activation 103 self.confidence_threshold = confidence_threshold 104 self.threshold_from_both_sides = threshold_from_both_sides 105 self.prior_samples = prior_samples 106 self.consensus_masking = consensus_masking 107 # TODO serialize the class names and kwargs for activation instead 108 self.init_kwargs = { 109 "activation": None, "confidence_threshold": confidence_threshold, 110 "threshold_from_both_sides": threshold_from_both_sides 111 } 112 113 def _compute_label_mask_both_sides(self, pseudo_labels): 114 upper_threshold = self.confidence_threshold 115 lower_threshold = 1.0 - self.confidence_threshold 116 mask = [ 117 torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.)) 118 for sample in pseudo_labels 119 ] 120 return mask 121 122 def _compute_label_mask_one_side(self, pseudo_labels): 123 mask = [ 124 torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.)) 125 for sample in pseudo_labels 126 ] 127 return mask 128 129 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 130 """Compute pseudo-labels. 131 132 Args: 133 teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`. 134 input_: The input for this batch. 135 136 Returns: 137 The pseudo-labels. 138 """ 139 teacher.forward(input_) 140 if self.activation is not None: 141 pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] 142 else: 143 pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] 144 pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples 145 146 if self.confidence_threshold is None: 147 label_mask = None 148 else: 149 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ 150 else self._compute_label_mask_one_side(pseudo_labels) 151 label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples 152 if self.consensus_masking: 153 label_mask = torch.where(label_mask == 1, 1, 0) 154 155 return pseudo_labels, label_mask 156 157 def step(self, metric, epoch): 158 pass
Compute pseudo labels from a Probabilistic UNet.
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed.
- threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
- prior_samples: The number of times to sample from the model distribution per input.
- consensus_masking: Whether to activate consensus masking in the label filter. If False, the weighted consensus response (weighted per-pixel response) is returned. If True, the masked consensus response (complete aggrement of pixels) is returned.
ProbabilisticPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, prior_samples: int = 16, consensus_masking: bool = False)
94 def __init__( 95 self, 96 activation: Optional[torch.nn.Module] = None, 97 confidence_threshold: Optional[float] = None, 98 threshold_from_both_sides: bool = True, 99 prior_samples: int = 16, 100 consensus_masking: bool = False, 101 ): 102 self.activation = activation 103 self.confidence_threshold = confidence_threshold 104 self.threshold_from_both_sides = threshold_from_both_sides 105 self.prior_samples = prior_samples 106 self.consensus_masking = consensus_masking 107 # TODO serialize the class names and kwargs for activation instead 108 self.init_kwargs = { 109 "activation": None, "confidence_threshold": confidence_threshold, 110 "threshold_from_both_sides": threshold_from_both_sides 111 }
class
ScheduledPseudoLabeler:
161class ScheduledPseudoLabeler: 162 """ 163 Implement scheduled pseudo-labeling with dynamic confidence-threshold updates. 164 165 Pseudo labels are generated from a teacher model prediction and can be filtered 166 by a confidence mask. The confidence threshold can be adapted over time either 167 by decreasing it based on a monitored metric (plateau behavior) or by increasing 168 it with a fixed epoch schedule. 169 170 Args: 171 activation: Activation function applied to the teacher prediction. 172 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 173 If None is given, no mask will be computed. 174 increase: If True, increase the confidence threshold over time according to 175 a fixed schedule. If False, decrease it based on plateau detection. 176 last_step_epoch: Last epoch at which threshold increase is allowed when 177 `increase=True`. 178 threshold_from_both_sides: Whether to include values larger than the 179 threshold and smaller than 1 - the threshold in the mask, or only values 180 larger than the threshold. The former should be used for binary labels, 181 the latter for multiclass labels. 182 mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric. 183 - 'min': A lower value of the monitored metric is considered better (e.g., loss). 184 - 'max': A higher value of the monitored metric is considered better (e.g., accuracy). 185 factor: Update size for confidence-threshold scheduling. Interpreted as a 186 multiplicative factor for `threshold_mode='rel'` and as an additive step 187 for `threshold_mode='abs'`. 188 patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced. 189 threshold: Threshold value for determining a significant improvement in the performance metric 190 to reset the patience counter. This can be relative (percentage improvement) 191 or absolute depending on `threshold_mode`. 192 threshold_mode: Determines whether the `threshold` is interpreted as a relative improvement ('rel') 193 or an absolute improvement ('abs'). 194 min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value. 195 max_ct: Maximum allowed confidence threshold. The threshold will not be increased above this value. 196 eps: A small value to avoid floating-point precision errors during threshold comparison. 197 warm_up_epochs: Number of warm-up epochs. At the end of warm-up, 198 `confidence_threshold` is set to `max_ct`. This is intended for 199 decreasing-threshold scheduling (`increase=False`). 200 mask_channel: Specific channel to use for confidence masking. Currently, 201 only None is supported. 202 verbose: If True, prints messages when the confidence threshold is updated. 203 """ 204 205 def __init__( 206 self, 207 activation: Optional[Union[torch.nn.Module, Callable]] = None, 208 confidence_threshold: Optional[float] = None, 209 increase: bool = False, 210 last_step_epoch: int = None, 211 threshold_from_both_sides: bool = True, 212 mode: Literal["min", "max"] = "min", 213 factor: float = 0.05, 214 patience: int = 10, 215 threshold: float = 1e-4, 216 threshold_mode: Literal["rel", "abs"] = "abs", 217 min_ct: float = 0.5, 218 max_ct: float = 0.95, 219 eps: float = 1e-8, 220 warm_up_epochs: int = 0, 221 mask_channel: Optional[int] = None, 222 verbose: bool = True, 223 ): 224 self.activation = activation 225 self.confidence_threshold = confidence_threshold 226 self.increase = increase 227 self.threshold_from_both_sides = threshold_from_both_sides 228 self.init_kwargs = { 229 "activation": None, "confidence_threshold": confidence_threshold, 230 "threshold_from_both_sides": threshold_from_both_sides, 231 "mask_channel": mask_channel, 232 } 233 # scheduler arguments 234 if mode not in {"min", "max"}: 235 raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.") 236 self.mode = mode 237 238 assert factor < 1, f"Factor must be smaller than 1, got {factor}" 239 self.factor = factor 240 241 self.patience = patience 242 self.threshold = threshold 243 244 if threshold_mode not in {"rel", "abs"}: 245 raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.") 246 self.threshold_mode = threshold_mode 247 248 self.min_ct = min_ct 249 self.max_ct = max_ct 250 self.eps = eps 251 self.warm_up_epochs = warm_up_epochs 252 253 if self.increase and self.warm_up_epochs > 0: 254 raise ValueError("warm_up_epochs > 0 is only supported when increase=False.") 255 256 # TODO implement mask_channel functionality; for now only None is supported 257 if mask_channel is not None: 258 raise NotImplementedError("mask_channel is not implemented yet; only None is supported.") 259 self.mask_channel = mask_channel 260 self.verbose = verbose 261 262 if mode == "min": 263 self.best = float("inf") 264 else: # mode == "max": 265 self.best = float("-inf") 266 267 # self.best = 0 268 self.num_bad_epochs: int = 0 269 self.last_epoch = 0 270 271 if self.increase: 272 self.last_step_epoch = last_step_epoch 273 274 n_ct = len(np.arange(self.min_ct, self.max_ct + self.factor / 2, self.factor)) 275 n_increments = n_ct - 1 276 277 if n_increments <= 0: 278 # nothing to increase 279 self.step_epoch = 0 280 else: 281 # compute initial step size; enforce minimum step size of 1 282 self.step_epoch = max(1, int(np.floor(self.last_step_epoch / n_increments))) 283 284 # ensure max_ct is reachable 285 required_epochs = self.step_epoch * n_increments 286 if self.last_step_epoch < required_epochs: 287 self.last_step_epoch = required_epochs 288 289 print( 290 f"ScheduledPseudoLabeler: Increasing confidence_threshold every {self.step_epoch} epochs;", 291 f"until epoch {self.last_step_epoch}" 292 ) 293 294 def _compute_label_mask_both_sides(self, pseudo_labels): 295 upper_threshold = self.confidence_threshold 296 lower_threshold = 1.0 - self.confidence_threshold 297 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 298 return mask 299 300 def _compute_label_mask_one_side(self, pseudo_labels): 301 mask = (pseudo_labels >= self.confidence_threshold) 302 return mask 303 304 def __call__(self, teacher, input_): 305 pseudo_labels = teacher(input_) 306 if self.activation is not None: 307 pseudo_labels = self.activation(pseudo_labels) 308 if self.confidence_threshold is None: 309 label_mask = None 310 else: 311 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 312 else self._compute_label_mask_one_side(pseudo_labels) 313 return pseudo_labels, label_mask 314 315 def end_warm_up(self): 316 self.confidence_threshold = self.max_ct 317 print(f"End of warm-up phase reached: setting confidence threshold to {self.confidence_threshold}") 318 319 def _is_better(self, a, best): 320 if self.mode == "min" and self.threshold_mode == "rel": 321 rel_epsilon = 1.0 - self.threshold 322 return a < best * rel_epsilon 323 324 elif self.mode == "min" and self.threshold_mode == "abs": 325 return a < best - self.threshold 326 327 elif self.mode == "max" and self.threshold_mode == "rel": 328 rel_epsilon = self.threshold + 1.0 329 return a > best * rel_epsilon 330 331 else: # mode == 'max' and epsilon_mode == 'abs': 332 return a > best + self.threshold 333 334 def _reduce_ct(self, epoch): 335 old_ct = self.confidence_threshold 336 if self.threshold_mode == "rel": 337 new_ct = max(self.confidence_threshold * (1-self.factor), self.min_ct) 338 else: # threshold_mode == 'abs': 339 new_ct = max(self.confidence_threshold - self.factor, self.min_ct) 340 if abs(old_ct - new_ct) > self.eps: 341 self.confidence_threshold = new_ct 342 if self.verbose: 343 print(f"Epoch {epoch}: reducing confidence threshold from {old_ct} to {self.confidence_threshold}") 344 345 def decrease_step(self, metric, epoch=None): 346 if epoch is None: 347 epoch = self.last_epoch + 1 348 self.last_epoch = epoch 349 350 # If the metric is None, reduce the confidence threshold every epoch 351 if metric is None: 352 if epoch == 0: 353 return 354 if epoch % self.patience == 0: 355 self._reduce_ct(epoch) 356 return 357 358 else: 359 current = float(metric) 360 361 if self._is_better(current, self.best): 362 self.best = current 363 self.num_bad_epochs = 0 364 else: 365 self.num_bad_epochs += 1 366 367 if self.num_bad_epochs > self.patience: 368 self._reduce_ct(epoch) 369 self.num_bad_epochs = 0 370 371 def _increase_ct(self, epoch): 372 old_ct = self.confidence_threshold 373 if self.threshold_mode == "rel": 374 new_ct = min(self.confidence_threshold * (1+self.factor), self.max_ct) 375 else: # threshold_mode == 'abs': 376 new_ct = min(self.confidence_threshold + self.factor, self.max_ct) 377 if abs(old_ct - new_ct) > self.eps: 378 self.confidence_threshold = new_ct 379 if self.verbose: 380 print(f"Epoch {epoch}: increase confidence threshold from {old_ct} to {self.confidence_threshold}") 381 382 def increase_step(self, epoch): 383 if epoch > self.last_step_epoch: 384 return 385 386 if epoch % self.step_epoch == 0 and epoch != 0: 387 self._increase_ct(epoch) 388 389 def step(self, metric=None, epoch=None): 390 if epoch == self.warm_up_epochs and self.warm_up_epochs > 0: 391 self.end_warm_up() 392 elif epoch > self.warm_up_epochs: 393 if self.increase: 394 self.increase_step(epoch) 395 else: 396 self.decrease_step(metric, epoch) 397 else: 398 return
Implement scheduled pseudo-labeling with dynamic confidence-threshold updates.
Pseudo labels are generated from a teacher model prediction and can be filtered by a confidence mask. The confidence threshold can be adapted over time either by decreasing it based on a monitored metric (plateau behavior) or by increasing it with a fixed epoch schedule.
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If None is given, no mask will be computed.
- increase: If True, increase the confidence threshold over time according to a fixed schedule. If False, decrease it based on plateau detection.
- last_step_epoch: Last epoch at which threshold increase is allowed when
increase=True. - threshold_from_both_sides: Whether to include values larger than the threshold and smaller than 1 - the threshold in the mask, or only values larger than the threshold. The former should be used for binary labels, the latter for multiclass labels.
- mode: Determines whether the confidence threshold reduction is triggered by a "min" or "max" metric.
- 'min': A lower value of the monitored metric is considered better (e.g., loss).
- 'max': A higher value of the monitored metric is considered better (e.g., accuracy).
- factor: Update size for confidence-threshold scheduling. Interpreted as a
multiplicative factor for
threshold_mode='rel'and as an additive step forthreshold_mode='abs'. - patience: Number of epochs (with no improvement) after which the confidence threshold will be reduced.
- threshold: Threshold value for determining a significant improvement in the performance metric
to reset the patience counter. This can be relative (percentage improvement)
or absolute depending on
threshold_mode. - threshold_mode: Determines whether the
thresholdis interpreted as a relative improvement ('rel') or an absolute improvement ('abs'). - min_ct: Minimum allowed confidence threshold. The threshold will not be reduced below this value.
- max_ct: Maximum allowed confidence threshold. The threshold will not be increased above this value.
- eps: A small value to avoid floating-point precision errors during threshold comparison.
- warm_up_epochs: Number of warm-up epochs. At the end of warm-up,
confidence_thresholdis set tomax_ct. This is intended for decreasing-threshold scheduling (increase=False). - mask_channel: Specific channel to use for confidence masking. Currently, only None is supported.
- verbose: If True, prints messages when the confidence threshold is updated.
ScheduledPseudoLabeler( activation: Union[torch.nn.modules.module.Module, Callable, NoneType] = None, confidence_threshold: Optional[float] = None, increase: bool = False, last_step_epoch: int = None, threshold_from_both_sides: bool = True, mode: Literal['min', 'max'] = 'min', factor: float = 0.05, patience: int = 10, threshold: float = 0.0001, threshold_mode: Literal['rel', 'abs'] = 'abs', min_ct: float = 0.5, max_ct: float = 0.95, eps: float = 1e-08, warm_up_epochs: int = 0, mask_channel: Optional[int] = None, verbose: bool = True)
205 def __init__( 206 self, 207 activation: Optional[Union[torch.nn.Module, Callable]] = None, 208 confidence_threshold: Optional[float] = None, 209 increase: bool = False, 210 last_step_epoch: int = None, 211 threshold_from_both_sides: bool = True, 212 mode: Literal["min", "max"] = "min", 213 factor: float = 0.05, 214 patience: int = 10, 215 threshold: float = 1e-4, 216 threshold_mode: Literal["rel", "abs"] = "abs", 217 min_ct: float = 0.5, 218 max_ct: float = 0.95, 219 eps: float = 1e-8, 220 warm_up_epochs: int = 0, 221 mask_channel: Optional[int] = None, 222 verbose: bool = True, 223 ): 224 self.activation = activation 225 self.confidence_threshold = confidence_threshold 226 self.increase = increase 227 self.threshold_from_both_sides = threshold_from_both_sides 228 self.init_kwargs = { 229 "activation": None, "confidence_threshold": confidence_threshold, 230 "threshold_from_both_sides": threshold_from_both_sides, 231 "mask_channel": mask_channel, 232 } 233 # scheduler arguments 234 if mode not in {"min", "max"}: 235 raise ValueError(f"Invalid mode: {mode}. Mode should be 'min' or 'max'.") 236 self.mode = mode 237 238 assert factor < 1, f"Factor must be smaller than 1, got {factor}" 239 self.factor = factor 240 241 self.patience = patience 242 self.threshold = threshold 243 244 if threshold_mode not in {"rel", "abs"}: 245 raise ValueError(f"Invalid threshold mode: {mode}. Threshold mode should be 'rel' or 'abs'.") 246 self.threshold_mode = threshold_mode 247 248 self.min_ct = min_ct 249 self.max_ct = max_ct 250 self.eps = eps 251 self.warm_up_epochs = warm_up_epochs 252 253 if self.increase and self.warm_up_epochs > 0: 254 raise ValueError("warm_up_epochs > 0 is only supported when increase=False.") 255 256 # TODO implement mask_channel functionality; for now only None is supported 257 if mask_channel is not None: 258 raise NotImplementedError("mask_channel is not implemented yet; only None is supported.") 259 self.mask_channel = mask_channel 260 self.verbose = verbose 261 262 if mode == "min": 263 self.best = float("inf") 264 else: # mode == "max": 265 self.best = float("-inf") 266 267 # self.best = 0 268 self.num_bad_epochs: int = 0 269 self.last_epoch = 0 270 271 if self.increase: 272 self.last_step_epoch = last_step_epoch 273 274 n_ct = len(np.arange(self.min_ct, self.max_ct + self.factor / 2, self.factor)) 275 n_increments = n_ct - 1 276 277 if n_increments <= 0: 278 # nothing to increase 279 self.step_epoch = 0 280 else: 281 # compute initial step size; enforce minimum step size of 1 282 self.step_epoch = max(1, int(np.floor(self.last_step_epoch / n_increments))) 283 284 # ensure max_ct is reachable 285 required_epochs = self.step_epoch * n_increments 286 if self.last_step_epoch < required_epochs: 287 self.last_step_epoch = required_epochs 288 289 print( 290 f"ScheduledPseudoLabeler: Increasing confidence_threshold every {self.step_epoch} epochs;", 291 f"until epoch {self.last_step_epoch}" 292 )
def
decrease_step(self, metric, epoch=None):
345 def decrease_step(self, metric, epoch=None): 346 if epoch is None: 347 epoch = self.last_epoch + 1 348 self.last_epoch = epoch 349 350 # If the metric is None, reduce the confidence threshold every epoch 351 if metric is None: 352 if epoch == 0: 353 return 354 if epoch % self.patience == 0: 355 self._reduce_ct(epoch) 356 return 357 358 else: 359 current = float(metric) 360 361 if self._is_better(current, self.best): 362 self.best = current 363 self.num_bad_epochs = 0 364 else: 365 self.num_bad_epochs += 1 366 367 if self.num_bad_epochs > self.patience: 368 self._reduce_ct(epoch) 369 self.num_bad_epochs = 0