diff --git a/quacc/experiments/generators.py b/quacc/experiments/generators.py index 17f550f..50e5cc9 100644 --- a/quacc/experiments/generators.py +++ b/quacc/experiments/generators.py @@ -12,10 +12,10 @@ from sklearn.linear_model import LogisticRegression from quacc.dataset import DatasetProvider as DP from quacc.error import macrof1_fn, vanilla_acc_fn -from quacc.experiments.util import getpath from quacc.models.base import ClassifierAccuracyPrediction from quacc.models.baselines import ATC, DoC from quacc.models.cont_table import CAPContingencyTable, ContTableTransferCAP, NaiveCAP +from quacc.utils.commons import get_results_path def gen_classifiers(): @@ -63,17 +63,20 @@ def gen_tweet_datasets( def gen_bin_datasets( only_names=False, ) -> [str, [LabelledCollection, LabelledCollection, LabelledCollection]]: - if only_names: - for dataset_name in ["imdb", "CCAT", "GCAT", "MCAT"]: - yield dataset_name, None - else: - yield "imdb", DP.imdb() - for rcv1_name in [ - "CCAT", - "GCAT", - "MCAT", - ]: - yield rcv1_name, DP.rcv1(rcv1_name) + _IMDB = [ + "imdb", + ] + _RCV1 = [ + # "CCAT", + # "GCAT", + # "MCAT", + ] + for dn in _IMDB: + dval = None if only_names else DP.imdb() + yield dn, dval + for dn in _RCV1: + dval = None if only_names else DP.rcv1(dn) + yield dn, dval def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]: @@ -94,7 +97,7 @@ def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction] def gen_CAP_cont_table(h) -> [str, CAPContingencyTable]: acc_fn = None yield "Naive", NaiveCAP(h, acc_fn) - yield "CT-PPS-EMQ", ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression())) + # yield "CT-PPS-EMQ", ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression())) # yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01)) # yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05)) # yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False) @@ -125,7 +128,7 @@ def gen_acc_measure(): def any_missing(basedir, cls_name, dataset_name, method_name): for acc_name, _ in gen_acc_measure(): if not os.path.exists( - getpath(basedir, cls_name, acc_name, dataset_name, method_name) + get_results_path(basedir, cls_name, acc_name, dataset_name, method_name) ): return True return False