From bd0c15b178672d053eaad0bceb48a8b88a748b56 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Mon, 8 Apr 2024 18:16:25 +0200 Subject: [PATCH] bug found in delta_plot_data, get_shift method added --- quacc/experiments/report.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/quacc/experiments/report.py b/quacc/experiments/report.py index 9d38444..4f90745 100644 --- a/quacc/experiments/report.py +++ b/quacc/experiments/report.py @@ -6,9 +6,17 @@ from pathlib import Path import numpy as np import pandas as pd +from quacc.error import nae from quacc.utils.commons import get_results_path, load_json_file, save_json_file +def _get_shift(index: np.ndarray, train_prev: np.ndarray): + index = np.array([np.array(tp) for tp in index]) + train_prevs = np.tile(train_prev, (index.shape[0], 1)) + _shift = nae(index, train_prevs) + return np.around(_shift, decimals=2) + + class TestReport: def __init__( self, @@ -141,7 +149,9 @@ class Report: stdevs = None if stdev is None else [] for _method, _results in self.results.items(): methods.append(_method) - _prevs = np.array([_r.test_prevs for _r in _results]).flatten() + _prevs = np.array( + [_r.test_prevs for _r in _results] + ).flatten() # should not be flattened, check this _true_accs = np.array([_r.true_accs for _r in _results]).flatten() _estim_accs = np.array([_r.estim_accs for _r in _results]).flatten() _acc_errs = np.abs(_true_accs - _estim_accs) @@ -157,5 +167,5 @@ class Report: return methods, prevs, acc_errs, stdevs - def shift_plot_data(self, dataset_name, method_names, acc_name): + def shift_plot_data(self): pass