diff --git a/qcdash/app.py b/qcdash/app.py index cf4f21f..c8124bc 100644 --- a/qcdash/app.py +++ b/qcdash/app.py @@ -12,10 +12,12 @@ import numpy as np from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html from dash.dash_table.Format import Align, Format, Scheme -from quacc import plot -from quacc.evaluation.estimators import CE, _renames -from quacc.evaluation.report import CompReport, DatasetReport -from quacc.evaluation.stats import wilcoxon +from quacc.experiments.report import Report +from quacc.experiments.util import get_acc_name +from quacc.legacy.evaluation.estimators import CE, _renames +from quacc.legacy.evaluation.report import CompReport, DatasetReport +from quacc.legacy.evaluation.stats import wilcoxon +from quacc.plot.plotly import plot_delta, plot_diagonal, plot_shift valid_plot_modes = defaultdict(lambda: CompReport._default_modes) valid_plot_modes["avg"] = DatasetReport._default_dr_modes @@ -74,29 +76,61 @@ def get_datasets(root: str | Path) -> List[DatasetReport]: return {str(drp.parent): load_dataset(drp) for drp in dr_paths} -def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None): - _backend = backend or plot.get_backend("plotly") - estimators = CE.name[estimators] +def get_fig(rep: Report, dataset, metric, estimators, view, mode): match (view, mode): - case ("avg", _): - return dr.get_plots( - mode=mode, - metric=metric, - estimators=estimators, - conf="plotly", - save_fig=False, - backend=_backend, + case ("avg", "diagonal"): + true_accs, estim_accs = rep.diagonal_plot_data( + dataset_name=dataset, + method_names=estimators, + acc_name=metric, + ) + return plot_diagonal( + method_names=estimators, + true_accs=true_accs, + estim_accs=estim_accs, + measure_name=metric, + ) + case ("avg", "delta_train"): + prevs, acc_errs = rep.delta_train_plot_data( + dataset_name=dataset, + method_names=estimators, + acc_name=metric, + ) + return plot_delta( + method_names=estimators, + prevs=prevs, + acc_errs=acc_errs, + measure_name=metric, + prev_name="Test", + ) + case ("avg", "stdev_train"): + prevs, acc_errs, stdevs = rep.delta_train_plot_data( + dataset_name=dataset, + method_names=estimators, + acc_name=metric, + stdev=True, + ) + return plot_delta( + method_names=estimators, + prevs=prevs, + acc_errs=acc_errs, + measure_name=metric, + prev_name="Test", + stdevs=stdevs, + ) + case ("avg", "shift"): + prevs, acc_errs, counts = rep.shift_plot_data( + dataset_name=dataset, method_names=estimators, acc_name=metric + ) + return plot_shift( + method_names=estimators, + prevs=prevs, + acc_errs=acc_errs, + measure_name=metric, + counts=counts, ) case (_, _): - cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)] - return cr.get_plots( - mode=mode, - metric=metric, - estimators=estimators, - conf="plotly", - save_fig=False, - backend=_backend, - ) + return None def get_table(dr: DatasetReport, metric, estimators, view, mode):