naive baseline added
This commit is contained in:
parent
ababeb426a
commit
921caaf426
|
@ -71,14 +71,14 @@ test_conf: &test_conf
|
|||
|
||||
main:
|
||||
confs: &main_confs
|
||||
- DATASET_NAME: imdb
|
||||
- DATASET_NAME: rcv1
|
||||
DATASET_TARGET: CCAT
|
||||
other_confs:
|
||||
- DATASET_NAME: imdb
|
||||
- DATASET_NAME: rcv1
|
||||
DATASET_TARGET: GCAT
|
||||
- DATASET_NAME: rcv1
|
||||
DATASET_TARGET: MCAT
|
||||
other_confs:
|
||||
|
||||
sld_lr_conf: &sld_lr_conf
|
||||
|
||||
|
@ -407,4 +407,4 @@ timing_conf: &timing_conf
|
|||
|
||||
confs: *main_confs
|
||||
|
||||
exec: *kde_lr_gs_conf
|
||||
exec: *baselines_conf
|
||||
|
|
|
@ -68,6 +68,38 @@ def kfcv(
|
|||
return report
|
||||
|
||||
|
||||
@baseline
|
||||
def naive(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
predict_method="predict",
|
||||
):
|
||||
c_model_predict = getattr(c_model, predict_method)
|
||||
f1_average = "binary" if validation.n_classes == 2 else "macro"
|
||||
|
||||
val_preds = c_model_predict(validation.X)
|
||||
val_acc = metrics.accuracy_score(validation.y, val_preds)
|
||||
val_f1 = metrics.f1_score(validation.y, val_preds, average=f1_average)
|
||||
|
||||
report = EvaluationReport(name="naive")
|
||||
for test in protocol():
|
||||
test_preds = c_model_predict(test.X)
|
||||
acc_score = metrics.accuracy_score(test.y, test_preds)
|
||||
f1_score = metrics.f1_score(test.y, test_preds, average=f1_average)
|
||||
meta_acc = abs(val_acc - acc_score)
|
||||
meta_f1 = abs(val_f1 - f1_score)
|
||||
report.append_row(
|
||||
test.prevalence(),
|
||||
acc_score=acc_score,
|
||||
f1_score=f1_score,
|
||||
acc=meta_acc,
|
||||
f1=meta_f1,
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
@baseline
|
||||
def ref(
|
||||
c_model: BaseEstimator,
|
||||
|
@ -556,4 +588,3 @@ def kdex2(
|
|||
report.append_row(test.prevalence(), acc=meta_score, acc_score=estim_acc)
|
||||
|
||||
return report
|
||||
|
||||
|
|
Loading…
Reference in New Issue