torch_em.self_training.pseudo_labeling
1from typing import Optional 2 3import torch 4 5 6class DefaultPseudoLabeler: 7 """Compute pseudo labels based on model predictions, typically from a teacher model. 8 9 Args: 10 activation: Activation function applied to the teacher prediction. 11 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 12 If None is given no mask will be computed. 13 threshold_from_both_sides: Whether to include both values bigger than the threshold 14 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 15 The former should be used for binary labels, the latter for for multiclass labels. 16 """ 17 def __init__( 18 self, 19 activation: Optional[torch.nn.Module] = None, 20 confidence_threshold: Optional[float] = None, 21 threshold_from_both_sides: bool = True, 22 ): 23 self.activation = activation 24 self.confidence_threshold = confidence_threshold 25 self.threshold_from_both_sides = threshold_from_both_sides 26 # TODO serialize the class names and kwargs for activation instead 27 self.init_kwargs = { 28 "activation": None, "confidence_threshold": confidence_threshold, 29 "threshold_from_both_sides": threshold_from_both_sides 30 } 31 32 def _compute_label_mask_both_sides(self, pseudo_labels): 33 upper_threshold = self.confidence_threshold 34 lower_threshold = 1.0 - self.confidence_threshold 35 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 36 return mask 37 38 def _compute_label_mask_one_side(self, pseudo_labels): 39 mask = (pseudo_labels >= self.confidence_threshold) 40 return mask 41 42 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 43 """Compute pseudo-labels. 44 45 Args: 46 teacher: The teacher model. 47 input_: The input for this batch. 48 49 Returns: 50 The pseudo-labels. 51 """ 52 pseudo_labels = teacher(input_) 53 if self.activation is not None: 54 pseudo_labels = self.activation(pseudo_labels) 55 if self.confidence_threshold is None: 56 label_mask = None 57 else: 58 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 59 else self._compute_label_mask_one_side(pseudo_labels) 60 return pseudo_labels, label_mask 61 62 63class ProbabilisticPseudoLabeler: 64 """Compute pseudo labels from a Probabilistic UNet. 65 66 Args: 67 activation: Activation function applied to the teacher prediction. 68 confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. 69 If none is given no mask will be computed. 70 threshold_from_both_sides: Whether to include both values bigger than the threshold 71 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 72 The former should be used for binary labels, the latter for for multiclass labels. 73 prior_samples: The number of times to sample from the model distribution per input. 74 consensus_masking: Whether to activate consensus masking in the label filter. 75 If False, the weighted consensus response (weighted per-pixel response) is returned. 76 If True, the masked consensus response (complete aggrement of pixels) is returned. 77 """ 78 def __init__( 79 self, 80 activation: Optional[torch.nn.Module] = None, 81 confidence_threshold: Optional[float] = None, 82 threshold_from_both_sides: bool = True, 83 prior_samples: int = 16, 84 consensus_masking: bool = False, 85 ): 86 self.activation = activation 87 self.confidence_threshold = confidence_threshold 88 self.threshold_from_both_sides = threshold_from_both_sides 89 self.prior_samples = prior_samples 90 self.consensus_masking = consensus_masking 91 # TODO serialize the class names and kwargs for activation instead 92 self.init_kwargs = { 93 "activation": None, "confidence_threshold": confidence_threshold, 94 "threshold_from_both_sides": threshold_from_both_sides 95 } 96 97 def _compute_label_mask_both_sides(self, pseudo_labels): 98 upper_threshold = self.confidence_threshold 99 lower_threshold = 1.0 - self.confidence_threshold 100 mask = [ 101 torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.)) 102 for sample in pseudo_labels 103 ] 104 return mask 105 106 def _compute_label_mask_one_side(self, pseudo_labels): 107 mask = [ 108 torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.)) 109 for sample in pseudo_labels 110 ] 111 return mask 112 113 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 114 """Compute pseudo-labels. 115 116 Args: 117 teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`. 118 input_: The input for this batch. 119 120 Returns: 121 The pseudo-labels. 122 """ 123 teacher.forward(input_) 124 if self.activation is not None: 125 pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] 126 else: 127 pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] 128 pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples 129 130 if self.confidence_threshold is None: 131 label_mask = None 132 else: 133 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ 134 else self._compute_label_mask_one_side(pseudo_labels) 135 label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples 136 if self.consensus_masking: 137 label_mask = torch.where(label_mask == 1, 1, 0) 138 139 return pseudo_labels, label_mask
class
DefaultPseudoLabeler:
7class DefaultPseudoLabeler: 8 """Compute pseudo labels based on model predictions, typically from a teacher model. 9 10 Args: 11 activation: Activation function applied to the teacher prediction. 12 confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 13 If None is given no mask will be computed. 14 threshold_from_both_sides: Whether to include both values bigger than the threshold 15 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 16 The former should be used for binary labels, the latter for for multiclass labels. 17 """ 18 def __init__( 19 self, 20 activation: Optional[torch.nn.Module] = None, 21 confidence_threshold: Optional[float] = None, 22 threshold_from_both_sides: bool = True, 23 ): 24 self.activation = activation 25 self.confidence_threshold = confidence_threshold 26 self.threshold_from_both_sides = threshold_from_both_sides 27 # TODO serialize the class names and kwargs for activation instead 28 self.init_kwargs = { 29 "activation": None, "confidence_threshold": confidence_threshold, 30 "threshold_from_both_sides": threshold_from_both_sides 31 } 32 33 def _compute_label_mask_both_sides(self, pseudo_labels): 34 upper_threshold = self.confidence_threshold 35 lower_threshold = 1.0 - self.confidence_threshold 36 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 37 return mask 38 39 def _compute_label_mask_one_side(self, pseudo_labels): 40 mask = (pseudo_labels >= self.confidence_threshold) 41 return mask 42 43 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 44 """Compute pseudo-labels. 45 46 Args: 47 teacher: The teacher model. 48 input_: The input for this batch. 49 50 Returns: 51 The pseudo-labels. 52 """ 53 pseudo_labels = teacher(input_) 54 if self.activation is not None: 55 pseudo_labels = self.activation(pseudo_labels) 56 if self.confidence_threshold is None: 57 label_mask = None 58 else: 59 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 60 else self._compute_label_mask_one_side(pseudo_labels) 61 return pseudo_labels, label_mask
Compute pseudo labels based on model predictions, typically from a teacher model.
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. If None is given no mask will be computed.
- threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
DefaultPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True)
18 def __init__( 19 self, 20 activation: Optional[torch.nn.Module] = None, 21 confidence_threshold: Optional[float] = None, 22 threshold_from_both_sides: bool = True, 23 ): 24 self.activation = activation 25 self.confidence_threshold = confidence_threshold 26 self.threshold_from_both_sides = threshold_from_both_sides 27 # TODO serialize the class names and kwargs for activation instead 28 self.init_kwargs = { 29 "activation": None, "confidence_threshold": confidence_threshold, 30 "threshold_from_both_sides": threshold_from_both_sides 31 }
class
ProbabilisticPseudoLabeler:
64class ProbabilisticPseudoLabeler: 65 """Compute pseudo labels from a Probabilistic UNet. 66 67 Args: 68 activation: Activation function applied to the teacher prediction. 69 confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. 70 If none is given no mask will be computed. 71 threshold_from_both_sides: Whether to include both values bigger than the threshold 72 and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 73 The former should be used for binary labels, the latter for for multiclass labels. 74 prior_samples: The number of times to sample from the model distribution per input. 75 consensus_masking: Whether to activate consensus masking in the label filter. 76 If False, the weighted consensus response (weighted per-pixel response) is returned. 77 If True, the masked consensus response (complete aggrement of pixels) is returned. 78 """ 79 def __init__( 80 self, 81 activation: Optional[torch.nn.Module] = None, 82 confidence_threshold: Optional[float] = None, 83 threshold_from_both_sides: bool = True, 84 prior_samples: int = 16, 85 consensus_masking: bool = False, 86 ): 87 self.activation = activation 88 self.confidence_threshold = confidence_threshold 89 self.threshold_from_both_sides = threshold_from_both_sides 90 self.prior_samples = prior_samples 91 self.consensus_masking = consensus_masking 92 # TODO serialize the class names and kwargs for activation instead 93 self.init_kwargs = { 94 "activation": None, "confidence_threshold": confidence_threshold, 95 "threshold_from_both_sides": threshold_from_both_sides 96 } 97 98 def _compute_label_mask_both_sides(self, pseudo_labels): 99 upper_threshold = self.confidence_threshold 100 lower_threshold = 1.0 - self.confidence_threshold 101 mask = [ 102 torch.where((sample >= upper_threshold) + (sample <= lower_threshold), torch.tensor(1.), torch.tensor(0.)) 103 for sample in pseudo_labels 104 ] 105 return mask 106 107 def _compute_label_mask_one_side(self, pseudo_labels): 108 mask = [ 109 torch.where((sample >= self.confidence_threshold), torch.tensor(1.), torch.tensor(0.)) 110 for sample in pseudo_labels 111 ] 112 return mask 113 114 def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: 115 """Compute pseudo-labels. 116 117 Args: 118 teacher: The teacher model. Must be a `torch_em.model.probabilistic_unet.ProbabilisticUNet`. 119 input_: The input for this batch. 120 121 Returns: 122 The pseudo-labels. 123 """ 124 teacher.forward(input_) 125 if self.activation is not None: 126 pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] 127 else: 128 pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] 129 pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples 130 131 if self.confidence_threshold is None: 132 label_mask = None 133 else: 134 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ 135 else self._compute_label_mask_one_side(pseudo_labels) 136 label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples 137 if self.consensus_masking: 138 label_mask = torch.where(label_mask == 1, 1, 0) 139 140 return pseudo_labels, label_mask
Compute pseudo labels from a Probabilistic UNet.
Arguments:
- activation: Activation function applied to the teacher prediction.
- confidence_threshold: Threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed.
- threshold_from_both_sides: Whether to include both values bigger than the threshold and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. The former should be used for binary labels, the latter for for multiclass labels.
- prior_samples: The number of times to sample from the model distribution per input.
- consensus_masking: Whether to activate consensus masking in the label filter. If False, the weighted consensus response (weighted per-pixel response) is returned. If True, the masked consensus response (complete aggrement of pixels) is returned.
ProbabilisticPseudoLabeler( activation: Optional[torch.nn.modules.module.Module] = None, confidence_threshold: Optional[float] = None, threshold_from_both_sides: bool = True, prior_samples: int = 16, consensus_masking: bool = False)
79 def __init__( 80 self, 81 activation: Optional[torch.nn.Module] = None, 82 confidence_threshold: Optional[float] = None, 83 threshold_from_both_sides: bool = True, 84 prior_samples: int = 16, 85 consensus_masking: bool = False, 86 ): 87 self.activation = activation 88 self.confidence_threshold = confidence_threshold 89 self.threshold_from_both_sides = threshold_from_both_sides 90 self.prior_samples = prior_samples 91 self.consensus_masking = consensus_masking 92 # TODO serialize the class names and kwargs for activation instead 93 self.init_kwargs = { 94 "activation": None, "confidence_threshold": confidence_threshold, 95 "threshold_from_both_sides": threshold_from_both_sides 96 }