From 5d0ecfda390c76cd441d53f08de6cb22ca5dd43f Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 21 Dec 2023 16:47:20 +0100 Subject: [PATCH] f1 updated --- baselines/atc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/baselines/atc.py b/baselines/atc.py index be93d24..689f8f4 100644 --- a/baselines/atc.py +++ b/baselines/atc.py @@ -38,7 +38,7 @@ def get_ATC_acc(thres, scores): return np.mean(scores >= thres) -def get_ATC_f1(thres, scores, probs): +def get_ATC_f1(thres, scores, probs, average="binary"): preds = np.argmax(probs, axis=-1) estim_y = np.abs(1 - (scores >= thres) ^ preds) - return f1_score(estim_y, preds) + return f1_score(estim_y, preds, average=average)