refactoring w/o labelled collection
This commit is contained in:
parent
075be93a23
commit
5738821d10
|
|
@ -729,6 +729,11 @@ class EMQ(AggregativeSoftQuantifier):
|
||||||
posteriors = self.calibration_function(posteriors)
|
posteriors = self.calibration_function(posteriors)
|
||||||
return posteriors
|
return posteriors
|
||||||
|
|
||||||
|
def classifier_fit_predict(self, X, y):
|
||||||
|
classif_predictions = super().classifier_fit_predict(X, y)
|
||||||
|
self.train_prevalence = F.prevalence_from_labels(y, classes=self.classes_)
|
||||||
|
return classif_predictions
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions):
|
def aggregation_fit(self, classif_predictions):
|
||||||
"""
|
"""
|
||||||
Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities
|
Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities
|
||||||
|
|
@ -756,9 +761,7 @@ class EMQ(AggregativeSoftQuantifier):
|
||||||
y = np.searchsorted(self.classes_, y)
|
y = np.searchsorted(self.classes_, y)
|
||||||
self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True)
|
self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True)
|
||||||
|
|
||||||
if self.exact_train_prev:
|
if not self.exact_train_prev:
|
||||||
self.train_prevalence = F.prevalence_from_labels(y, self.classes_)
|
|
||||||
else:
|
|
||||||
train_posteriors = classif_predictions.X
|
train_posteriors = classif_predictions.X
|
||||||
if self.recalib is not None:
|
if self.recalib is not None:
|
||||||
train_posteriors = self.calibration_function(train_posteriors)
|
train_posteriors = self.calibration_function(train_posteriors)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue