From 531d22573b412c3cf8ba203c395111a97ccdea4e Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Wed, 31 Jan 2024 18:06:58 +0100 Subject: [PATCH] update --- conf.yaml | 10 ++++------ copy_source.sh | 6 ++++-- log | 2 ++ quacc/dataset.py | 6 ++++-- quacc/evaluation/baseline.py | 33 ++++++++++++++++++++++++++++++++- quacc/evaluation/method.py | 7 ++++--- quacc/evaluation/report.py | 8 ++++++++ rates.md | 15 +++++++++++++++ run.py | 4 ++++ 9 files changed, 77 insertions(+), 14 deletions(-) create mode 100644 rates.md diff --git a/conf.yaml b/conf.yaml index feab9ce..e6e7f4b 100644 --- a/conf.yaml +++ b/conf.yaml @@ -71,14 +71,14 @@ test_conf: &test_conf main: confs: &main_confs + - DATASET_NAME: imdb - DATASET_NAME: rcv1 DATASET_TARGET: CCAT - other_confs: - - DATASET_NAME: imdb - DATASET_NAME: rcv1 DATASET_TARGET: GCAT - DATASET_NAME: rcv1 DATASET_TARGET: MCAT + other_confs: sld_lr_conf: &sld_lr_conf @@ -348,9 +348,7 @@ baselines_conf: &baselines_conf COMP_ESTIMATORS: - doc - atc_mc - - mandoline - - rca - - rca_star + - naive N_JOBS: -2 confs: *main_confs @@ -406,4 +404,4 @@ timing_conf: &timing_conf confs: *main_confs -exec: *kde_lr_gs_conf +exec: *baselines_conf diff --git a/copy_source.sh b/copy_source.sh index 8accd61..b52d9a0 100755 --- a/copy_source.sh +++ b/copy_source.sh @@ -1,7 +1,9 @@ #!/bin/bash -CMD="cp" -DEST="~/tesi_docker/" +CMD="scp" +DEST="andreaesuli@edge-nd1.isti.cnr.it:~/raid/lorenzo/" +# CMD="cp" +# DEST="~/tesi_docker/" bash -c "${CMD} -r quacc ${DEST}" bash -c "${CMD} -r baselines ${DEST}" diff --git a/log b/log index 5f9501e..b9948ff 100755 --- a/log +++ b/log @@ -3,6 +3,8 @@ if [[ "${1}" == "r" ]]; then scp volpi@ilona.isti.cnr.it:~/tesi/quacc.log ~/tesi/remote.log &>/dev/null ssh volpi@ilona.isti.cnr.it tail -n 500 -f /home/volpi/tesi/quacc.log | bat -P --language=log +elif [[ "${1}" == "d" ]]; then + ssh andreaesuli@edge-nd1.isti.cnr.it tail -n 500 -f /home/andreaesuli/raid/lorenzo/quacc.log | bat -P --language=log else tail -n 500 -f /home/lorev/tesi/quacc.log | bat --paging=never --language log fi diff --git a/quacc/dataset.py b/quacc/dataset.py index 37f8f8d..9232308 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -126,7 +126,9 @@ class DatasetProvider: # provare min_df=5 def __imdb(self, **kwargs): - return qp.datasets.fetch_reviews("imdb", tfidf=True, min_df=3).train_test + return qp.datasets.fetch_reviews( + "imdb", data_home="./quapy_data", tfidf=True, min_df=3 + ).train_test def __rcv1(self, target, **kwargs): n_train = 23149 @@ -135,7 +137,7 @@ class DatasetProvider: if target is None or target not in available_targets: raise ValueError(f"Invalid target {target}") - dataset = fetch_rcv1() + dataset = fetch_rcv1(data_home="./scikit_learn_data") target_index = np.where(dataset.target_names == target)[0] all_train_d = dataset.data[:n_train, :] test_d = dataset.data[n_train:, :] diff --git a/quacc/evaluation/baseline.py b/quacc/evaluation/baseline.py index 605b5ca..436dd47 100644 --- a/quacc/evaluation/baseline.py +++ b/quacc/evaluation/baseline.py @@ -68,6 +68,38 @@ def kfcv( return report +@baseline +def naive( + c_model: BaseEstimator, + validation: LabelledCollection, + protocol: AbstractStochasticSeededProtocol, + predict_method="predict", +): + c_model_predict = getattr(c_model, predict_method) + f1_average = "binary" if validation.n_classes == 2 else "macro" + + val_preds = c_model_predict(validation.X) + val_acc = metrics.accuracy_score(validation.y, val_preds) + val_f1 = metrics.f1_score(validation.y, val_preds, average=f1_average) + + report = EvaluationReport(name="naive") + for test in protocol(): + test_preds = c_model_predict(test.X) + acc_score = metrics.accuracy_score(test.y, test_preds) + f1_score = metrics.f1_score(test.y, test_preds, average=f1_average) + meta_acc = abs(val_acc - acc_score) + meta_f1 = abs(val_f1 - f1_score) + report.append_row( + test.prevalence(), + acc_score=acc_score, + f1_score=f1_score, + acc=meta_acc, + f1=meta_f1, + ) + + return report + + @baseline def ref( c_model: BaseEstimator, @@ -556,4 +588,3 @@ def kdex2( report.append_row(test.prevalence(), acc=meta_score, acc_score=estim_acc) return report - diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py index 86a1fb8..9942b65 100644 --- a/quacc/evaluation/method.py +++ b/quacc/evaluation/method.py @@ -380,9 +380,9 @@ __kde_lr_set = [ M("mul_kde_lr_a", __kde_lr(), "mul", conf=["max_conf", "entropy", "isoft"], ), M("m3w_kde_lr_a", __kde_lr(), "mul", conf=["max_conf", "entropy", "isoft"], cf=True), # gs kde - G("bin_kde_lr_gs", __kde_lr(), "bin", pg="kde_lr", search="spider" ), - G("mul_kde_lr_gs", __kde_lr(), "mul", pg="kde_lr", search="spider" ), - G("m3w_kde_lr_gs", __kde_lr(), "mul", pg="kde_lr", search="spider", cf=True), + G("bin_kde_lr_gs", __kde_lr(), "bin", pg="kde_lr", search="grid" ), + G("mul_kde_lr_gs", __kde_lr(), "mul", pg="kde_lr", search="grid" ), + G("m3w_kde_lr_gs", __kde_lr(), "mul", pg="kde_lr", search="grid", cf=True), E("kde_lr_gs"), ] @@ -458,6 +458,7 @@ __methods_set = ( + __kde_lr_set + __dense_kde_lr_set + __dense_kde_rbf_set + + [E("QuAcc")] ) _methods = {m.name: m for m in __methods_set} diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index 9f63d9a..bfe394e 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -140,6 +140,14 @@ class CompReport: "mul_kde_lr_gs", "m3w_kde_lr_gs", ], + "QuAcc": [ + "bin_sld_lr_gs", + "mul_sld_lr_gs", + "m3w_sld_lr_gs", + "bin_kde_lr_gs", + "mul_kde_lr_gs", + "m3w_kde_lr_gs", + ], } for name, methods in _mapping.items(): diff --git a/rates.md b/rates.md new file mode 100644 index 0000000..b136fb5 --- /dev/null +++ b/rates.md @@ -0,0 +1,15 @@ +# Additional covariates percentage + +Rate of usage of additional covariates, recalibration and "balanced" class_weight +during grid search: + +| method | av % | recalib % | rebalance % | +| --------------: | :----: | :-------: | :---------: | +| imdb_sld_lr | 81.49% | 77.78% | 59.26% | +| imdb_kde_lr | 71.43% | NA | 88.18% | +| rcv1_CCAT_sld_lr| 62.97% | 70.38% | 77.78% | +| rcv1_CCAT_kde_lr| 78.06% | NA | 84.82% | +| rcv1_GCAT_sld_lr| 76.93% | 61.54% | 65.39% | +| rcv1_GCAT_kde_lr| 71.36% | NA | 78.65% | +| rcv1_MCAT_sld_lr| 62.97% | 48.15% | 74.08% | +| rcv1_MCAT_kde_lr| 71.03% | NA | 68.70% | diff --git a/run.py b/run.py index eddab99..e678ec9 100644 --- a/run.py +++ b/run.py @@ -15,3 +15,7 @@ def run(): run_local() elif args.remote: run_remote(detatch=args.detatch) + + +if __name__ == "__main__": + run()