From 322f060a13f7ab88128279c1a4f473816d2e555d Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Fri, 5 Apr 2024 15:49:57 +0200 Subject: [PATCH] qcdash updated for plotting refactoring --- qcdash/app.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/qcdash/app.py b/qcdash/app.py index c8124bc..a3df930 100644 --- a/qcdash/app.py +++ b/qcdash/app.py @@ -76,57 +76,65 @@ def get_datasets(root: str | Path) -> List[DatasetReport]: return {str(drp.parent): load_dataset(drp) for drp in dr_paths} -def get_fig(rep: Report, dataset, metric, estimators, view, mode): +def get_fig(rep: Report, cls_name, acc_name, dataset_name, estimators, view, mode): match (view, mode): case ("avg", "diagonal"): true_accs, estim_accs = rep.diagonal_plot_data( - dataset_name=dataset, + dataset_name=dataset_name, method_names=estimators, - acc_name=metric, + acc_name=acc_name, ) return plot_diagonal( method_names=estimators, true_accs=true_accs, estim_accs=estim_accs, - measure_name=metric, + cls_name=cls_name, + acc_name=acc_name, + dataset_name=dataset_name, ) case ("avg", "delta_train"): prevs, acc_errs = rep.delta_train_plot_data( - dataset_name=dataset, + dataset_name=dataset_name, method_names=estimators, - acc_name=metric, + acc_name=acc_name, ) return plot_delta( method_names=estimators, prevs=prevs, acc_errs=acc_errs, - measure_name=metric, + cls_name=cls_name, + acc_mame=acc_name, + dataset_name=dataset_name, prev_name="Test", ) case ("avg", "stdev_train"): prevs, acc_errs, stdevs = rep.delta_train_plot_data( - dataset_name=dataset, + dataset_name=dataset_name, method_names=estimators, - acc_name=metric, + acc_name=acc_name, stdev=True, ) return plot_delta( method_names=estimators, prevs=prevs, acc_errs=acc_errs, - measure_name=metric, + cls_name=cls_name, + acc_mame=acc_name, + dataset_name=dataset_name, prev_name="Test", stdevs=stdevs, ) case ("avg", "shift"): prevs, acc_errs, counts = rep.shift_plot_data( - dataset_name=dataset, method_names=estimators, acc_name=metric + dataset_name=dataset_name, method_names=estimators, acc_name=acc_name ) return plot_shift( method_names=estimators, prevs=prevs, acc_errs=acc_errs, - measure_name=metric, + cls_name=cls_name, + acc_name=acc_name, + dataset_name=dataset_name, counts=counts, ) case (_, _): @@ -543,7 +551,7 @@ def update_content(dataset, metric, estimators, view, mode, root): case _: fig = get_fig( dr=dr, - metric=metric, + acc_name=metric, estimators=estimators, view=view, mode=mode,