torch_em.self_training.pseudo_labeling

  1import torch
  2
  3
  4class DefaultPseudoLabeler:
  5    """Compute pseudo labels.
  6
  7    Parameters:
  8        activation [nn.Module, callable] - activation function applied to the teacher prediction.
  9        confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels.
 10            If none is given no mask will be computed (default: None)
 11        threshold_from_both_sides [bool] - whether to include both values bigger than the threshold
 12            and smaller than 1 - it, or only values bigger than it in the mask.
 13            The former should be used for binary labels, the latter for for multiclass labels (default: False)
 14    """
 15    def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True):
 16        self.activation = activation
 17        self.confidence_threshold = confidence_threshold
 18        self.threshold_from_both_sides = threshold_from_both_sides
 19        # TODO serialize the class names and kwargs for activation instead
 20        self.init_kwargs = {
 21            "activation": None, "confidence_threshold": confidence_threshold,
 22            "threshold_from_both_sides": threshold_from_both_sides
 23        }
 24
 25    def _compute_label_mask_both_sides(self, pseudo_labels):
 26        upper_threshold = self.confidence_threshold
 27        lower_threshold = 1.0 - self.confidence_threshold
 28        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
 29        return mask
 30
 31    def _compute_label_mask_one_side(self, pseudo_labels):
 32        mask = (pseudo_labels >= self.confidence_threshold)
 33        return mask
 34
 35    def __call__(self, teacher, input_):
 36        pseudo_labels = teacher(input_)
 37        if self.activation is not None:
 38            pseudo_labels = self.activation(pseudo_labels)
 39        if self.confidence_threshold is None:
 40            label_mask = None
 41        else:
 42            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
 43                else self._compute_label_mask_one_side(pseudo_labels)
 44        return pseudo_labels, label_mask
 45
 46
 47class ProbabilisticPseudoLabeler:
 48    """Compute pseudo labels from the Probabilistic UNet.
 49
 50    Parameters:
 51        activation [nn.Module, callable] - activation function applied to the teacher prediction.
 52        confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels.
 53            If none is given no mask will be computed (default: None)
 54        threshold_from_both_sides [bool] - whether to include both values bigger than the threshold
 55            and smaller than 1 - it, or only values bigger than it in the mask.
 56            The former should be used for binary labels, the latter for for multiclass labels (default: False)
 57        prior_samples [int] - the number of times we want to sample from the
 58            prior distribution per inputs (default: 16)
 59        consensus_masking [bool] - whether to activate consensus masking in the label filter (default: False)
 60            If false, the weighted consensus response (weighted per-pixel response) is returned
 61            If true, the masked consensus response (complete aggrement of pixels) is returned
 62    """
 63    def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True,
 64                 prior_samples=16, consensus_masking=False):
 65        self.activation = activation
 66        self.confidence_threshold = confidence_threshold
 67        self.threshold_from_both_sides = threshold_from_both_sides
 68        self.prior_samples = prior_samples
 69        self.consensus_masking = consensus_masking
 70        # TODO serialize the class names and kwargs for activation instead
 71        self.init_kwargs = {
 72            "activation": None, "confidence_threshold": confidence_threshold,
 73            "threshold_from_both_sides": threshold_from_both_sides
 74        }
 75
 76    def _compute_label_mask_both_sides(self, pseudo_labels):
 77        upper_threshold = self.confidence_threshold
 78        lower_threshold = 1.0 - self.confidence_threshold
 79        mask = [torch.where((sample >= upper_threshold) + (sample <= lower_threshold),
 80                            torch.tensor(1.),
 81                            torch.tensor(0.)) for sample in pseudo_labels]
 82        return mask
 83
 84    def _compute_label_mask_one_side(self, pseudo_labels):
 85        mask = [torch.where((sample >= self.confidence_threshold),
 86                            torch.tensor(1.),
 87                            torch.tensor(0.)) for sample in pseudo_labels]
 88        return mask
 89
 90    def __call__(self, teacher, input_):
 91        teacher.forward(input_)
 92        if self.activation is not None:
 93            pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)]
 94        else:
 95            pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)]
 96        pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples
 97
 98        if self.confidence_threshold is None:
 99            label_mask = None
100        else:
101            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \
102                else self._compute_label_mask_one_side(pseudo_labels)
103            label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples
104            if self.consensus_masking:
105                label_mask = torch.where(label_mask == 1, 1, 0)
106
107        return pseudo_labels, label_mask
class DefaultPseudoLabeler:
 5class DefaultPseudoLabeler:
 6    """Compute pseudo labels.
 7
 8    Parameters:
 9        activation [nn.Module, callable] - activation function applied to the teacher prediction.
10        confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels.
11            If none is given no mask will be computed (default: None)
12        threshold_from_both_sides [bool] - whether to include both values bigger than the threshold
13            and smaller than 1 - it, or only values bigger than it in the mask.
14            The former should be used for binary labels, the latter for for multiclass labels (default: False)
15    """
16    def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True):
17        self.activation = activation
18        self.confidence_threshold = confidence_threshold
19        self.threshold_from_both_sides = threshold_from_both_sides
20        # TODO serialize the class names and kwargs for activation instead
21        self.init_kwargs = {
22            "activation": None, "confidence_threshold": confidence_threshold,
23            "threshold_from_both_sides": threshold_from_both_sides
24        }
25
26    def _compute_label_mask_both_sides(self, pseudo_labels):
27        upper_threshold = self.confidence_threshold
28        lower_threshold = 1.0 - self.confidence_threshold
29        mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32)
30        return mask
31
32    def _compute_label_mask_one_side(self, pseudo_labels):
33        mask = (pseudo_labels >= self.confidence_threshold)
34        return mask
35
36    def __call__(self, teacher, input_):
37        pseudo_labels = teacher(input_)
38        if self.activation is not None:
39            pseudo_labels = self.activation(pseudo_labels)
40        if self.confidence_threshold is None:
41            label_mask = None
42        else:
43            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\
44                else self._compute_label_mask_one_side(pseudo_labels)
45        return pseudo_labels, label_mask

Compute pseudo labels.

Arguments:
  • activation [nn.Module, callable] - activation function applied to the teacher prediction.
  • confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed (default: None)
  • threshold_from_both_sides [bool] - whether to include both values bigger than the threshold and smaller than 1 - it, or only values bigger than it in the mask. The former should be used for binary labels, the latter for for multiclass labels (default: False)
DefaultPseudoLabeler( activation=None, confidence_threshold=None, threshold_from_both_sides=True)
16    def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True):
17        self.activation = activation
18        self.confidence_threshold = confidence_threshold
19        self.threshold_from_both_sides = threshold_from_both_sides
20        # TODO serialize the class names and kwargs for activation instead
21        self.init_kwargs = {
22            "activation": None, "confidence_threshold": confidence_threshold,
23            "threshold_from_both_sides": threshold_from_both_sides
24        }
activation
confidence_threshold
threshold_from_both_sides
init_kwargs
class ProbabilisticPseudoLabeler:
 48class ProbabilisticPseudoLabeler:
 49    """Compute pseudo labels from the Probabilistic UNet.
 50
 51    Parameters:
 52        activation [nn.Module, callable] - activation function applied to the teacher prediction.
 53        confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels.
 54            If none is given no mask will be computed (default: None)
 55        threshold_from_both_sides [bool] - whether to include both values bigger than the threshold
 56            and smaller than 1 - it, or only values bigger than it in the mask.
 57            The former should be used for binary labels, the latter for for multiclass labels (default: False)
 58        prior_samples [int] - the number of times we want to sample from the
 59            prior distribution per inputs (default: 16)
 60        consensus_masking [bool] - whether to activate consensus masking in the label filter (default: False)
 61            If false, the weighted consensus response (weighted per-pixel response) is returned
 62            If true, the masked consensus response (complete aggrement of pixels) is returned
 63    """
 64    def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True,
 65                 prior_samples=16, consensus_masking=False):
 66        self.activation = activation
 67        self.confidence_threshold = confidence_threshold
 68        self.threshold_from_both_sides = threshold_from_both_sides
 69        self.prior_samples = prior_samples
 70        self.consensus_masking = consensus_masking
 71        # TODO serialize the class names and kwargs for activation instead
 72        self.init_kwargs = {
 73            "activation": None, "confidence_threshold": confidence_threshold,
 74            "threshold_from_both_sides": threshold_from_both_sides
 75        }
 76
 77    def _compute_label_mask_both_sides(self, pseudo_labels):
 78        upper_threshold = self.confidence_threshold
 79        lower_threshold = 1.0 - self.confidence_threshold
 80        mask = [torch.where((sample >= upper_threshold) + (sample <= lower_threshold),
 81                            torch.tensor(1.),
 82                            torch.tensor(0.)) for sample in pseudo_labels]
 83        return mask
 84
 85    def _compute_label_mask_one_side(self, pseudo_labels):
 86        mask = [torch.where((sample >= self.confidence_threshold),
 87                            torch.tensor(1.),
 88                            torch.tensor(0.)) for sample in pseudo_labels]
 89        return mask
 90
 91    def __call__(self, teacher, input_):
 92        teacher.forward(input_)
 93        if self.activation is not None:
 94            pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)]
 95        else:
 96            pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)]
 97        pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples
 98
 99        if self.confidence_threshold is None:
100            label_mask = None
101        else:
102            label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \
103                else self._compute_label_mask_one_side(pseudo_labels)
104            label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples
105            if self.consensus_masking:
106                label_mask = torch.where(label_mask == 1, 1, 0)
107
108        return pseudo_labels, label_mask

Compute pseudo labels from the Probabilistic UNet.

Arguments:
  • activation [nn.Module, callable] - activation function applied to the teacher prediction.
  • confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. If none is given no mask will be computed (default: None)
  • threshold_from_both_sides [bool] - whether to include both values bigger than the threshold and smaller than 1 - it, or only values bigger than it in the mask. The former should be used for binary labels, the latter for for multiclass labels (default: False)
  • prior_samples [int] - the number of times we want to sample from the prior distribution per inputs (default: 16)
  • consensus_masking [bool] - whether to activate consensus masking in the label filter (default: False) If false, the weighted consensus response (weighted per-pixel response) is returned If true, the masked consensus response (complete aggrement of pixels) is returned
ProbabilisticPseudoLabeler( activation=None, confidence_threshold=None, threshold_from_both_sides=True, prior_samples=16, consensus_masking=False)
64    def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True,
65                 prior_samples=16, consensus_masking=False):
66        self.activation = activation
67        self.confidence_threshold = confidence_threshold
68        self.threshold_from_both_sides = threshold_from_both_sides
69        self.prior_samples = prior_samples
70        self.consensus_masking = consensus_masking
71        # TODO serialize the class names and kwargs for activation instead
72        self.init_kwargs = {
73            "activation": None, "confidence_threshold": confidence_threshold,
74            "threshold_from_both_sides": threshold_from_both_sides
75        }
activation
confidence_threshold
threshold_from_both_sides
prior_samples
consensus_masking
init_kwargs