improved binary dataset generator

This commit is contained in:
Lorenzo Volpi 2024-04-05 15:52:29 +02:00
parent f787c4510d
commit 5ee04a2a19
1 changed files with 17 additions and 14 deletions

View File

@ -12,10 +12,10 @@ from sklearn.linear_model import LogisticRegression
from quacc.dataset import DatasetProvider as DP
from quacc.error import macrof1_fn, vanilla_acc_fn
from quacc.experiments.util import getpath
from quacc.models.base import ClassifierAccuracyPrediction
from quacc.models.baselines import ATC, DoC
from quacc.models.cont_table import CAPContingencyTable, ContTableTransferCAP, NaiveCAP
from quacc.utils.commons import get_results_path
def gen_classifiers():
@ -63,17 +63,20 @@ def gen_tweet_datasets(
def gen_bin_datasets(
only_names=False,
) -> [str, [LabelledCollection, LabelledCollection, LabelledCollection]]:
if only_names:
for dataset_name in ["imdb", "CCAT", "GCAT", "MCAT"]:
yield dataset_name, None
else:
yield "imdb", DP.imdb()
for rcv1_name in [
"CCAT",
"GCAT",
"MCAT",
]:
yield rcv1_name, DP.rcv1(rcv1_name)
_IMDB = [
"imdb",
]
_RCV1 = [
# "CCAT",
# "GCAT",
# "MCAT",
]
for dn in _IMDB:
dval = None if only_names else DP.imdb()
yield dn, dval
for dn in _RCV1:
dval = None if only_names else DP.rcv1(dn)
yield dn, dval
def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]:
@ -94,7 +97,7 @@ def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]
def gen_CAP_cont_table(h) -> [str, CAPContingencyTable]:
acc_fn = None
yield "Naive", NaiveCAP(h, acc_fn)
yield "CT-PPS-EMQ", ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
# yield "CT-PPS-EMQ", ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
# yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01))
# yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05))
# yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False)
@ -125,7 +128,7 @@ def gen_acc_measure():
def any_missing(basedir, cls_name, dataset_name, method_name):
for acc_name, _ in gen_acc_measure():
if not os.path.exists(
getpath(basedir, cls_name, acc_name, dataset_name, method_name)
get_results_path(basedir, cls_name, acc_name, dataset_name, method_name)
):
return True
return False