refactor folders
This commit is contained in:
parent
8c063afd1e
commit
1eebfbc709
|
|
@ -6,7 +6,7 @@ import quapy as qp
|
|||
from data import LabelledCollection
|
||||
from method.non_aggregative import DMx
|
||||
from protocol import APP
|
||||
from quapy.method.aggregative import CC, DMy, ACC
|
||||
from quapy.method.aggregative import CC, DMy, ACC, EMQ
|
||||
from sklearn.svm import LinearSVC
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
|
@ -15,12 +15,15 @@ qp.environ['SAMPLE_SIZE'] = 500
|
|||
|
||||
def cls():
|
||||
return LogisticRegressionCV(n_jobs=-1,Cs=10)
|
||||
# return LogisticRegression(C=.1)
|
||||
|
||||
def gen_methods():
|
||||
yield CC(cls()), 'CC$_{10' + '\%}$'
|
||||
yield CC(cls()), r'CC$_{10' + r'\%}$'
|
||||
yield ACC(cls()), 'ACC'
|
||||
yield DMy(cls(), val_split=10, nbins=10, n_jobs=-1), 'HDy'
|
||||
yield DMx(nbins=10, n_jobs=-1), 'HDx'
|
||||
yield EMQ(cls()), 'SLD'
|
||||
# yield EMQ(cls(), calib='vs'), 'SLD-VS'
|
||||
|
||||
def gen_data():
|
||||
|
||||
|
|
@ -31,6 +34,7 @@ def gen_data():
|
|||
training_size = 5000
|
||||
# since the problem is binary, it suffices to specify the negative prevalence, since the positive is constrained
|
||||
train_sample = train.sampling(training_size, 1-training_prevalence, random_state=0)
|
||||
# train_sample = train
|
||||
|
||||
for model, method_name in tqdm(gen_methods(), total=4):
|
||||
with qp.util.temp_seed(1):
|
||||
|
|
@ -43,10 +47,10 @@ def gen_data():
|
|||
X, y = test.Xy
|
||||
test_dense = LabelledCollection(svd.transform(X), y)
|
||||
|
||||
model.fit(train_sample_dense)
|
||||
model.fit(*train_sample_dense.Xy)
|
||||
true_prev, estim_prev = qp.evaluation.prediction(model, APP(test_dense, repeats=100, random_state=0))
|
||||
else:
|
||||
model.fit(train_sample)
|
||||
model.fit(*train_sample.Xy)
|
||||
true_prev, estim_prev = qp.evaluation.prediction(model, APP(test, repeats=100, random_state=0))
|
||||
method_data.append((method_name, true_prev, estim_prev, train_sample.prevalence()))
|
||||
|
||||
|
|
@ -55,5 +59,5 @@ def gen_data():
|
|||
|
||||
method_names, true_prevs, estim_prevs, tr_prevs = gen_data()
|
||||
|
||||
qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, savepath='./plots_cacm/bin_diag_4methods.pdf')
|
||||
qp.plot.error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=10, savepath='./plots_cacm/err_drift_4methods.pdf', title='', show_density=False, show_std=True)
|
||||
# qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, savepath='./plots_ieee/bin_diag_4methods.pdf')
|
||||
qp.plot.error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=10, savepath='./plots_ieee/err_drift_4methods.pdf', title='', show_density=False, show_std=True)
|
||||
|
|
|
|||
|
|
@ -314,7 +314,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
|||
for p,ind in enumerate(range(len(bins))):
|
||||
selected = inds==ind
|
||||
if selected.sum() > 0:
|
||||
xs.append(ind*binwidth-binwidth/2)
|
||||
xs.append(ind*binwidth)
|
||||
ys.append(np.mean(method_drifts[selected]))
|
||||
ystds.append(np.std(method_drifts[selected]))
|
||||
npoints[p] += len(method_drifts[selected])
|
||||
|
|
|
|||
Loading…
Reference in New Issue