improved binary dataset generator

This commit is contained in:
Lorenzo Volpi 2024-04-05 15:52:29 +02:00
parent f787c4510d
commit 5ee04a2a19
1 changed files with 17 additions and 14 deletions

View File

@ -12,10 +12,10 @@ from sklearn.linear_model import LogisticRegression
from quacc.dataset import DatasetProvider as DP from quacc.dataset import DatasetProvider as DP
from quacc.error import macrof1_fn, vanilla_acc_fn from quacc.error import macrof1_fn, vanilla_acc_fn
from quacc.experiments.util import getpath
from quacc.models.base import ClassifierAccuracyPrediction from quacc.models.base import ClassifierAccuracyPrediction
from quacc.models.baselines import ATC, DoC from quacc.models.baselines import ATC, DoC
from quacc.models.cont_table import CAPContingencyTable, ContTableTransferCAP, NaiveCAP from quacc.models.cont_table import CAPContingencyTable, ContTableTransferCAP, NaiveCAP
from quacc.utils.commons import get_results_path
def gen_classifiers(): def gen_classifiers():
@ -63,17 +63,20 @@ def gen_tweet_datasets(
def gen_bin_datasets( def gen_bin_datasets(
only_names=False, only_names=False,
) -> [str, [LabelledCollection, LabelledCollection, LabelledCollection]]: ) -> [str, [LabelledCollection, LabelledCollection, LabelledCollection]]:
if only_names: _IMDB = [
for dataset_name in ["imdb", "CCAT", "GCAT", "MCAT"]: "imdb",
yield dataset_name, None ]
else: _RCV1 = [
yield "imdb", DP.imdb() # "CCAT",
for rcv1_name in [ # "GCAT",
"CCAT", # "MCAT",
"GCAT", ]
"MCAT", for dn in _IMDB:
]: dval = None if only_names else DP.imdb()
yield rcv1_name, DP.rcv1(rcv1_name) yield dn, dval
for dn in _RCV1:
dval = None if only_names else DP.rcv1(dn)
yield dn, dval
def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]: def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]:
@ -94,7 +97,7 @@ def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]
def gen_CAP_cont_table(h) -> [str, CAPContingencyTable]: def gen_CAP_cont_table(h) -> [str, CAPContingencyTable]:
acc_fn = None acc_fn = None
yield "Naive", NaiveCAP(h, acc_fn) yield "Naive", NaiveCAP(h, acc_fn)
yield "CT-PPS-EMQ", ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression())) # yield "CT-PPS-EMQ", ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
# yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01)) # yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01))
# yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05)) # yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05))
# yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False) # yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False)
@ -125,7 +128,7 @@ def gen_acc_measure():
def any_missing(basedir, cls_name, dataset_name, method_name): def any_missing(basedir, cls_name, dataset_name, method_name):
for acc_name, _ in gen_acc_measure(): for acc_name, _ in gen_acc_measure():
if not os.path.exists( if not os.path.exists(
getpath(basedir, cls_name, acc_name, dataset_name, method_name) get_results_path(basedir, cls_name, acc_name, dataset_name, method_name)
): ):
return True return True
return False return False