imports fixed, added control variables
This commit is contained in:
parent
5bb66b85c8
commit
b17ae5e45d
|
@ -15,6 +15,7 @@ from quacc.experiments.generators import (
|
||||||
gen_multi_datasets,
|
gen_multi_datasets,
|
||||||
gen_tweet_datasets,
|
gen_tweet_datasets,
|
||||||
)
|
)
|
||||||
|
from quacc.experiments.plotting import save_plot_delta, save_plot_diagonal
|
||||||
from quacc.experiments.report import Report, TestReport
|
from quacc.experiments.report import Report, TestReport
|
||||||
from quacc.experiments.util import (
|
from quacc.experiments.util import (
|
||||||
fit_method,
|
fit_method,
|
||||||
|
@ -27,6 +28,8 @@ from quacc.experiments.util import (
|
||||||
PROBLEM = "binary"
|
PROBLEM = "binary"
|
||||||
ORACLE = False
|
ORACLE = False
|
||||||
basedir = PROBLEM + ("-oracle" if ORACLE else "")
|
basedir = PROBLEM + ("-oracle" if ORACLE else "")
|
||||||
|
EXPERIMENT = True
|
||||||
|
PLOTTING = True
|
||||||
|
|
||||||
|
|
||||||
if PROBLEM == "binary":
|
if PROBLEM == "binary":
|
||||||
|
@ -43,6 +46,7 @@ elif PROBLEM == "tweet":
|
||||||
gen_datasets = gen_tweet_datasets
|
gen_datasets = gen_tweet_datasets
|
||||||
|
|
||||||
|
|
||||||
|
if EXPERIMENT:
|
||||||
for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(
|
for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(
|
||||||
gen_classifiers(), gen_datasets()
|
gen_classifiers(), gen_datasets()
|
||||||
):
|
):
|
||||||
|
@ -68,7 +72,15 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(
|
||||||
for acc_name, acc_fn in gen_acc_measure():
|
for acc_name, acc_fn in gen_acc_measure():
|
||||||
print(f"\tfor measure {acc_name}")
|
print(f"\tfor measure {acc_name}")
|
||||||
for method_name, method in gen_CAP(h, acc_fn, with_oracle=ORACLE):
|
for method_name, method in gen_CAP(h, acc_fn, with_oracle=ORACLE):
|
||||||
report = TestReport(basedir, cls_name, acc_name, dataset_name, method_name)
|
report = TestReport(
|
||||||
|
basedir=basedir,
|
||||||
|
cls_name=cls_name,
|
||||||
|
acc_name=acc_name,
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
method_name=method_name,
|
||||||
|
train_prev=L.prevalence().tolist(),
|
||||||
|
val_prev=V.prevalence().tolist(),
|
||||||
|
)
|
||||||
if os.path.exists(report.path):
|
if os.path.exists(report.path):
|
||||||
print(f"\t\t{method_name}-{acc_name} exists, skipping")
|
print(f"\t\t{method_name}-{acc_name} exists, skipping")
|
||||||
continue
|
continue
|
||||||
|
@ -102,7 +114,9 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(
|
||||||
method, test_prot, gen_acc_measure, ORACLE
|
method, test_prot, gen_acc_measure, ORACLE
|
||||||
)
|
)
|
||||||
for acc_name, estim_accs in estim_accs_dict.items():
|
for acc_name, estim_accs in estim_accs_dict.items():
|
||||||
report = TestReport(basedir, cls_name, acc_name, dataset_name, method_name)
|
report = TestReport(
|
||||||
|
basedir, cls_name, acc_name, dataset_name, method_name
|
||||||
|
)
|
||||||
test_prevs = prevs_from_prot(test_prot)
|
test_prevs = prevs_from_prot(test_prot)
|
||||||
report.add_result(
|
report.add_result(
|
||||||
test_prevs=test_prevs,
|
test_prevs=test_prevs,
|
||||||
|
@ -115,17 +129,17 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# generate plots
|
# generate plots
|
||||||
print("generating plots")
|
if PLOTTING:
|
||||||
rep = Report.load_results(basedir)
|
for (cls_name, _), (acc_name, _) in itertools.product(
|
||||||
for rs in rep.results:
|
gen_classifiers(), gen_acc_measure()
|
||||||
print(rs.path)
|
):
|
||||||
|
save_plot_diagonal(basedir, cls_name, acc_name)
|
||||||
# for (cls_name, _), (acc_name, _) in itertools.product(
|
for dataset_name, _ in gen_datasets(only_names=True):
|
||||||
# gen_classifiers(), gen_acc_measure()
|
save_plot_diagonal(basedir, cls_name, acc_name, dataset_name=dataset_name)
|
||||||
# ):
|
save_plot_delta(basedir, cls_name, acc_name, dataset_name=dataset_name)
|
||||||
# plot_diagonal(basedir, cls_name, acc_name)
|
save_plot_delta(
|
||||||
# for dataset_name, _ in gen_datasets(only_names=True):
|
basedir, cls_name, acc_name, dataset_name=dataset_name, stdev=True
|
||||||
# plot_diagonal(basedir, cls_name, acc_name, dataset_name=dataset_name)
|
)
|
||||||
|
|
||||||
# print("generating tables")
|
# print("generating tables")
|
||||||
# gen_tables(basedir, datasets=[d for d, _ in gen_datasets(only_names=True)])
|
# gen_tables(basedir, datasets=[d for d, _ in gen_datasets(only_names=True)])
|
||||||
|
|
Loading…
Reference in New Issue