From c20d9d5ea415d1fd67551f9744f797f840e20221 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Mon, 12 Dec 2022 17:32:30 +0100 Subject: [PATCH] the heuristic exact_train_prev is performed via kFCV, using a new function qp.model_selection.cross_val_predict --- TODO.txt | 2 ++ quapy/method/aggregative.py | 11 ++++++++--- quapy/model_selection.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/TODO.txt b/TODO.txt index 90f3301..6cef78c 100644 --- a/TODO.txt +++ b/TODO.txt @@ -3,6 +3,8 @@ clean all the cumbersome methods that have to be implemented for new quantifiers make truly parallel the GridSearchQ make more examples in the "examples" directory merge with master, because I had to fix some problems with QuaNet due to an issue notified via GitHub! +added cross_val_predict in qp.model_selection (i.e., a cross_val_predict for quantification) --would be nice to have + it parallelized Packaging: diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index 202b5dd..4cec2cd 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -3,7 +3,7 @@ from copy import deepcopy from typing import Callable, Union import numpy as np from joblib import Parallel, delayed -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, clone from sklearn.calibration import CalibratedClassifierCV from sklearn.metrics import confusion_matrix from sklearn.model_selection import StratifiedKFold, cross_val_predict @@ -503,7 +503,7 @@ class EMQ(AggregativeProbabilisticQuantifier): :param learner: a sklearn's Estimator that generates a classifier :param exact_train_prev: set to True (default) for using, as the initial observation, the true training prevalence; or set to False for computing the training prevalence as an estimate, akin to PCC, i.e., as the expected - value of the posterior probabilities of the trianing documents as suggested in + value of the posterior probabilities of the training instances as suggested in `Alexandari et al. paper `_: """ @@ -519,7 +519,12 @@ class EMQ(AggregativeProbabilisticQuantifier): if self.exact_train_prev: self.train_prevalence = F.prevalence_from_labels(data.labels, self.classes_) else: - self.train_prevalence = PCC(learner=self.learner).fit(data, fit_learner=False).quantify(data.X) + self.train_prevalence = qp.model_selection.cross_val_predict( + quantifier=PCC(clone(self.learner)), + data=data, + nfolds=3, + random_state=0 + ) return self def aggregate(self, classif_posteriors, epsilon=EPSILON): diff --git a/quapy/model_selection.py b/quapy/model_selection.py index 41a7a19..f7c5b94 100644 --- a/quapy/model_selection.py +++ b/quapy/model_selection.py @@ -2,6 +2,10 @@ import itertools import signal from copy import deepcopy from typing import Union, Callable + +import numpy as np +from sklearn import clone + import quapy as qp from quapy import evaluation from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol @@ -187,3 +191,28 @@ class GridSearchQ(BaseQuantifier): raise ValueError('best_model called before fit') + + +def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0): + """ + Akin to `scikit-learn's cross_val_predict `_ + but for quantification. + + :param quantifier: a quantifier issuing class prevalence values + :param data: a labelled collection + :param nfolds: number of folds for k-fold cross validation generation + :param random_state: random seed for reproducibility + :return: a vector of class prevalence values + """ + + total_prev = np.zeros(shape=data.n_classes) + + for train, test in data.kFCV(nfolds=nfolds, random_state=random_state): + quantifier.fit(train) + fold_prev = quantifier.quantify(test.X) + rel_size = len(test.X)/len(data) + total_prev += fold_prev*rel_size + + return total_prev + +