naive baseline added
This commit is contained in:
parent
ababeb426a
commit
921caaf426
|
@ -71,14 +71,14 @@ test_conf: &test_conf
|
||||||
|
|
||||||
main:
|
main:
|
||||||
confs: &main_confs
|
confs: &main_confs
|
||||||
|
- DATASET_NAME: imdb
|
||||||
- DATASET_NAME: rcv1
|
- DATASET_NAME: rcv1
|
||||||
DATASET_TARGET: CCAT
|
DATASET_TARGET: CCAT
|
||||||
other_confs:
|
|
||||||
- DATASET_NAME: imdb
|
|
||||||
- DATASET_NAME: rcv1
|
- DATASET_NAME: rcv1
|
||||||
DATASET_TARGET: GCAT
|
DATASET_TARGET: GCAT
|
||||||
- DATASET_NAME: rcv1
|
- DATASET_NAME: rcv1
|
||||||
DATASET_TARGET: MCAT
|
DATASET_TARGET: MCAT
|
||||||
|
other_confs:
|
||||||
|
|
||||||
sld_lr_conf: &sld_lr_conf
|
sld_lr_conf: &sld_lr_conf
|
||||||
|
|
||||||
|
@ -407,4 +407,4 @@ timing_conf: &timing_conf
|
||||||
|
|
||||||
confs: *main_confs
|
confs: *main_confs
|
||||||
|
|
||||||
exec: *kde_lr_gs_conf
|
exec: *baselines_conf
|
||||||
|
|
|
@ -68,6 +68,38 @@ def kfcv(
|
||||||
return report
|
return report
|
||||||
|
|
||||||
|
|
||||||
|
@baseline
|
||||||
|
def naive(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
predict_method="predict",
|
||||||
|
):
|
||||||
|
c_model_predict = getattr(c_model, predict_method)
|
||||||
|
f1_average = "binary" if validation.n_classes == 2 else "macro"
|
||||||
|
|
||||||
|
val_preds = c_model_predict(validation.X)
|
||||||
|
val_acc = metrics.accuracy_score(validation.y, val_preds)
|
||||||
|
val_f1 = metrics.f1_score(validation.y, val_preds, average=f1_average)
|
||||||
|
|
||||||
|
report = EvaluationReport(name="naive")
|
||||||
|
for test in protocol():
|
||||||
|
test_preds = c_model_predict(test.X)
|
||||||
|
acc_score = metrics.accuracy_score(test.y, test_preds)
|
||||||
|
f1_score = metrics.f1_score(test.y, test_preds, average=f1_average)
|
||||||
|
meta_acc = abs(val_acc - acc_score)
|
||||||
|
meta_f1 = abs(val_f1 - f1_score)
|
||||||
|
report.append_row(
|
||||||
|
test.prevalence(),
|
||||||
|
acc_score=acc_score,
|
||||||
|
f1_score=f1_score,
|
||||||
|
acc=meta_acc,
|
||||||
|
f1=meta_f1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
|
||||||
@baseline
|
@baseline
|
||||||
def ref(
|
def ref(
|
||||||
c_model: BaseEstimator,
|
c_model: BaseEstimator,
|
||||||
|
@ -556,4 +588,3 @@ def kdex2(
|
||||||
report.append_row(test.prevalence(), acc=meta_score, acc_score=estim_acc)
|
report.append_row(test.prevalence(), acc=meta_score, acc_score=estim_acc)
|
||||||
|
|
||||||
return report
|
return report
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue