naive baseline added

This commit is contained in:
Lorenzo Volpi 2024-01-31 18:13:09 +01:00
parent ababeb426a
commit 921caaf426
2 changed files with 35 additions and 4 deletions

View File

@ -71,14 +71,14 @@ test_conf: &test_conf
main:
confs: &main_confs
- DATASET_NAME: imdb
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
other_confs:
- DATASET_NAME: imdb
- DATASET_NAME: rcv1
DATASET_TARGET: GCAT
- DATASET_NAME: rcv1
DATASET_TARGET: MCAT
other_confs:
sld_lr_conf: &sld_lr_conf
@ -407,4 +407,4 @@ timing_conf: &timing_conf
confs: *main_confs
exec: *kde_lr_gs_conf
exec: *baselines_conf

View File

@ -68,6 +68,38 @@ def kfcv(
return report
@baseline
def naive(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
predict_method="predict",
):
c_model_predict = getattr(c_model, predict_method)
f1_average = "binary" if validation.n_classes == 2 else "macro"
val_preds = c_model_predict(validation.X)
val_acc = metrics.accuracy_score(validation.y, val_preds)
val_f1 = metrics.f1_score(validation.y, val_preds, average=f1_average)
report = EvaluationReport(name="naive")
for test in protocol():
test_preds = c_model_predict(test.X)
acc_score = metrics.accuracy_score(test.y, test_preds)
f1_score = metrics.f1_score(test.y, test_preds, average=f1_average)
meta_acc = abs(val_acc - acc_score)
meta_f1 = abs(val_f1 - f1_score)
report.append_row(
test.prevalence(),
acc_score=acc_score,
f1_score=f1_score,
acc=meta_acc,
f1=meta_f1,
)
return report
@baseline
def ref(
c_model: BaseEstimator,
@ -556,4 +588,3 @@ def kdex2(
report.append_row(test.prevalence(), acc=meta_score, acc_score=estim_acc)
return report