Comments fixed
This commit is contained in:
parent
f537ecb5e4
commit
623f6f96f3
|
@ -32,14 +32,11 @@ def atc_mc(
|
||||||
|
|
||||||
## score function, e.g., negative entropy or argmax confidence
|
## score function, e.g., negative entropy or argmax confidence
|
||||||
val_scores = atc.get_max_conf(val_probs)
|
val_scores = atc.get_max_conf(val_probs)
|
||||||
#pred_idxv1 #calib_probsv1/probsv1
|
|
||||||
val_preds = np.argmax(val_probs, axis=-1)
|
val_preds = np.argmax(val_probs, axis=-1)
|
||||||
#pred_probs_new #probs_new
|
|
||||||
test_scores = atc.get_max_conf(test_probs)
|
test_scores = atc.get_max_conf(test_probs)
|
||||||
#pred_probsv1 #labelsv1 #pred_idxv1
|
|
||||||
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
|
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
|
||||||
#calib_thres_balance #pred_probs_new
|
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
|
||||||
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y),
|
"true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y),
|
||||||
|
@ -106,5 +103,5 @@ def doc_feat(
|
||||||
test_scores = np.max(test_probs, axis=-1)
|
test_scores = np.max(test_probs, axis=-1)
|
||||||
val_preds = np.argmax(val_probs, axis=-1)
|
val_preds = np.argmax(val_probs, axis=-1)
|
||||||
|
|
||||||
v1acc = np.mean(val_preds == val_labels)*100
|
v1acc = np.mean(val_preds == val_labels) * 100
|
||||||
return v1acc + doc.get_doc(val_scores, test_scores)
|
return v1acc + doc.get_doc(val_scores, test_scores)
|
||||||
|
|
Loading…
Reference in New Issue