diff --git a/quacc/experiments/report.py b/quacc/experiments/report.py index 9d38444..4f90745 100644 --- a/quacc/experiments/report.py +++ b/quacc/experiments/report.py @@ -6,9 +6,17 @@ from pathlib import Path import numpy as np import pandas as pd +from quacc.error import nae from quacc.utils.commons import get_results_path, load_json_file, save_json_file +def _get_shift(index: np.ndarray, train_prev: np.ndarray): + index = np.array([np.array(tp) for tp in index]) + train_prevs = np.tile(train_prev, (index.shape[0], 1)) + _shift = nae(index, train_prevs) + return np.around(_shift, decimals=2) + + class TestReport: def __init__( self, @@ -141,7 +149,9 @@ class Report: stdevs = None if stdev is None else [] for _method, _results in self.results.items(): methods.append(_method) - _prevs = np.array([_r.test_prevs for _r in _results]).flatten() + _prevs = np.array( + [_r.test_prevs for _r in _results] + ).flatten() # should not be flattened, check this _true_accs = np.array([_r.true_accs for _r in _results]).flatten() _estim_accs = np.array([_r.estim_accs for _r in _results]).flatten() _acc_errs = np.abs(_true_accs - _estim_accs) @@ -157,5 +167,5 @@ class Report: return methods, prevs, acc_errs, stdevs - def shift_plot_data(self, dataset_name, method_names, acc_name): + def shift_plot_data(self): pass