imports fixed, added control variables

This commit is contained in:
Lorenzo Volpi 2024-04-08 17:56:41 +02:00
parent 5bb66b85c8
commit b17ae5e45d
1 changed files with 86 additions and 72 deletions

View File

@ -15,6 +15,7 @@ from quacc.experiments.generators import (
gen_multi_datasets, gen_multi_datasets,
gen_tweet_datasets, gen_tweet_datasets,
) )
from quacc.experiments.plotting import save_plot_delta, save_plot_diagonal
from quacc.experiments.report import Report, TestReport from quacc.experiments.report import Report, TestReport
from quacc.experiments.util import ( from quacc.experiments.util import (
fit_method, fit_method,
@ -27,6 +28,8 @@ from quacc.experiments.util import (
PROBLEM = "binary" PROBLEM = "binary"
ORACLE = False ORACLE = False
basedir = PROBLEM + ("-oracle" if ORACLE else "") basedir = PROBLEM + ("-oracle" if ORACLE else "")
EXPERIMENT = True
PLOTTING = True
if PROBLEM == "binary": if PROBLEM == "binary":
@ -43,6 +46,7 @@ elif PROBLEM == "tweet":
gen_datasets = gen_tweet_datasets gen_datasets = gen_tweet_datasets
if EXPERIMENT:
for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product( for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(
gen_classifiers(), gen_datasets() gen_classifiers(), gen_datasets()
): ):
@ -68,7 +72,15 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(
for acc_name, acc_fn in gen_acc_measure(): for acc_name, acc_fn in gen_acc_measure():
print(f"\tfor measure {acc_name}") print(f"\tfor measure {acc_name}")
for method_name, method in gen_CAP(h, acc_fn, with_oracle=ORACLE): for method_name, method in gen_CAP(h, acc_fn, with_oracle=ORACLE):
report = TestReport(basedir, cls_name, acc_name, dataset_name, method_name) report = TestReport(
basedir=basedir,
cls_name=cls_name,
acc_name=acc_name,
dataset_name=dataset_name,
method_name=method_name,
train_prev=L.prevalence().tolist(),
val_prev=V.prevalence().tolist(),
)
if os.path.exists(report.path): if os.path.exists(report.path):
print(f"\t\t{method_name}-{acc_name} exists, skipping") print(f"\t\t{method_name}-{acc_name} exists, skipping")
continue continue
@ -102,7 +114,9 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(
method, test_prot, gen_acc_measure, ORACLE method, test_prot, gen_acc_measure, ORACLE
) )
for acc_name, estim_accs in estim_accs_dict.items(): for acc_name, estim_accs in estim_accs_dict.items():
report = TestReport(basedir, cls_name, acc_name, dataset_name, method_name) report = TestReport(
basedir, cls_name, acc_name, dataset_name, method_name
)
test_prevs = prevs_from_prot(test_prot) test_prevs = prevs_from_prot(test_prot)
report.add_result( report.add_result(
test_prevs=test_prevs, test_prevs=test_prevs,
@ -115,17 +129,17 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(
print() print()
# generate plots # generate plots
print("generating plots") if PLOTTING:
rep = Report.load_results(basedir) for (cls_name, _), (acc_name, _) in itertools.product(
for rs in rep.results: gen_classifiers(), gen_acc_measure()
print(rs.path) ):
save_plot_diagonal(basedir, cls_name, acc_name)
# for (cls_name, _), (acc_name, _) in itertools.product( for dataset_name, _ in gen_datasets(only_names=True):
# gen_classifiers(), gen_acc_measure() save_plot_diagonal(basedir, cls_name, acc_name, dataset_name=dataset_name)
# ): save_plot_delta(basedir, cls_name, acc_name, dataset_name=dataset_name)
# plot_diagonal(basedir, cls_name, acc_name) save_plot_delta(
# for dataset_name, _ in gen_datasets(only_names=True): basedir, cls_name, acc_name, dataset_name=dataset_name, stdev=True
# plot_diagonal(basedir, cls_name, acc_name, dataset_name=dataset_name) )
# print("generating tables") # print("generating tables")
# gen_tables(basedir, datasets=[d for d, _ in gen_datasets(only_names=True)]) # gen_tables(basedir, datasets=[d for d, _ in gen_datasets(only_names=True)])