From d7bb8bb2b994aa0cc5fbce8552fb7f8164441d8d Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Mon, 8 Apr 2024 17:57:25 +0200 Subject: [PATCH] plot saving methods added --- quacc/experiments/plotting.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 quacc/experiments/plotting.py diff --git a/quacc/experiments/plotting.py b/quacc/experiments/plotting.py new file mode 100644 index 0000000..1334bd4 --- /dev/null +++ b/quacc/experiments/plotting.py @@ -0,0 +1,60 @@ +import numpy as np + +from quacc.experiments.generators import get_method_names +from quacc.experiments.report import Report +from quacc.plot.matplotlib import plot_delta, plot_diagonal + + +def save_plot_diagonal( + basedir, cls_name, acc_name, dataset_name="*", report: Report = None +): + methods = get_method_names() + report = ( + Report.load_results( + basedir, + cls_name, + acc_name, + dataset_name=dataset_name, + method_name=methods, + ) + if report is None + else report + ) + _methods, _true_accs, _estim_accs = report.diagonal_plot_data() + plot_diagonal( + method_names=_methods, + true_accs=_true_accs, + estim_accs=_estim_accs, + cls_name=cls_name, + acc_name=acc_name, + dataset_name=dataset_name, + basedir=basedir, + ) + + +def save_plot_delta( + basedir, cls_name, acc_name, dataset_name="*", stdev=False, report: Report = None +): + methods = get_method_names() + report = ( + Report.load_results( + basedir, + cls_name, + acc_name, + dataset_name=dataset_name, + method_name=methods, + ) + if report is None + else report + ) + _methods, _prevs, _acc_errs, _stdevs = report.delta_plot_data(stdev=stdev) + plot_delta( + method_names=_methods, + prevs=_prevs, + acc_errs=_acc_errs, + cls_name=cls_name, + acc_name=acc_name, + dataset_name=dataset_name, + basedir=basedir, + stdevs=_stdevs, + )