torch_em.self_training.pseudo_labeling
1import torch 2 3 4class DefaultPseudoLabeler: 5 """Compute pseudo labels. 6 7 Parameters: 8 activation [nn.Module, callable] - activation function applied to the teacher prediction. 9 confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. 10 If none is given no mask will be computed (default: None) 11 threshold_from_both_sides [bool] - whether to include both values bigger than the threshold 12 and smaller than 1 - it, or only values bigger than it in the mask. 13 The former should be used for binary labels, the latter for for multiclass labels (default: False) 14 """ 15 def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True): 16 self.activation = activation 17 self.confidence_threshold = confidence_threshold 18 self.threshold_from_both_sides = threshold_from_both_sides 19 # TODO serialize the class names and kwargs for activation instead 20 self.init_kwargs = { 21 "activation": None, "confidence_threshold": confidence_threshold, 22 "threshold_from_both_sides": threshold_from_both_sides 23 } 24 25 def _compute_label_mask_both_sides(self, pseudo_labels): 26 upper_threshold = self.confidence_threshold 27 lower_threshold = 1.0 - self.confidence_threshold 28 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 29 return mask 30 31 def _compute_label_mask_one_side(self, pseudo_labels): 32 mask = (pseudo_labels >= self.confidence_threshold) 33 return mask 34 35 def __call__(self, teacher, input_): 36 pseudo_labels = teacher(input_) 37 if self.activation is not None: 38 pseudo_labels = self.activation(pseudo_labels) 39 if self.confidence_threshold is None: 40 label_mask = None 41 else: 42 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 43 else self._compute_label_mask_one_side(pseudo_labels) 44 return pseudo_labels, label_mask 45 46 47class ProbabilisticPseudoLabeler: 48 """Compute pseudo labels from the Probabilistic UNet. 49 50 Parameters: 51 activation [nn.Module, callable] - activation function applied to the teacher prediction. 52 confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. 53 If none is given no mask will be computed (default: None) 54 threshold_from_both_sides [bool] - whether to include both values bigger than the threshold 55 and smaller than 1 - it, or only values bigger than it in the mask. 56 The former should be used for binary labels, the latter for for multiclass labels (default: False) 57 prior_samples [int] - the number of times we want to sample from the 58 prior distribution per inputs (default: 16) 59 consensus_masking [bool] - whether to activate consensus masking in the label filter (default: False) 60 If false, the weighted consensus response (weighted per-pixel response) is returned 61 If true, the masked consensus response (complete aggrement of pixels) is returned 62 """ 63 def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True, 64 prior_samples=16, consensus_masking=False): 65 self.activation = activation 66 self.confidence_threshold = confidence_threshold 67 self.threshold_from_both_sides = threshold_from_both_sides 68 self.prior_samples = prior_samples 69 self.consensus_masking = consensus_masking 70 # TODO serialize the class names and kwargs for activation instead 71 self.init_kwargs = { 72 "activation": None, "confidence_threshold": confidence_threshold, 73 "threshold_from_both_sides": threshold_from_both_sides 74 } 75 76 def _compute_label_mask_both_sides(self, pseudo_labels): 77 upper_threshold = self.confidence_threshold 78 lower_threshold = 1.0 - self.confidence_threshold 79 mask = [torch.where((sample >= upper_threshold) + (sample <= lower_threshold), 80 torch.tensor(1.), 81 torch.tensor(0.)) for sample in pseudo_labels] 82 return mask 83 84 def _compute_label_mask_one_side(self, pseudo_labels): 85 mask = [torch.where((sample >= self.confidence_threshold), 86 torch.tensor(1.), 87 torch.tensor(0.)) for sample in pseudo_labels] 88 return mask 89 90 def __call__(self, teacher, input_): 91 teacher.forward(input_) 92 if self.activation is not None: 93 pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] 94 else: 95 pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] 96 pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples 97 98 if self.confidence_threshold is None: 99 label_mask = None 100 else: 101 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ 102 else self._compute_label_mask_one_side(pseudo_labels) 103 label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples 104 if self.consensus_masking: 105 label_mask = torch.where(label_mask == 1, 1, 0) 106 107 return pseudo_labels, label_mask
class
DefaultPseudoLabeler:
5class DefaultPseudoLabeler: 6 """Compute pseudo labels. 7 8 Parameters: 9 activation [nn.Module, callable] - activation function applied to the teacher prediction. 10 confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. 11 If none is given no mask will be computed (default: None) 12 threshold_from_both_sides [bool] - whether to include both values bigger than the threshold 13 and smaller than 1 - it, or only values bigger than it in the mask. 14 The former should be used for binary labels, the latter for for multiclass labels (default: False) 15 """ 16 def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True): 17 self.activation = activation 18 self.confidence_threshold = confidence_threshold 19 self.threshold_from_both_sides = threshold_from_both_sides 20 # TODO serialize the class names and kwargs for activation instead 21 self.init_kwargs = { 22 "activation": None, "confidence_threshold": confidence_threshold, 23 "threshold_from_both_sides": threshold_from_both_sides 24 } 25 26 def _compute_label_mask_both_sides(self, pseudo_labels): 27 upper_threshold = self.confidence_threshold 28 lower_threshold = 1.0 - self.confidence_threshold 29 mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) 30 return mask 31 32 def _compute_label_mask_one_side(self, pseudo_labels): 33 mask = (pseudo_labels >= self.confidence_threshold) 34 return mask 35 36 def __call__(self, teacher, input_): 37 pseudo_labels = teacher(input_) 38 if self.activation is not None: 39 pseudo_labels = self.activation(pseudo_labels) 40 if self.confidence_threshold is None: 41 label_mask = None 42 else: 43 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ 44 else self._compute_label_mask_one_side(pseudo_labels) 45 return pseudo_labels, label_mask
Compute pseudo labels.
Arguments:
- activation [nn.Module, callable] - activation function applied to the teacher prediction.
- confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed (default: None)
- threshold_from_both_sides [bool] - whether to include both values bigger than the threshold and smaller than 1 - it, or only values bigger than it in the mask. The former should be used for binary labels, the latter for for multiclass labels (default: False)
DefaultPseudoLabeler( activation=None, confidence_threshold=None, threshold_from_both_sides=True)
16 def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True): 17 self.activation = activation 18 self.confidence_threshold = confidence_threshold 19 self.threshold_from_both_sides = threshold_from_both_sides 20 # TODO serialize the class names and kwargs for activation instead 21 self.init_kwargs = { 22 "activation": None, "confidence_threshold": confidence_threshold, 23 "threshold_from_both_sides": threshold_from_both_sides 24 }
class
ProbabilisticPseudoLabeler:
48class ProbabilisticPseudoLabeler: 49 """Compute pseudo labels from the Probabilistic UNet. 50 51 Parameters: 52 activation [nn.Module, callable] - activation function applied to the teacher prediction. 53 confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. 54 If none is given no mask will be computed (default: None) 55 threshold_from_both_sides [bool] - whether to include both values bigger than the threshold 56 and smaller than 1 - it, or only values bigger than it in the mask. 57 The former should be used for binary labels, the latter for for multiclass labels (default: False) 58 prior_samples [int] - the number of times we want to sample from the 59 prior distribution per inputs (default: 16) 60 consensus_masking [bool] - whether to activate consensus masking in the label filter (default: False) 61 If false, the weighted consensus response (weighted per-pixel response) is returned 62 If true, the masked consensus response (complete aggrement of pixels) is returned 63 """ 64 def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True, 65 prior_samples=16, consensus_masking=False): 66 self.activation = activation 67 self.confidence_threshold = confidence_threshold 68 self.threshold_from_both_sides = threshold_from_both_sides 69 self.prior_samples = prior_samples 70 self.consensus_masking = consensus_masking 71 # TODO serialize the class names and kwargs for activation instead 72 self.init_kwargs = { 73 "activation": None, "confidence_threshold": confidence_threshold, 74 "threshold_from_both_sides": threshold_from_both_sides 75 } 76 77 def _compute_label_mask_both_sides(self, pseudo_labels): 78 upper_threshold = self.confidence_threshold 79 lower_threshold = 1.0 - self.confidence_threshold 80 mask = [torch.where((sample >= upper_threshold) + (sample <= lower_threshold), 81 torch.tensor(1.), 82 torch.tensor(0.)) for sample in pseudo_labels] 83 return mask 84 85 def _compute_label_mask_one_side(self, pseudo_labels): 86 mask = [torch.where((sample >= self.confidence_threshold), 87 torch.tensor(1.), 88 torch.tensor(0.)) for sample in pseudo_labels] 89 return mask 90 91 def __call__(self, teacher, input_): 92 teacher.forward(input_) 93 if self.activation is not None: 94 pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] 95 else: 96 pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] 97 pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples 98 99 if self.confidence_threshold is None: 100 label_mask = None 101 else: 102 label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ 103 else self._compute_label_mask_one_side(pseudo_labels) 104 label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples 105 if self.consensus_masking: 106 label_mask = torch.where(label_mask == 1, 1, 0) 107 108 return pseudo_labels, label_mask
Compute pseudo labels from the Probabilistic UNet.
Arguments:
- activation [nn.Module, callable] - activation function applied to the teacher prediction.
- confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed (default: None)
- threshold_from_both_sides [bool] - whether to include both values bigger than the threshold and smaller than 1 - it, or only values bigger than it in the mask. The former should be used for binary labels, the latter for for multiclass labels (default: False)
- prior_samples [int] - the number of times we want to sample from the prior distribution per inputs (default: 16)
- consensus_masking [bool] - whether to activate consensus masking in the label filter (default: False) If false, the weighted consensus response (weighted per-pixel response) is returned If true, the masked consensus response (complete aggrement of pixels) is returned
ProbabilisticPseudoLabeler( activation=None, confidence_threshold=None, threshold_from_both_sides=True, prior_samples=16, consensus_masking=False)
64 def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True, 65 prior_samples=16, consensus_masking=False): 66 self.activation = activation 67 self.confidence_threshold = confidence_threshold 68 self.threshold_from_both_sides = threshold_from_both_sides 69 self.prior_samples = prior_samples 70 self.consensus_masking = consensus_masking 71 # TODO serialize the class names and kwargs for activation instead 72 self.init_kwargs = { 73 "activation": None, "confidence_threshold": confidence_threshold, 74 "threshold_from_both_sides": threshold_from_both_sides 75 }