QuAcc/quacc/evaluation/baseline.py

299 lines
9.6 KiB
Python
Raw Normal View History

from functools import wraps
2023-09-13 00:11:20 +02:00
from statistics import mean
2023-09-22 01:40:36 +02:00
import numpy as np
import sklearn.metrics as metrics
from quapy.data import LabelledCollection
from quapy.protocol import AbstractStochasticSeededProtocol
from scipy.sparse import issparse
from sklearn.base import BaseEstimator
from sklearn.model_selection import cross_validate
import baselines.atc as atc
import baselines.doc as doc
import baselines.impweight as iw
import baselines.rca as rcalib
2023-09-13 00:11:20 +02:00
from .report import EvaluationReport
_baselines = {}
2023-09-14 01:52:19 +02:00
def baseline(func):
@wraps(func)
def wrapper(c_model, validation, protocol):
return func(c_model, validation, protocol)
_baselines[func.__name__] = wrapper
return wrapper
@baseline
def kfcv(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
predict_method="predict",
):
c_model_predict = getattr(c_model, predict_method)
2023-09-14 01:52:19 +02:00
scoring = ["accuracy", "f1_macro"]
scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring)
acc_score = mean(scores["test_accuracy"])
f1_score = mean(scores["test_f1_macro"])
2023-09-24 02:21:18 +02:00
report = EvaluationReport(name="kfcv")
for test in protocol():
test_preds = c_model_predict(test.X)
meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds))
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds))
report.append_row(
test.prevalence(),
acc_score=acc_score,
f1_score=f1_score,
acc=meta_acc,
f1=meta_f1,
)
return report
@baseline
def ref(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
):
c_model_predict = getattr(c_model, "predict_proba")
report = EvaluationReport(name="ref")
for test in protocol():
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
report.append_row(
test.prevalence(),
acc_score=metrics.accuracy_score(test.y, test_preds),
f1_score=metrics.f1_score(test.y, test_preds),
)
2023-09-24 02:21:18 +02:00
return report
2023-09-24 02:21:18 +02:00
@baseline
2023-09-17 21:47:34 +02:00
def atc_mc(
2023-09-14 01:52:19 +02:00
c_model: BaseEstimator,
validation: LabelledCollection,
2023-09-24 02:21:18 +02:00
protocol: AbstractStochasticSeededProtocol,
2023-09-14 01:52:19 +02:00
predict_method="predict_proba",
):
c_model_predict = getattr(c_model, predict_method)
## Load ID validation data probs and labels
val_probs, val_labels = c_model_predict(validation.X), validation.y
## score function, e.g., negative entropy or argmax confidence
2023-09-17 21:47:34 +02:00
val_scores = atc.get_max_conf(val_probs)
2023-09-14 01:52:19 +02:00
val_preds = np.argmax(val_probs, axis=-1)
2023-09-18 09:24:20 +02:00
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
2023-09-14 01:52:19 +02:00
report = EvaluationReport(name="atc_mc")
2023-09-24 02:21:18 +02:00
for test in protocol():
## Load OOD test data probs
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
2023-09-24 02:21:18 +02:00
test_scores = atc.get_max_conf(test_probs)
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
meta_acc = abs(atc_accuracy - metrics.accuracy_score(test.y, test_preds))
f1_score = atc.get_ATC_f1(atc_thres, test_scores, test_probs)
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds))
report.append_row(
test.prevalence(),
acc=meta_acc,
acc_score=atc_accuracy,
f1_score=f1_score,
f1=meta_f1,
)
return report
2023-09-14 01:52:19 +02:00
2023-09-16 01:59:49 +02:00
@baseline
2023-09-17 21:47:34 +02:00
def atc_ne(
2023-09-14 01:52:19 +02:00
c_model: BaseEstimator,
validation: LabelledCollection,
2023-09-24 02:21:18 +02:00
protocol: AbstractStochasticSeededProtocol,
2023-09-14 01:52:19 +02:00
predict_method="predict_proba",
):
c_model_predict = getattr(c_model, predict_method)
## Load ID validation data probs and labels
val_probs, val_labels = c_model_predict(validation.X), validation.y
## score function, e.g., negative entropy or argmax confidence
2023-09-17 21:47:34 +02:00
val_scores = atc.get_entropy(val_probs)
2023-09-14 01:52:19 +02:00
val_preds = np.argmax(val_probs, axis=-1)
2023-09-17 21:47:34 +02:00
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
2023-09-14 01:52:19 +02:00
report = EvaluationReport(name="atc_ne")
2023-09-24 02:21:18 +02:00
for test in protocol():
## Load OOD test data probs
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
2023-09-24 02:21:18 +02:00
test_scores = atc.get_entropy(test_probs)
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
meta_acc = abs(atc_accuracy - metrics.accuracy_score(test.y, test_preds))
f1_score = atc.get_ATC_f1(atc_thres, test_scores, test_probs)
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds))
report.append_row(
test.prevalence(),
acc=meta_acc,
acc_score=atc_accuracy,
f1_score=f1_score,
f1=meta_f1,
)
return report
2023-09-14 01:52:19 +02:00
2023-09-16 01:59:49 +02:00
@baseline
2023-09-17 21:47:34 +02:00
def doc_feat(
c_model: BaseEstimator,
validation: LabelledCollection,
2023-09-24 02:21:18 +02:00
protocol: AbstractStochasticSeededProtocol,
2023-09-17 21:47:34 +02:00
predict_method="predict_proba",
):
c_model_predict = getattr(c_model, predict_method)
val_probs, val_labels = c_model_predict(validation.X), validation.y
val_scores = np.max(val_probs, axis=-1)
val_preds = np.argmax(val_probs, axis=-1)
2023-09-18 09:24:20 +02:00
v1acc = np.mean(val_preds == val_labels) * 100
2023-09-24 02:21:18 +02:00
report = EvaluationReport(name="doc_feat")
2023-09-24 02:21:18 +02:00
for test in protocol():
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
2023-09-24 02:21:18 +02:00
test_scores = np.max(test_probs, axis=-1)
score = (v1acc + doc.get_doc(val_scores, test_scores)) / 100.0
meta_acc = abs(score - metrics.accuracy_score(test.y, test_preds))
report.append_row(test.prevalence(), acc=meta_acc, acc_score=score)
return report
2023-09-18 18:19:13 +02:00
@baseline
def rca(
2023-09-18 18:19:13 +02:00
c_model: BaseEstimator,
validation: LabelledCollection,
2023-09-24 02:21:18 +02:00
protocol: AbstractStochasticSeededProtocol,
2023-09-18 18:19:13 +02:00
predict_method="predict",
):
"""elsahar19"""
2023-09-18 18:19:13 +02:00
c_model_predict = getattr(c_model, predict_method)
val_pred1 = c_model_predict(validation.X)
report = EvaluationReport(name="rca")
2023-09-24 02:21:18 +02:00
for test in protocol():
try:
2023-09-24 02:21:18 +02:00
test_pred = c_model_predict(test.X)
c_model2 = rcalib.clone_fit(c_model, test.X, test_pred)
2023-09-24 02:21:18 +02:00
c_model2_predict = getattr(c_model2, predict_method)
val_pred2 = c_model2_predict(validation.X)
rca_score = 1.0 - rcalib.get_score(val_pred1, val_pred2, validation.y)
meta_score = abs(rca_score - metrics.accuracy_score(test.y, test_pred))
report.append_row(test.prevalence(), acc=meta_score, acc_score=rca_score)
2023-09-24 02:21:18 +02:00
except ValueError:
report.append_row(
test.prevalence(), acc=float("nan"), acc_score=float("nan")
)
2023-09-24 02:21:18 +02:00
return report
2023-09-24 02:21:18 +02:00
2023-09-18 18:19:13 +02:00
@baseline
def rca_star(
2023-09-18 18:19:13 +02:00
c_model: BaseEstimator,
validation: LabelledCollection,
2023-09-24 02:21:18 +02:00
protocol: AbstractStochasticSeededProtocol,
2023-09-18 18:19:13 +02:00
predict_method="predict",
):
"""elsahar19"""
2023-09-18 18:19:13 +02:00
c_model_predict = getattr(c_model, predict_method)
validation1, validation2 = validation.split_stratified(
train_prop=0.5, random_state=0
)
2023-09-18 18:19:13 +02:00
val1_pred = c_model_predict(validation1.X)
c_model1 = rcalib.clone_fit(c_model, validation1.X, val1_pred)
2023-09-18 18:19:13 +02:00
c_model1_predict = getattr(c_model1, predict_method)
val2_pred1 = c_model1_predict(validation2.X)
report = EvaluationReport(name="rca_star")
2023-09-24 02:21:18 +02:00
for test in protocol():
try:
test_pred = c_model_predict(test.X)
c_model2 = rcalib.clone_fit(c_model, test.X, test_pred)
2023-09-24 02:21:18 +02:00
c_model2_predict = getattr(c_model2, predict_method)
val2_pred2 = c_model2_predict(validation2.X)
rca_star_score = 1.0 - rcalib.get_score(
val2_pred1, val2_pred2, validation2.y
)
meta_score = abs(rca_star_score - metrics.accuracy_score(test.y, test_pred))
report.append_row(
test.prevalence(), acc=meta_score, acc_score=rca_star_score
2023-09-24 02:21:18 +02:00
)
except ValueError:
report.append_row(
test.prevalence(), acc=float("nan"), acc_score=float("nan")
)
2023-09-24 02:21:18 +02:00
return report
@baseline
def logreg(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
predict_method="predict",
):
c_model_predict = getattr(c_model, predict_method)
val_preds = c_model_predict(validation.X)
report = EvaluationReport(name="logreg")
for test in protocol():
wx = iw.logreg(validation.X, validation.y, test.X)
test_preds = c_model_predict(test.X)
estim_acc = iw.get_acc(val_preds, validation.y, wx)
true_acc = metrics.accuracy_score(test.y, test_preds)
meta_score = abs(estim_acc - true_acc)
report.append_row(test.prevalence(), acc=meta_score, acc_score=estim_acc)
return report
@baseline
def kdex2(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
predict_method="predict",
):
c_model_predict = getattr(c_model, predict_method)
val_preds = c_model_predict(validation.X)
log_likelihood_val = iw.kdex2_lltr(validation.X)
Xval = validation.X.toarray() if issparse(validation.X) else validation.X
report = EvaluationReport(name="kdex2")
for test in protocol():
Xte = test.X.toarray() if issparse(test.X) else test.X
wx = iw.kdex2_weights(Xval, Xte, log_likelihood_val)
test_preds = c_model_predict(Xte)
estim_acc = iw.get_acc(val_preds, validation.y, wx)
true_acc = metrics.accuracy_score(test.y, test_preds)
meta_score = abs(estim_acc - true_acc)
report.append_row(test.prevalence(), acc=meta_score, acc_score=estim_acc)
return report