diff --git a/.coverage b/.coverage index c9d78c7..c3df57d 100644 Binary files a/.coverage and b/.coverage differ diff --git a/guillory21_doc/__pycache__/doc.cpython-311.pyc b/guillory21_doc/__pycache__/doc.cpython-311.pyc new file mode 100644 index 0000000..a98676f Binary files /dev/null and b/guillory21_doc/__pycache__/doc.cpython-311.pyc differ diff --git a/guillory21_doc/doc.py b/guillory21_doc/doc.py new file mode 100644 index 0000000..9b59883 --- /dev/null +++ b/guillory21_doc/doc.py @@ -0,0 +1,4 @@ +import numpy as np + +def get_doc(probs1, probs2): + return np.mean(probs2) - np.mean(probs1) \ No newline at end of file diff --git a/quacc/baseline.py b/quacc/baseline.py index 32dcacc..c508dce 100644 --- a/quacc/baseline.py +++ b/quacc/baseline.py @@ -1,16 +1,13 @@ +from ast import get_docstring from statistics import mean from typing import Dict from sklearn.base import BaseEstimator from sklearn.model_selection import cross_validate from quapy.data import LabelledCollection -from garg22_ATC.ATC_helper import ( - find_ATC_threshold, - get_ATC_acc, - get_entropy, - get_max_conf, -) +import garg22_ATC.ATC_helper as atc import numpy as np -from jiang18_trustscore.trustscore import TrustScore +import jiang18_trustscore.trustscore as trustscore +import guillory21_doc.doc as doc def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict: @@ -19,7 +16,7 @@ def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict: return {"f1_score": mean(scores["test_f1_macro"])} -def ATC_MC( +def atc_mc( c_model: BaseEstimator, validation: LabelledCollection, test: LabelledCollection, @@ -34,21 +31,23 @@ def ATC_MC( test_probs = c_model_predict(test.X) ## score function, e.g., negative entropy or argmax confidence - val_scores = get_max_conf(val_probs) + val_scores = atc.get_max_conf(val_probs) + #pred_idxv1 #calib_probsv1/probsv1 val_preds = np.argmax(val_probs, axis=-1) - - test_scores = get_max_conf(test_probs) - - _, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds) - ATC_accuracy = get_ATC_acc(ATC_thres, test_scores) + #pred_probs_new #probs_new + test_scores = atc.get_max_conf(test_probs) + #pred_probsv1 #labelsv1 #pred_idxv1 + _, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds) + #calib_thres_balance #pred_probs_new + atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores) return { "true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y), - "pred_acc": ATC_accuracy, + "pred_acc": atc_accuracy, } -def ATC_NE( +def atc_ne( c_model: BaseEstimator, validation: LabelledCollection, test: LabelledCollection, @@ -63,17 +62,17 @@ def ATC_NE( test_probs = c_model_predict(test.X) ## score function, e.g., negative entropy or argmax confidence - val_scores = get_entropy(val_probs) + val_scores = atc.get_entropy(val_probs) val_preds = np.argmax(val_probs, axis=-1) - test_scores = get_entropy(test_probs) + test_scores = atc.get_entropy(test_probs) - _, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds) - ATC_accuracy = get_ATC_acc(ATC_thres, test_scores) + _, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds) + atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores) return { "true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y), - "pred_acc": ATC_accuracy, + "pred_acc": atc_accuracy, } @@ -87,8 +86,25 @@ def trust_score( test_pred = c_model_predict(test.X) - trust_model = TrustScore() + trust_model = trustscore.TrustScore() trust_model.fit(validation.X, validation.y) return trust_model.get_score(test.X, test_pred) + +def doc_feat( + c_model: BaseEstimator, + validation: LabelledCollection, + test: LabelledCollection, + predict_method="predict_proba", +): + c_model_predict = getattr(c_model, predict_method) + + val_probs, val_labels = c_model_predict(validation.X), validation.y + test_probs = c_model_predict(test.X) + val_scores = np.max(val_probs, axis=-1) + test_scores = np.max(test_probs, axis=-1) + val_preds = np.argmax(val_probs, axis=-1) + + v1acc = np.mean(val_preds == val_labels)*100 + return v1acc + doc.get_doc(val_scores, test_scores)