bug found in delta_plot_data, get_shift method added

This commit is contained in:
Lorenzo Volpi 2024-04-08 18:16:25 +02:00
parent e0fec320cf
commit bd0c15b178
1 changed files with 12 additions and 2 deletions

View File

@ -6,9 +6,17 @@ from pathlib import Path
import numpy as np
import pandas as pd
from quacc.error import nae
from quacc.utils.commons import get_results_path, load_json_file, save_json_file
def _get_shift(index: np.ndarray, train_prev: np.ndarray):
index = np.array([np.array(tp) for tp in index])
train_prevs = np.tile(train_prev, (index.shape[0], 1))
_shift = nae(index, train_prevs)
return np.around(_shift, decimals=2)
class TestReport:
def __init__(
self,
@ -141,7 +149,9 @@ class Report:
stdevs = None if stdev is None else []
for _method, _results in self.results.items():
methods.append(_method)
_prevs = np.array([_r.test_prevs for _r in _results]).flatten()
_prevs = np.array(
[_r.test_prevs for _r in _results]
).flatten() # should not be flattened, check this
_true_accs = np.array([_r.true_accs for _r in _results]).flatten()
_estim_accs = np.array([_r.estim_accs for _r in _results]).flatten()
_acc_errs = np.abs(_true_accs - _estim_accs)
@ -157,5 +167,5 @@ class Report:
return methods, prevs, acc_errs, stdevs
def shift_plot_data(self, dataset_name, method_names, acc_name):
def shift_plot_data(self):
pass