improved binary dataset generator
This commit is contained in:
parent
f787c4510d
commit
5ee04a2a19
|
@ -12,10 +12,10 @@ from sklearn.linear_model import LogisticRegression
|
|||
|
||||
from quacc.dataset import DatasetProvider as DP
|
||||
from quacc.error import macrof1_fn, vanilla_acc_fn
|
||||
from quacc.experiments.util import getpath
|
||||
from quacc.models.base import ClassifierAccuracyPrediction
|
||||
from quacc.models.baselines import ATC, DoC
|
||||
from quacc.models.cont_table import CAPContingencyTable, ContTableTransferCAP, NaiveCAP
|
||||
from quacc.utils.commons import get_results_path
|
||||
|
||||
|
||||
def gen_classifiers():
|
||||
|
@ -63,17 +63,20 @@ def gen_tweet_datasets(
|
|||
def gen_bin_datasets(
|
||||
only_names=False,
|
||||
) -> [str, [LabelledCollection, LabelledCollection, LabelledCollection]]:
|
||||
if only_names:
|
||||
for dataset_name in ["imdb", "CCAT", "GCAT", "MCAT"]:
|
||||
yield dataset_name, None
|
||||
else:
|
||||
yield "imdb", DP.imdb()
|
||||
for rcv1_name in [
|
||||
"CCAT",
|
||||
"GCAT",
|
||||
"MCAT",
|
||||
]:
|
||||
yield rcv1_name, DP.rcv1(rcv1_name)
|
||||
_IMDB = [
|
||||
"imdb",
|
||||
]
|
||||
_RCV1 = [
|
||||
# "CCAT",
|
||||
# "GCAT",
|
||||
# "MCAT",
|
||||
]
|
||||
for dn in _IMDB:
|
||||
dval = None if only_names else DP.imdb()
|
||||
yield dn, dval
|
||||
for dn in _RCV1:
|
||||
dval = None if only_names else DP.rcv1(dn)
|
||||
yield dn, dval
|
||||
|
||||
|
||||
def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]:
|
||||
|
@ -94,7 +97,7 @@ def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]
|
|||
def gen_CAP_cont_table(h) -> [str, CAPContingencyTable]:
|
||||
acc_fn = None
|
||||
yield "Naive", NaiveCAP(h, acc_fn)
|
||||
yield "CT-PPS-EMQ", ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
|
||||
# yield "CT-PPS-EMQ", ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
|
||||
# yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01))
|
||||
# yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05))
|
||||
# yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False)
|
||||
|
@ -125,7 +128,7 @@ def gen_acc_measure():
|
|||
def any_missing(basedir, cls_name, dataset_name, method_name):
|
||||
for acc_name, _ in gen_acc_measure():
|
||||
if not os.path.exists(
|
||||
getpath(basedir, cls_name, acc_name, dataset_name, method_name)
|
||||
get_results_path(basedir, cls_name, acc_name, dataset_name, method_name)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
|
Loading…
Reference in New Issue