diff --git a/ClassifierAccuracy/experiments.py b/ClassifierAccuracy/experiments.py index 73ad1b1..6af1f16 100644 --- a/ClassifierAccuracy/experiments.py +++ b/ClassifierAccuracy/experiments.py @@ -1,7 +1,10 @@ -from commons import * +from ClassifierAccuracy.util.commons import * +from ClassifierAccuracy.util.plotting import plot_diagonal PROBLEM = 'multiclass' -basedir = PROBLEM +ORACLE = False +basedir = PROBLEM+('-oracle' if ORACLE else '') + if PROBLEM == 'binary': qp.environ['SAMPLE_SIZE'] = 1000 @@ -31,15 +34,15 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifier # instances of ClassifierAccuracyPrediction are bound to the evaluation measure, so they # must be nested in the acc-for for acc_name, acc_fn in gen_acc_measure(): - for (method_name, method) in gen_CAP(h, acc_fn): + for (method_name, method) in gen_CAP(h, acc_fn, with_oracle=ORACLE): result_path = getpath(basedir, cls_name, acc_name, dataset_name, method_name) if os.path.exists(result_path): print(f'\t{method_name}-{acc_name} exists, skipping') continue - print(f'\t{method_name}-{acc_name} computing...') + print(f'\t{method_name} computing...') method, t_train = fit_method(method, V) - estim_accs, t_test_ave = predictionsCAP(method, test_prot) + estim_accs, t_test_ave = predictionsCAP(method, test_prot, ORACLE) save_json_result(result_path, true_accs[acc_name], estim_accs, t_train, t_test_ave) # instances of CAPContingencyTable instead are generic, and the evaluation measure can @@ -52,7 +55,7 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifier print(f'\tmethod {method_name} computing...') method, t_train = fit_method(method, V) - estim_accs_dict, t_test_ave = predictionsCAPcont_table(method, test_prot, gen_acc_measure) + estim_accs_dict, t_test_ave = predictionsCAPcont_table(method, test_prot, gen_acc_measure, ORACLE) for acc_name in estim_accs_dict.keys(): result_path = getpath(basedir, cls_name, acc_name, dataset_name, method_name) save_json_result(result_path, true_accs[acc_name], estim_accs_dict[acc_name], t_train, t_test_ave) @@ -63,11 +66,9 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifier print('generating plots') for (cls_name, _), (acc_name, _) in itertools.product(gen_classifiers(), gen_acc_measure()): methods = get_method_names() - results = open_results(basedir, cls_name, acc_name, method_name=methods) - plot_diagonal(cls_name, acc_name, results, base_dir=f'plots/{basedir}/all') + plot_diagonal(basedir, cls_name, acc_name) for dataset_name, _ in gen_datasets(only_names=True): - results = open_results(basedir, cls_name, acc_name, dataset_name=dataset_name, method_name=methods) - plot_diagonal(cls_name, acc_name, results, base_dir=f'plots/{basedir}/{dataset_name}') + plot_diagonal(basedir, cls_name, acc_name, dataset_name=dataset_name) print('generating tables') gen_tables(basedir, datasets=[d for d,_ in gen_datasets(only_names=True)]) diff --git a/ClassifierAccuracy/gen_tables.py b/ClassifierAccuracy/gen_tables.py index e306e74..fc8d389 100644 --- a/ClassifierAccuracy/gen_tables.py +++ b/ClassifierAccuracy/gen_tables.py @@ -1,3 +1,3 @@ -from commons import gen_tables +from ClassifierAccuracy.util.commons import gen_tables gen_tables() \ No newline at end of file diff --git a/ClassifierAccuracy/models_multiclass.py b/ClassifierAccuracy/models_multiclass.py index 23b9a7c..cdc256b 100644 --- a/ClassifierAccuracy/models_multiclass.py +++ b/ClassifierAccuracy/models_multiclass.py @@ -2,7 +2,7 @@ from copy import deepcopy import numpy as np from sklearn.base import BaseEstimator -from sklearn.linear_model import LogisticRegression +from sklearn.linear_model import LogisticRegression, LinearRegression import quapy as qp from sklearn import clone @@ -29,11 +29,15 @@ class ClassifierAccuracyPrediction(ABC): def fit(self, val: LabelledCollection): ... - def predict(self, X): + @abstractmethod + def predict(self, X, oracle_prev=None): """ Evaluates the accuracy function on the predicted contingency table :param X: test data + :param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by + an oracle. This is meant to test the effect of the errors in CAP that are explained by + the errors in quantification performance :return: float """ return ... @@ -51,28 +55,30 @@ class CAPContingencyTable(ClassifierAccuracyPrediction): self.h = h self.acc = acc - @abstractmethod - def fit(self, val: LabelledCollection): - ... - - def predict(self, X): + def predict(self, X, oracle_prev=None): """ Evaluates the accuracy function on the predicted contingency table :param X: test data + :param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by + an oracle. This is meant to test the effect of the errors in CAP that are explained by + the errors in quantification performance :return: float """ - cont_table = self.predict_ct(X) + cont_table = self.predict_ct(X, oracle) raw_acc = self.acc(cont_table) norm_acc = np.clip(raw_acc, 0, 1) return norm_acc @abstractmethod - def predict_ct(self, X): + def predict_ct(self, X, oracle_prev=None): """ Predicts the contingency table for the test data :param X: test data + :param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by + an oracle. This is meant to test the effect of the errors in CAP that are explained by + the errors in quantification performance :return: a contingency table """ ... @@ -92,13 +98,14 @@ class NaiveCAP(CAPContingencyTable): self.cont_table = confusion_matrix(y_true, y_pred=y_hat, labels=val.classes_) return self - def predict_ct(self, test): + def predict_ct(self, test, oracle_prev=None): """ This method disregards the test set, under the assumption that it is IID wrt the training. This meaning that the confusion matrix for the test data should coincide with the one computed for training (using any cross validation strategy). :param test: test collection (ignored) + :param oracle_prev: ignored :return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix` """ return self.cont_table @@ -133,17 +140,23 @@ class ContTableTransferCAP(CAPContingencyTableQ): def fit(self, val: LabelledCollection): y_hat = self.h.predict(val.X) y_true = val.y - self.cont_table = confusion_matrix(y_true, y_pred=y_hat, labels=val.classes_) + self.cont_table = confusion_matrix(y_true, y_pred=y_hat, labels=val.classes_, normalize='all') self.train_prev = val.prevalence() self.quantifier_fit(val) return self - def predict_ct(self, test): + def predict_ct(self, test, oracle_prev=None): """ :param test: test collection (ignored) + :param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by + an oracle. This is meant to test the effect of the errors in CAP that are explained by + the errors in quantification performance :return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix` """ - prev_hat = self.q.quantify(test) + if oracle_prev is None: + prev_hat = self.q.quantify(test) + else: + prev_hat = oracle_prev adjustment = prev_hat / self.train_prev return self.cont_table * adjustment[:, np.newaxis] @@ -212,9 +225,12 @@ class NsquaredEquationsCAP(CAPContingencyTableQ): return A, b - def predict_ct(self, test): + def predict_ct(self, test, oracle_prev): """ :param test: test collection (ignored) + :param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by + an oracle. This is meant to test the effect of the errors in CAP that are explained by + the errors in quantification performance :return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix` """ @@ -222,7 +238,10 @@ class NsquaredEquationsCAP(CAPContingencyTableQ): h_label_preds = self.h.predict(test) cc_prev_estim = F.prevalence_from_labels(h_label_preds, self.h.classes_) - q_prev_estim = self.q.quantify(test) + if oracle_prev is None: + q_prev_estim = self.q.quantify(test) + else: + q_prev_estim = oracle_prev A = self.A b = self.partial_b @@ -255,13 +274,14 @@ class NsquaredEquationsCAP(CAPContingencyTableQ): class SebastianiCAP(ClassifierAccuracyPrediction): - def __init__(self, h, acc_fn, q_class, n_val_samples=500, alpha=0.3): + def __init__(self, h, acc_fn, q_class, n_val_samples=500, alpha=0.3, predict_train_prev=True): self.h = h self.acc = acc_fn self.q = q_class(h) self.n_val_samples = n_val_samples self.alpha = alpha self.sample_size = qp.environ['SAMPLE_SIZE'] + self.predict_train_prev = predict_train_prev def fit(self, val: LabelledCollection): v2, v1 = val.split_stratified(train_prop=0.5) @@ -272,11 +292,17 @@ class SebastianiCAP(ClassifierAccuracyPrediction): self.sigma_acc = [self.true_acc(sigma_i) for sigma_i in gen_samples()] # precompute prevalence predictions on samples - gen_samples.on_preclassified_instances(self.q.classify(v2.X), in_place=True) - self.sigma_pred_prevs = [self.q.aggregate(sigma_i.X) for sigma_i in gen_samples()] + if self.predict_train_prev: + gen_samples.on_preclassified_instances(self.q.classify(v2.X), in_place=True) + self.sigma_pred_prevs = [self.q.aggregate(sigma_i.X) for sigma_i in gen_samples()] + else: + self.sigma_pred_prevs = [sigma_i.prevalence() for sigma_i in gen_samples()] - def predict(self, X): - test_pred_prev = self.q.quantify(X) + def predict(self, X, oracle_prev=None): + if oracle_prev is None: + test_pred_prev = self.q.quantify(X) + else: + test_pred_prev = oracle_prev if self.alpha > 0: # select samples from V2 with predicted prevalence close to the predicted prevalence for U @@ -316,8 +342,11 @@ class PabloCAP(ClassifierAccuracyPrediction): label_predictions = self.h.predict(val.X) self.pre_classified = LabelledCollection(instances=label_predictions, labels=val.labels) - def predict(self, X): - pred_prev = F.smooth(self.q.quantify(X)) + def predict(self, X, oracle_prev=None): + if oracle_prev is None: + pred_prev = F.smooth(self.q.quantify(X)) + else: + pred_prev = oracle_prev X_size = X.shape[0] acc_estim = [] for _ in range(self.n_val_samples): @@ -334,25 +363,83 @@ class PabloCAP(ClassifierAccuracyPrediction): raise ValueError('unknown aggregation function') +def get_posteriors_from_h(h, X): + if hasattr(h, 'predict_proba'): + P = h.predict_proba(X) + else: + n_classes = len(h.classes_) + dec_scores = h.decision_function(X) + if n_classes == 1: + dec_scores = np.vstack([-dec_scores, dec_scores]).T + P = scipy.special.softmax(dec_scores, axis=1) + return P + + +def max_conf(P, keepdims=False): + mc = P.max(axis=1, keepdims=keepdims) + return mc + + +def neg_entropy(P, keepdims=False): + ne = scipy.stats.entropy(P, axis=1) + if keepdims: + ne = ne.reshape(-1, 1) + return ne + + class QuAcc: + def _get_X_dot(self, X): h = self.h - if hasattr(h, 'predict_proba'): - P = h.predict_proba(X)[:, 1:] - else: - n_classes = len(h.classes_) - P = h.decision_function(X).reshape(-1, n_classes) - X_dot = safehstack(X, P) + P = get_posteriors_from_h(h, X) + + add_covs = [] + + if self.add_posteriors: + add_covs.append(P[:, 1:]) + + if self.add_maxconf: + mc = max_conf(P, keepdims=True) + add_covs.append(mc) + + if self.add_negentropy: + ne = neg_entropy(P, keepdims=True) + add_covs.append(ne) + + if self.add_maxinfsoft: + lgP = np.log(P) + mis = np.max(lgP -lgP.mean(axis=1, keepdims=True), axis=1, keepdims=True) + add_covs.append(mis) + + if len(add_covs)>0: + X_dot = np.hstack(add_covs) + + if self.add_X: + X_dot = safehstack(X, add_covs) + return X_dot class QuAcc1xN2(CAPContingencyTableQ, QuAcc): - def __init__(self, h: BaseEstimator, acc: callable, q_class: AggregativeQuantifier): + def __init__(self, + h: BaseEstimator, + acc: callable, + q_class: AggregativeQuantifier, + add_X=True, + add_posteriors=True, + add_maxconf=False, + add_negentropy=False, + add_maxinfsoft=False): self.h = h self.acc = acc self.q = EmptySaveQuantifier(q_class) + self.add_X = add_X + self.add_posteriors = add_posteriors + self.add_maxconf = add_maxconf + self.add_negentropy = add_negentropy + self.add_maxinfsoft = add_maxinfsoft def fit(self, val: LabelledCollection): pred_labels = self.h.predict(val.X) @@ -367,17 +454,30 @@ class QuAcc1xN2(CAPContingencyTableQ, QuAcc): val_dot = LabelledCollection(X_dot, y_dot, classes=classes_dot) self.q.fit(val_dot) - def predict_ct(self, X): + def predict_ct(self, X, oracle_prev=None): X_dot = self._get_X_dot(X) return self.q.quantify(X_dot) class QuAccNxN(CAPContingencyTableQ, QuAcc): - def __init__(self, h: BaseEstimator, acc: callable, q_class: AggregativeQuantifier): + def __init__(self, + h: BaseEstimator, + acc: callable, + q_class: AggregativeQuantifier, + add_X=True, + add_posteriors=True, + add_maxconf=False, + add_negentropy=False, + add_maxinfsoft=False): self.h = h self.acc = acc self.q_class = q_class + self.add_X = add_X + self.add_posteriors = add_posteriors + self.add_maxconf = add_maxconf + self.add_negentropy = add_negentropy + self.add_maxinfsoft = add_maxinfsoft def fit(self, val: LabelledCollection): pred_labels = self.h.predict(val.X) @@ -394,7 +494,7 @@ class QuAccNxN(CAPContingencyTableQ, QuAcc): q_i.fit(data_i) self.q.append(q_i) - def predict_ct(self, X): + def predict_ct(self, X, oracle_prev=None): classes = self.h.classes_ pred_labels = self.h.predict(X) X_dot = self._get_X_dot(X) @@ -449,3 +549,194 @@ class EmptySaveQuantifier(BaseQuantifier): def num_non_empty_classes(self): return len(self.old_class_idx) + +# Baselines: +class ATC(ClassifierAccuracyPrediction): + + VALID_FUNCTIONS = {'maxconf', 'neg_entropy'} + + def __init__(self, h, acc_fn, scoring_fn='maxconf'): + assert scoring_fn in ATC.VALID_FUNCTIONS, \ + f'unknown scoring function, use any from {ATC.VALID_FUNCTIONS}' + #assert acc_fn == 'vanilla_accuracy', \ + # 'use acc_fn=="vanilla_accuracy"; other metris are not yet tested in ATC' + self.h = h + self.acc_fn = acc_fn + self.scoring_fn = scoring_fn + + def get_scores(self, P): + if self.scoring_fn == 'maxconf': + scores = max_conf(P) + else: + scores = neg_entropy(P) + return scores + + def fit(self, val: LabelledCollection): + P = get_posteriors_from_h(self.h, val.X) + pred_labels = np.argmax(P, axis=1) + true_labels = val.y + scores = self.get_scores(P) + _, self.threshold = self.__find_ATC_threshold(scores=scores, labels=(pred_labels==true_labels)) + + def predict(self, X, oracle_prev=None): + P = get_posteriors_from_h(self.h, X) + scores = self.get_scores(P) + #assert self.acc_fn == 'vanilla_accuracy', \ + # 'use acc_fn=="vanilla_accuracy"; other metris are not yet tested in ATC' + return self.__get_ATC_acc(self.threshold, scores) + + def __find_ATC_threshold(self, scores, labels): + # code copy-pasted from https://github.com/saurabhgarg1996/ATC_code/blob/master/ATC_helper.py + sorted_idx = np.argsort(scores) + + sorted_scores = scores[sorted_idx] + sorted_labels = labels[sorted_idx] + + fp = np.sum(labels == 0) + fn = 0.0 + + min_fp_fn = np.abs(fp - fn) + thres = 0.0 + for i in range(len(labels)): + if sorted_labels[i] == 0: + fp -= 1 + else: + fn += 1 + + if np.abs(fp - fn) < min_fp_fn: + min_fp_fn = np.abs(fp - fn) + thres = sorted_scores[i] + + return min_fp_fn, thres + + def __get_ATC_acc(self, thres, scores): + # code copy-pasted from https://github.com/saurabhgarg1996/ATC_code/blob/master/ATC_helper.py + return np.mean(scores >= thres) + + +class DoC(ClassifierAccuracyPrediction): + + def __init__(self, h, sample_size, num_samples=100): + self.h = h + self.sample_size = sample_size + self.num_samples = num_samples + + def _get_post_stats(self, X, y): + P = get_posteriors_from_h(self.h, X) + mc = max_conf(P) + pred_labels = np.argmax(P, axis=-1) + acc = (y == pred_labels).mean() + return mc, acc + + def _doc(self, mc1, mc2): + return mc2.mean() - mc1.mean() + + def train_regression(self, v2_mcs, v2_accs): + docs = [self._doc(self.v1_mc, v2_mc_i) for v2_mc_i in v2_mcs] + target = [self.v1_acc - v2_acc_i for v2_acc_i in v2_accs] + docs = np.asarray(docs).reshape(-1,1) + target = np.asarray(target) + lin_reg = LinearRegression() + return lin_reg.fit(docs, target) + + def predict_regression(self, test_mc): + docs = np.asarray([self._doc(self.v1_mc, test_mc)]).reshape(-1, 1) + pred_acc = self.reg_model.predict(docs) + return self.v1_acc - pred_acc + + def fit(self, val: LabelledCollection): + v1, v2 = val.split_stratified(train_prop=0.5, random_state=0) + + self.v1_mc, self.v1_acc = self._get_post_stats(*v1.Xy) + + v2_prot = UPP(v2, sample_size=self.sample_size, repeats=self.num_samples, return_type='labelled_collection') + v2_stats = [self._get_post_stats(*sample.Xy) for sample in v2_prot()] + v2_mcs, v2_accs = list(zip(*v2_stats)) + + self.reg_model = self.train_regression(v2_mcs, v2_accs) + + def predict(self, X, oracle_prev=None): + P = get_posteriors_from_h(self.h, X) + mc = max_conf(P) + acc_pred = self.predict_regression(mc)[0] + return acc_pred + + """ + def doc(self, + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, + predict_method="predict_proba"): + + c_model_predict = getattr(c_model, predict_method) + f1_average = "binary" if validation.n_classes == 2 else "macro" + + val1, val2 = validation.split_stratified(train_prop=0.5, random_state=env._R_SEED) + val1_probs = c_model_predict(val1.X) + val1_mc = np.max(val1_probs, axis=-1) + val1_preds = np.argmax(val1_probs, axis=-1) + val1_acc = metrics.accuracy_score(val1.y, val1_preds) + val1_f1 = metrics.f1_score(val1.y, val1_preds, average=f1_average) + val2_protocol = APP( + val2, + n_prevalences=21, + repeats=100, + return_type="labelled_collection", + ) + val2_prot_mc = [] + val2_prot_preds = [] + val2_prot_y = [] + for v2 in val2_protocol(): + _probs = c_model_predict(v2.X) + _mc = np.max(_probs, axis=-1) + _preds = np.argmax(_probs, axis=-1) + val2_prot_mc.append(_mc) + val2_prot_preds.append(_preds) + val2_prot_y.append(v2.y) + + val_scores = np.array([doclib.get_doc(val1_mc, v2_mc) for v2_mc in val2_prot_mc]) + val_targets_acc = np.array( + [ + val1_acc - metrics.accuracy_score(v2_y, v2_preds) + for v2_y, v2_preds in zip(val2_prot_y, val2_prot_preds) + ] + ) + reg_acc = LinearRegression().fit(val_scores[:, np.newaxis], val_targets_acc) + val_targets_f1 = np.array( + [ + val1_f1 - metrics.f1_score(v2_y, v2_preds, average=f1_average) + for v2_y, v2_preds in zip(val2_prot_y, val2_prot_preds) + ] + ) + reg_f1 = LinearRegression().fit(val_scores[:, np.newaxis], val_targets_f1) + + report = EvaluationReport(name="doc") + for test in protocol(): + test_probs = c_model_predict(test.X) + test_preds = np.argmax(test_probs, axis=-1) + test_mc = np.max(test_probs, axis=-1) + acc_score = ( + val1_acc + - reg_acc.predict(np.array([[doclib.get_doc(val1_mc, test_mc)]]))[0] + ) + f1_score = ( + val1_f1 - reg_f1.predict(np.array([[doclib.get_doc(val1_mc, test_mc)]]))[0] + ) + meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds)) + meta_f1 = abs( + f1_score - metrics.f1_score(test.y, test_preds, average=f1_average) + ) + report.append_row( + test.prevalence(), + acc=meta_acc, + acc_score=acc_score, + f1=meta_f1, + f1_score=f1_score, + ) + + return report + + def get_doc(probs1, probs2): + return np.mean(probs2) - np.mean(probs1) + """ + diff --git a/ClassifierAccuracy/notes.md b/ClassifierAccuracy/notes.md index a515559..83864df 100644 --- a/ClassifierAccuracy/notes.md +++ b/ClassifierAccuracy/notes.md @@ -14,4 +14,19 @@ A Classifier Accuracy Prediction (CAP) method is method tha receives as input: And implements: - fit: trains the CAP - predict: predicts the evaluation measure on unseen data (provided, calls predict_ct and acc_func) -- predict_ct: predicts the contingency table \ No newline at end of file +- predict_ct: predicts the contingency table + +Important: +- When the quantifiers' iperparameters are optimized, we should make sure that the + classifier is not being reused, or that the iperparameters do no include any from + the underlying classifier + +TODO: +- Add additional covariates [done, check] +- Add model selection for CAP +- Add Doc +- Add ATC +- Add APP in training and adapt plots and tables +- Add plots: error by drift, etc +- Add characterization of classifiers in terms of accuracy and use this as a variable + analyzing results \ No newline at end of file diff --git a/ClassifierAccuracy/commons.py b/ClassifierAccuracy/util/commons.py similarity index 79% rename from ClassifierAccuracy/commons.py rename to ClassifierAccuracy/util/commons.py index 983d827..1846208 100644 --- a/ClassifierAccuracy/commons.py +++ b/ClassifierAccuracy/util/commons.py @@ -3,23 +3,30 @@ import json import os from collections import defaultdict from glob import glob -from os import makedirs -from os.path import join from pathlib import Path from time import time +import numpy as np + -import matplotlib.pyplot as plt from sklearn.datasets import fetch_rcv1 +from sklearn.model_selection import GridSearchCV + +from ClassifierAccuracy.models_multiclass import * +from quapy.method.aggregative import EMQ, ACC, KDEyML -from quapy.method.aggregative import EMQ, ACC -from models_multiclass import * from quapy.data import LabelledCollection from quapy.data.datasets import fetch_UCIMulticlassLabelledCollection, UCI_MULTICLASS_DATASETS from quapy.data.datasets import fetch_reviews def gen_classifiers(): + param_grid = { + 'C': np.logspace(-4, -4, 9), + 'class_weight': ['balanced', None] + } + yield 'LR', LogisticRegression() + #yield 'LR-opt', GridSearchCV(LogisticRegression(), param_grid, cv=5, n_jobs=-1) #yield 'NB', GaussianNB() #yield 'SVM(rbf)', SVC() #yield 'SVM(linear)', LinearSVC() @@ -27,6 +34,8 @@ def gen_classifiers(): def gen_multi_datasets(only_names=False)-> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]: for dataset_name in UCI_MULTICLASS_DATASETS: + if dataset_name == 'wine-quality': + continue if only_names: yield dataset_name, None else: @@ -56,21 +65,31 @@ def gen_bin_datasets(only_names=False) -> [str,[LabelledCollection,LabelledColle yield cat, (L, V, U) -def gen_CAP(h, acc_fn)->[str, ClassifierAccuracyPrediction]: +def gen_CAP(h, acc_fn, with_oracle=False)->[str, ClassifierAccuracyPrediction]: #yield 'SebCAP', SebastianiCAP(h, acc_fn, ACC) - yield 'SebCAP-SLD', SebastianiCAP(h, acc_fn, EMQ) + yield 'SebCAP-SLD', SebastianiCAP(h, acc_fn, EMQ, predict_train_prev=not with_oracle) + #yield 'SebCAP-KDE', SebastianiCAP(h, acc_fn, KDEyML) #yield 'SebCAPweight', SebastianiCAP(h, acc_fn, ACC, alpha=0) #yield 'PabCAP', PabloCAP(h, acc_fn, ACC) - yield 'PabCAP-SLD-median', PabloCAP(h, acc_fn, EMQ, aggr='median') + #yield 'PabCAP-SLD-median', PabloCAP(h, acc_fn, EMQ, aggr='median') + yield 'ATC-MC', ATC(h, acc_fn, scoring_fn='maxconf') + #yield 'ATC-NE', ATC(h, acc_fn, scoring_fn='neg_entropy') + yield 'DoC', DoC(h, sample_size=qp.environ['SAMPLE_SIZE']) def gen_CAP_cont_table(h)->[str,CAPContingencyTable]: acc_fn = None yield 'Naive', NaiveCAP(h, acc_fn) - yield 'CT-PPS-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression())) - yield 'QuAcc(EMQ)nxn', QuAccNxN(h, acc_fn, EMQ(LogisticRegression())) + #yield 'CT-PPS-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression())) + #yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01)) + yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05)) + #yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False) + #yield 'QuAcc(EMQ)nxn', QuAccNxN(h, acc_fn, EMQ(LogisticRegression())) + #yield 'QuAcc(EMQ)nxn-MC', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_maxconf=True) + yield 'QuAcc(EMQ)nxn-NE', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_negentropy=True) + #yield 'QuAcc(EMQ)nxn-MIS', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_maxinfsoft=True) + #yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression())) #yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression())) - yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression())) #yield 'CT-PPSh-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()), reuse_h=True) #yield 'Equations-ACCh', NsquaredEquationsCAP(h, acc_fn, ACC, reuse_h=True) # yield 'Equations-ACC', NsquaredEquationsCAP(h, acc_fn, ACC) @@ -100,17 +119,23 @@ def fit_method(method, V): return method, t_train -def predictionsCAP(method, test_prot): +def predictionsCAP(method, test_prot, oracle=False): tinit = time() - estim_accs = [method.predict(Ui.X) for Ui in test_prot()] + if not oracle: + estim_accs = [method.predict(Ui.X) for Ui in test_prot()] + else: + estim_accs = [method.predict(Ui.X, oracle_prev=Ui.prevalence()) for Ui in test_prot()] t_test_ave = (time() - tinit) / test_prot.total() return estim_accs, t_test_ave -def predictionsCAPcont_table(method, test_prot, gen_acc_measure): +def predictionsCAPcont_table(method, test_prot, gen_acc_measure, oracle=False): estim_accs_dict = {} tinit = time() - estim_tables = [method.predict_ct(Ui.X) for Ui in test_prot()] + if not oracle: + estim_tables = [method.predict_ct(Ui.X) for Ui in test_prot()] + else: + estim_tables = [method.predict_ct(Ui.X, oracle_prev=Ui.prevalence()) for Ui in test_prot()] for acc_name, acc_fn in gen_acc_measure(): estim_accs_dict[acc_name] = [acc_fn(cont_table) for cont_table in estim_tables] t_test_ave = (time() - tinit) / test_prot.total() @@ -184,35 +209,6 @@ def cap_errors(true_acc, estim_acc): return np.abs(true_acc - estim_acc) -def plot_diagonal(cls_name, measure_name, results, base_dir='plots'): - - makedirs(base_dir, exist_ok=True) - makedirs(join(base_dir, measure_name), exist_ok=True) - - # Create scatter plot - plt.figure(figsize=(10, 10)) - plt.xlim(0, 1) - plt.ylim(0, 1) - plt.plot([0, 1], [0, 1], color='black', linestyle='--') - - for method_name in results.keys(): - xs = results[method_name]['true_acc'] - ys = results[method_name]['estim_acc'] - err = cap_errors(xs, ys).mean() - #pear_cor, _ = 0, 0 #pearsonr(xs, ys) - plt.scatter(xs, ys, label=f'{method_name} {err:.3f}', alpha=0.6) - - plt.legend() - - # Add labels and title - plt.xlabel(f'True {measure_name}') - plt.ylabel(f'Estimated {measure_name}') - - # Display the plot - # plt.show() - plt.savefig(join(base_dir, measure_name, 'diagonal_'+cls_name+'.png')) - - def getpath(basedir, cls_name, acc_name, dataset_name, method_name): return f"results/{basedir}/{cls_name}/{acc_name}/{dataset_name}/{method_name}.json" @@ -275,7 +271,7 @@ def gen_tables(basedir, datasets): classifiers = [classifier for classifier, _ in gen_classifiers()] measures = [measure for measure, _ in gen_acc_measure()] - os.makedirs('tables', exist_ok=True) + os.makedirs('./tables', exist_ok=True) tex_doc = """ \\documentclass[10pt,a4paper]{article} diff --git a/ClassifierAccuracy/tabular.py b/ClassifierAccuracy/util/tabular.py similarity index 100% rename from ClassifierAccuracy/tabular.py rename to ClassifierAccuracy/util/tabular.py diff --git a/quapy/data/base.py b/quapy/data/base.py index eb41c44..cb695be 100644 --- a/quapy/data/base.py +++ b/quapy/data/base.py @@ -151,6 +151,8 @@ class LabelledCollection: indexes_sample = [] for class_, n_requested in n_requests.items(): n_candidates = len(self.index[class_]) + #print(n_candidates) + #print(n_requested, 'rq') index_sample = self.index[class_][ np.random.choice(n_candidates, size=n_requested, replace=True) ] if n_requested > 0 else [] diff --git a/quapy/model_selection.py b/quapy/model_selection.py index 12b3386..ff4d03b 100644 --- a/quapy/model_selection.py +++ b/quapy/model_selection.py @@ -211,8 +211,9 @@ class GridSearchQ(BaseQuantifier): self._sout(f'error={status}') def fit(self, training: LabelledCollection): - """ Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing - the error metric. + """ + Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing + the error metric. :param training: the training set on which to optimize the hyperparameters :return: self