bug found in delta_plot_data, get_shift method added
This commit is contained in:
parent
e0fec320cf
commit
bd0c15b178
|
@ -6,9 +6,17 @@ from pathlib import Path
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from quacc.error import nae
|
||||
from quacc.utils.commons import get_results_path, load_json_file, save_json_file
|
||||
|
||||
|
||||
def _get_shift(index: np.ndarray, train_prev: np.ndarray):
|
||||
index = np.array([np.array(tp) for tp in index])
|
||||
train_prevs = np.tile(train_prev, (index.shape[0], 1))
|
||||
_shift = nae(index, train_prevs)
|
||||
return np.around(_shift, decimals=2)
|
||||
|
||||
|
||||
class TestReport:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -141,7 +149,9 @@ class Report:
|
|||
stdevs = None if stdev is None else []
|
||||
for _method, _results in self.results.items():
|
||||
methods.append(_method)
|
||||
_prevs = np.array([_r.test_prevs for _r in _results]).flatten()
|
||||
_prevs = np.array(
|
||||
[_r.test_prevs for _r in _results]
|
||||
).flatten() # should not be flattened, check this
|
||||
_true_accs = np.array([_r.true_accs for _r in _results]).flatten()
|
||||
_estim_accs = np.array([_r.estim_accs for _r in _results]).flatten()
|
||||
_acc_errs = np.abs(_true_accs - _estim_accs)
|
||||
|
@ -157,5 +167,5 @@ class Report:
|
|||
|
||||
return methods, prevs, acc_errs, stdevs
|
||||
|
||||
def shift_plot_data(self, dataset_name, method_names, acc_name):
|
||||
def shift_plot_data(self):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue