from sklearn.decomposition import TruncatedSVD from sklearn.linear_model import LogisticRegression, LogisticRegressionCV from sklearn.model_selection import GridSearchCV import quapy as qp from data import LabelledCollection from method.non_aggregative import DMx from protocol import APP from quapy.method.aggregative import CC, DMy, ACC from sklearn.svm import LinearSVC import numpy as np from tqdm import tqdm qp.environ['SAMPLE_SIZE'] = 500 def cls(): return LogisticRegressionCV(n_jobs=-1,Cs=10) def gen_methods(): yield CC(cls()), 'CC$_{10' + '\%}$' yield ACC(cls()), 'ACC' yield DMy(cls(), val_split=10, nbins=10, n_jobs=-1), 'HDy' yield DMx(nbins=10, n_jobs=-1), 'HDx' def gen_data(): train, test = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=5).train_test method_data = [] training_prevalence = 0.1 training_size = 5000 # since the problem is binary, it suffices to specify the negative prevalence, since the positive is constrained train_sample = train.sampling(training_size, 1-training_prevalence, random_state=0) for model, method_name in tqdm(gen_methods(), total=4): with qp.util.temp_seed(1): if method_name == 'HDx': X, y = train_sample.Xy svd = TruncatedSVD(n_components=5, random_state=0) Xred = svd.fit_transform(X) train_sample_dense = LabelledCollection(Xred, y) X, y = test.Xy test_dense = LabelledCollection(svd.transform(X), y) model.fit(train_sample_dense) true_prev, estim_prev = qp.evaluation.prediction(model, APP(test_dense, repeats=100, random_state=0)) else: model.fit(train_sample) true_prev, estim_prev = qp.evaluation.prediction(model, APP(test, repeats=100, random_state=0)) method_data.append((method_name, true_prev, estim_prev, train_sample.prevalence())) return zip(*method_data) method_names, true_prevs, estim_prevs, tr_prevs = gen_data() qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, savepath='./plots_cacm/bin_diag_4methods.pdf') qp.plot.error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=10, savepath='./plots_cacm/err_drift_4methods.pdf', title='', show_density=False, show_std=True)