ref baseline fixed
This commit is contained in:
parent
b96432f87b
commit
14326b2122
|
@ -65,11 +65,10 @@ def ref(
|
|||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
):
|
||||
c_model_predict = getattr(c_model, "predict_proba")
|
||||
c_model_predict = getattr(c_model, "predict")
|
||||
report = EvaluationReport(name="ref")
|
||||
for test in protocol():
|
||||
test_probs = c_model_predict(test.X)
|
||||
test_preds = np.argmax(test_probs, axis=-1)
|
||||
test_preds = c_model_predict(test.X)
|
||||
report.append_row(
|
||||
test.prevalence(),
|
||||
acc_score=metrics.accuracy_score(test.y, test_preds),
|
||||
|
|
Loading…
Reference in New Issue