improved binary dataset generator
This commit is contained in:
parent
f787c4510d
commit
5ee04a2a19
|
@ -12,10 +12,10 @@ from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
from quacc.dataset import DatasetProvider as DP
|
from quacc.dataset import DatasetProvider as DP
|
||||||
from quacc.error import macrof1_fn, vanilla_acc_fn
|
from quacc.error import macrof1_fn, vanilla_acc_fn
|
||||||
from quacc.experiments.util import getpath
|
|
||||||
from quacc.models.base import ClassifierAccuracyPrediction
|
from quacc.models.base import ClassifierAccuracyPrediction
|
||||||
from quacc.models.baselines import ATC, DoC
|
from quacc.models.baselines import ATC, DoC
|
||||||
from quacc.models.cont_table import CAPContingencyTable, ContTableTransferCAP, NaiveCAP
|
from quacc.models.cont_table import CAPContingencyTable, ContTableTransferCAP, NaiveCAP
|
||||||
|
from quacc.utils.commons import get_results_path
|
||||||
|
|
||||||
|
|
||||||
def gen_classifiers():
|
def gen_classifiers():
|
||||||
|
@ -63,17 +63,20 @@ def gen_tweet_datasets(
|
||||||
def gen_bin_datasets(
|
def gen_bin_datasets(
|
||||||
only_names=False,
|
only_names=False,
|
||||||
) -> [str, [LabelledCollection, LabelledCollection, LabelledCollection]]:
|
) -> [str, [LabelledCollection, LabelledCollection, LabelledCollection]]:
|
||||||
if only_names:
|
_IMDB = [
|
||||||
for dataset_name in ["imdb", "CCAT", "GCAT", "MCAT"]:
|
"imdb",
|
||||||
yield dataset_name, None
|
]
|
||||||
else:
|
_RCV1 = [
|
||||||
yield "imdb", DP.imdb()
|
# "CCAT",
|
||||||
for rcv1_name in [
|
# "GCAT",
|
||||||
"CCAT",
|
# "MCAT",
|
||||||
"GCAT",
|
]
|
||||||
"MCAT",
|
for dn in _IMDB:
|
||||||
]:
|
dval = None if only_names else DP.imdb()
|
||||||
yield rcv1_name, DP.rcv1(rcv1_name)
|
yield dn, dval
|
||||||
|
for dn in _RCV1:
|
||||||
|
dval = None if only_names else DP.rcv1(dn)
|
||||||
|
yield dn, dval
|
||||||
|
|
||||||
|
|
||||||
def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]:
|
def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]:
|
||||||
|
@ -94,7 +97,7 @@ def gen_CAP(h, acc_fn, with_oracle=False) -> [str, ClassifierAccuracyPrediction]
|
||||||
def gen_CAP_cont_table(h) -> [str, CAPContingencyTable]:
|
def gen_CAP_cont_table(h) -> [str, CAPContingencyTable]:
|
||||||
acc_fn = None
|
acc_fn = None
|
||||||
yield "Naive", NaiveCAP(h, acc_fn)
|
yield "Naive", NaiveCAP(h, acc_fn)
|
||||||
yield "CT-PPS-EMQ", ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
|
# yield "CT-PPS-EMQ", ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
|
||||||
# yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01))
|
# yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01))
|
||||||
# yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05))
|
# yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05))
|
||||||
# yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False)
|
# yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False)
|
||||||
|
@ -125,7 +128,7 @@ def gen_acc_measure():
|
||||||
def any_missing(basedir, cls_name, dataset_name, method_name):
|
def any_missing(basedir, cls_name, dataset_name, method_name):
|
||||||
for acc_name, _ in gen_acc_measure():
|
for acc_name, _ in gen_acc_measure():
|
||||||
if not os.path.exists(
|
if not os.path.exists(
|
||||||
getpath(basedir, cls_name, acc_name, dataset_name, method_name)
|
get_results_path(basedir, cls_name, acc_name, dataset_name, method_name)
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
Loading…
Reference in New Issue