first test on quantification for accuracy
This commit is contained in:
parent
2911cdc458
commit
d22fce9050
|
@ -0,0 +1,3 @@
|
||||||
|
*.code-workspace
|
||||||
|
quavenv/*
|
||||||
|
*.pdf
|
|
@ -0,0 +1,36 @@
|
||||||
|
abstention==0.1.3.1
|
||||||
|
astroid==2.15.4
|
||||||
|
contourpy==1.0.7
|
||||||
|
cycler==0.11.0
|
||||||
|
dill==0.3.6
|
||||||
|
docstring-to-markdown==0.12
|
||||||
|
fonttools==4.39.3
|
||||||
|
joblib==1.2.0
|
||||||
|
kiwisolver==1.4.4
|
||||||
|
lazy-object-proxy==1.9.0
|
||||||
|
matplotlib==3.7.1
|
||||||
|
numpy==1.24.3
|
||||||
|
packaging==23.1
|
||||||
|
pandas==2.0.1
|
||||||
|
parso==0.8.3
|
||||||
|
Pillow==9.5.0
|
||||||
|
platformdirs==3.5.0
|
||||||
|
pluggy==1.0.0
|
||||||
|
pyparsing==3.0.9
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
pytoolconfig==1.2.5
|
||||||
|
pytz==2023.3
|
||||||
|
QuaPy==0.1.7
|
||||||
|
scikit-learn==1.2.2
|
||||||
|
scipy==1.10.1
|
||||||
|
six==1.16.0
|
||||||
|
snowballstemmer==2.2.0
|
||||||
|
threadpoolctl==3.1.0
|
||||||
|
toml==0.10.2
|
||||||
|
tomlkit==0.11.8
|
||||||
|
tqdm==4.65.0
|
||||||
|
tzdata==2023.3
|
||||||
|
ujson==5.7.0
|
||||||
|
whatthepatch==1.0.5
|
||||||
|
wrapt==1.15.0
|
||||||
|
xlrd==2.0.1
|
|
@ -0,0 +1,146 @@
|
||||||
|
import numpy as np
|
||||||
|
import quapy as qp
|
||||||
|
import scipy.sparse as sp
|
||||||
|
from quapy.data import LabelledCollection
|
||||||
|
from quapy.protocol import APP, AbstractStochasticSeededProtocol
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.model_selection import cross_val_predict
|
||||||
|
|
||||||
|
|
||||||
|
# Extended classes
|
||||||
|
#
|
||||||
|
# 0 ~ True 0
|
||||||
|
# 1 ~ False 1
|
||||||
|
# 2 ~ False 0
|
||||||
|
# 3 ~ True 1
|
||||||
|
# _____________________
|
||||||
|
# | | |
|
||||||
|
# | True 0 | False 1 |
|
||||||
|
# |__________|__________|
|
||||||
|
# | | |
|
||||||
|
# | False 0 | True 1 |
|
||||||
|
# |__________|__________|
|
||||||
|
#
|
||||||
|
def get_ex_class(classes, true_class, pred_class):
|
||||||
|
return true_class * classes + pred_class
|
||||||
|
|
||||||
|
|
||||||
|
def extend_collection(coll, pred_prob):
|
||||||
|
n_classes = coll.n_classes
|
||||||
|
|
||||||
|
# n_X = [ X | predicted probs. ]
|
||||||
|
if isinstance(coll.X, sp.csr_matrix):
|
||||||
|
pred_prob_csr = sp.csr_matrix(pred_prob)
|
||||||
|
n_x = sp.hstack([coll.X, pred_prob_csr])
|
||||||
|
elif isinstance(coll.X, np.ndarray):
|
||||||
|
n_x = np.concatenate((coll.X, pred_prob), axis=1)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported matrix format")
|
||||||
|
|
||||||
|
# n_y = (exptected y, predicted y)
|
||||||
|
n_y = []
|
||||||
|
for i, true_class in enumerate(coll.y):
|
||||||
|
pred_class = pred_prob[i].argmax(axis=0)
|
||||||
|
n_y.append(get_ex_class(n_classes, true_class, pred_class))
|
||||||
|
|
||||||
|
return LabelledCollection(n_x, np.asarray(n_y), [*range(0, n_classes * n_classes)])
|
||||||
|
|
||||||
|
|
||||||
|
def qf1e_binary(prev):
|
||||||
|
recall = prev[0] / (prev[0] + prev[1])
|
||||||
|
precision = prev[0] / (prev[0] + prev[2])
|
||||||
|
|
||||||
|
return 1 - 2 * (precision * recall) / (precision + recall)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_errors(true_prev, estim_prev, n_instances):
|
||||||
|
errors = {}
|
||||||
|
_eps = 1 / (2 * n_instances)
|
||||||
|
errors = {
|
||||||
|
"mae": qp.error.mae(true_prev, estim_prev),
|
||||||
|
"rae": qp.error.rae(true_prev, estim_prev, eps=_eps),
|
||||||
|
"mrae": qp.error.mrae(true_prev, estim_prev, eps=_eps),
|
||||||
|
"kld": qp.error.kld(true_prev, estim_prev, eps=_eps),
|
||||||
|
"nkld": qp.error.nkld(true_prev, estim_prev, eps=_eps),
|
||||||
|
"true_f1e": qf1e_binary(true_prev),
|
||||||
|
"estim_f1e": qf1e_binary(estim_prev),
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def extend_and_quantify(
|
||||||
|
model,
|
||||||
|
q_model,
|
||||||
|
train,
|
||||||
|
test: LabelledCollection | AbstractStochasticSeededProtocol,
|
||||||
|
):
|
||||||
|
model.fit(*train.Xy)
|
||||||
|
|
||||||
|
pred_prob_train = cross_val_predict(model, *train.Xy, method="predict_proba")
|
||||||
|
_train = extend_collection(train, pred_prob_train)
|
||||||
|
|
||||||
|
q_model.fit(_train)
|
||||||
|
|
||||||
|
def quantify_extended(test):
|
||||||
|
pred_prob_test = model.predict_proba(test.X)
|
||||||
|
_test = extend_collection(test, pred_prob_test)
|
||||||
|
return _test.prevalence(), q_model.quantify(_test.instances)
|
||||||
|
|
||||||
|
if isinstance(test, LabelledCollection):
|
||||||
|
_orig_prev, _true_prev, _estim_prev = quantify_extended(test)
|
||||||
|
_errors = compute_errors(_true_prev, _estim_prev, test.X.shape[0])
|
||||||
|
return ([_orig_prev], [_true_prev], [_estim_prev], [_errors])
|
||||||
|
|
||||||
|
elif isinstance(test, AbstractStochasticSeededProtocol):
|
||||||
|
orig_prevs, true_prevs, estim_prevs, errors = [], [], [], []
|
||||||
|
for index in test.samples_parameters():
|
||||||
|
sample = test.sample(index)
|
||||||
|
_true_prev, _estim_prev = quantify_extended(sample)
|
||||||
|
|
||||||
|
orig_prevs.append(sample.prevalence())
|
||||||
|
true_prevs.append(_true_prev)
|
||||||
|
estim_prevs.append(_estim_prev)
|
||||||
|
errors.append(compute_errors(_true_prev, _estim_prev, sample.X.shape[0]))
|
||||||
|
|
||||||
|
return orig_prevs, true_prevs, estim_prevs, errors
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(name):
|
||||||
|
datasets = {
|
||||||
|
"spambase": lambda: qp.datasets.fetch_UCIDataset(
|
||||||
|
"spambase", verbose=False
|
||||||
|
).train_test,
|
||||||
|
"hp": lambda: qp.datasets.fetch_reviews("hp", tfidf=True).train_test,
|
||||||
|
"imdb": lambda: qp.datasets.fetch_reviews("imdb", tfidf=True).train_test,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return datasets[name]()
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(f"{name} is not available as a dataset")
|
||||||
|
|
||||||
|
|
||||||
|
def test_1():
|
||||||
|
train, test = get_dataset("spambase")
|
||||||
|
|
||||||
|
orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify(
|
||||||
|
LogisticRegression(),
|
||||||
|
qp.method.aggregative.SLD(LogisticRegression()),
|
||||||
|
train,
|
||||||
|
APP(test, sample_size=100, n_prevalences=11, repeats=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for orig_prev, true_prev, estim_prev, _errors in zip(
|
||||||
|
orig_prevs, true_prevs, estim_prevs, errors
|
||||||
|
):
|
||||||
|
print(f"original prevalence:\t{orig_prev}")
|
||||||
|
print(f"true prevalence:\t{true_prev}")
|
||||||
|
print(f"estimated prevalence:\t{estim_prev}")
|
||||||
|
for name, err in _errors.items():
|
||||||
|
print(f"{name}={err:.3f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_1()
|
Loading…
Reference in New Issue