From d0444d3bbb24005beadc3dbd20b84359f9c2e54b Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Fri, 8 Mar 2024 17:04:10 +0100 Subject: [PATCH] added artificial accuracy protocol --- ClassifierAccuracy/util/commons.py | 41 +++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/ClassifierAccuracy/util/commons.py b/ClassifierAccuracy/util/commons.py index 67df860..ac01374 100644 --- a/ClassifierAccuracy/util/commons.py +++ b/ClassifierAccuracy/util/commons.py @@ -14,6 +14,7 @@ from sklearn.model_selection import GridSearchCV from ClassifierAccuracy.models_multiclass import * from ClassifierAccuracy.util.tabular import Table +from quapy.protocol import OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol from quapy.method.aggregative import EMQ, ACC, KDEyML from quapy.data import LabelledCollection @@ -101,11 +102,11 @@ def gen_bin_datasets(only_names=False) -> [str,[LabelledCollection,LabelledColle def gen_CAP(h, acc_fn, with_oracle=False)->[str, ClassifierAccuracyPrediction]: #yield 'SebCAP', SebastianiCAP(h, acc_fn, ACC) - yield 'SebCAP-SLD', SebastianiCAP(h, acc_fn, EMQ, predict_train_prev=not with_oracle) + # yield 'SebCAP-SLD', SebastianiCAP(h, acc_fn, EMQ, predict_train_prev=not with_oracle) #yield 'SebCAP-KDE', SebastianiCAP(h, acc_fn, KDEyML) #yield 'SebCAPweight', SebastianiCAP(h, acc_fn, ACC, alpha=0) #yield 'PabCAP', PabloCAP(h, acc_fn, ACC) - yield 'PabCAP-SLD-median', PabloCAP(h, acc_fn, EMQ, aggr='median') + # yield 'PabCAP-SLD-median', PabloCAP(h, acc_fn, EMQ, aggr='median') yield 'ATC-MC', ATC(h, acc_fn, scoring_fn='maxconf') # yield 'ATC-NE', ATC(h, acc_fn, scoring_fn='neg_entropy') yield 'DoC', DoC(h, acc_fn, sample_size=qp.environ['SAMPLE_SIZE']) @@ -116,7 +117,7 @@ def gen_CAP_cont_table(h)->[str,CAPContingencyTable]: yield 'Naive', NaiveCAP(h, acc_fn) yield 'CT-PPS-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression())) # yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01)) - yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05)) + # yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05)) #yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False) #yield 'QuAcc(EMQ)nxn', QuAccNxN(h, acc_fn, EMQ(LogisticRegression())) #yield 'QuAcc(EMQ)nxn-MC', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_maxconf=True) @@ -372,6 +373,7 @@ def gen_tables(basedir, datasets): tex = table.latexTabular() table_name = f'{basedir}_{classifier}_{metric}.tex' + table_name = table_name.replace('/', '_') with open(f'./tables/{table_name}', 'wt') as foo: foo.write('\\begin{table}[h]\n') foo.write('\\centering\n') @@ -398,3 +400,36 @@ def gen_tables(basedir, datasets): os.system('rm main.aux main.log') +class ArtificialAccuracyProtocol(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): + + def __init__(self, data: LabelledCollection, h: BaseEstimator, sample_size=None, n_prevalences=101, repeats=10, random_state=0): + super(ArtificialAccuracyProtocol, self).__init__(random_state) + self.data = data + self.h = h + self.sample_size = qp._get_sample_size(sample_size) + self.n_prevalences = n_prevalences + self.repeats = repeats + self.collator = OnLabelledCollectionProtocol.get_collator('labelled_collection') + + def accuracy_grid(self): + grid = np.linspace(0, 1, self.n_prevalences) + grid = np.repeat(grid, self.repeats, axis=0) + return grid + + def samples_parameters(self): + # issue predictions + label_predictions = self.h.predict(self.data.X) + correct = label_predictions == self.data.y + self.data_evaluated = LabelledCollection(self.data.X, labels=correct, classes=[0,1]) + indexes = [] + for acc_value in self.accuracy_grid(): + index = self.data_evaluated.sampling_index(self.sample_size, acc_value) + indexes.append(index) + return indexes + + def sample(self, index): + return self.data.sampling_from_index(index) + + def total(self): + return self.n_prevalences * self.repeats +