diff --git a/conf.yaml b/conf.yaml index f61f032..4a4ffea 100644 --- a/conf.yaml +++ b/conf.yaml @@ -71,14 +71,14 @@ test_conf: &test_conf main: confs: &main_confs + - DATASET_NAME: imdb - DATASET_NAME: rcv1 DATASET_TARGET: CCAT - other_confs: - - DATASET_NAME: imdb - DATASET_NAME: rcv1 DATASET_TARGET: GCAT - DATASET_NAME: rcv1 DATASET_TARGET: MCAT + other_confs: sld_lr_conf: &sld_lr_conf @@ -407,4 +407,4 @@ timing_conf: &timing_conf confs: *main_confs -exec: *kde_lr_gs_conf +exec: *baselines_conf diff --git a/quacc/evaluation/baseline.py b/quacc/evaluation/baseline.py index 605b5ca..436dd47 100644 --- a/quacc/evaluation/baseline.py +++ b/quacc/evaluation/baseline.py @@ -68,6 +68,38 @@ def kfcv( return report +@baseline +def naive( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, + predict_method="predict", +): + c_model_predict = getattr(c_model, predict_method) + f1_average = "binary" if validation.n_classes == 2 else "macro" + + val_preds = c_model_predict(validation.X) + val_acc = metrics.accuracy_score(validation.y, val_preds) + val_f1 = metrics.f1_score(validation.y, val_preds, average=f1_average) + + report = EvaluationReport(name="naive") + for test in protocol(): + test_preds = c_model_predict(test.X) + acc_score = metrics.accuracy_score(test.y, test_preds) + f1_score = metrics.f1_score(test.y, test_preds, average=f1_average) + meta_acc = abs(val_acc - acc_score) + meta_f1 = abs(val_f1 - f1_score) + report.append_row( + test.prevalence(), + acc_score=acc_score, + f1_score=f1_score, + acc=meta_acc, + f1=meta_f1, + ) + + return report + + @baseline def ref( c_model: BaseEstimator, @@ -556,4 +588,3 @@ def kdex2( report.append_row(test.prevalence(), acc=meta_score, acc_score=estim_acc) return report -