f1 updated
This commit is contained in:
parent
e01006e663
commit
5d0ecfda39
|
@ -38,7 +38,7 @@ def get_ATC_acc(thres, scores):
|
||||||
return np.mean(scores >= thres)
|
return np.mean(scores >= thres)
|
||||||
|
|
||||||
|
|
||||||
def get_ATC_f1(thres, scores, probs):
|
def get_ATC_f1(thres, scores, probs, average="binary"):
|
||||||
preds = np.argmax(probs, axis=-1)
|
preds = np.argmax(probs, axis=-1)
|
||||||
estim_y = np.abs(1 - (scores >= thres) ^ preds)
|
estim_y = np.abs(1 - (scores >= thres) ^ preds)
|
||||||
return f1_score(estim_y, preds)
|
return f1_score(estim_y, preds, average=average)
|
||||||
|
|
Loading…
Reference in New Issue