laod results fixed, data added, methods for plot data added
This commit is contained in:
parent
d7bb8bb2b9
commit
5cfd5d87dd
|
@ -1,4 +1,10 @@
|
|||
import os
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from quacc.utils.commons import get_results_path, load_json_file, save_json_file
|
||||
|
||||
|
@ -11,12 +17,16 @@ class TestReport:
|
|||
acc_name,
|
||||
dataset_name,
|
||||
method_name,
|
||||
train_prev,
|
||||
val_prev,
|
||||
):
|
||||
self.basedir = basedir
|
||||
self.cls_name = cls_name
|
||||
self.acc_name = acc_name
|
||||
self.dataset_name = dataset_name
|
||||
self.method_name = method_name
|
||||
self.train_prev = train_prev
|
||||
self.val_prev = val_prev
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
|
@ -46,6 +56,8 @@ class TestReport:
|
|||
"acc_name": self.acc_name,
|
||||
"dataset_name": self.dataset_name,
|
||||
"method_name": self.method_name,
|
||||
"train_prev": self.train_prev,
|
||||
"val_prev": self.val_prev,
|
||||
"test_prevs": self.test_prevs,
|
||||
"true_accs": self.true_accs,
|
||||
"estim_accs": self.estim_accs,
|
||||
|
@ -64,6 +76,8 @@ class TestReport:
|
|||
acc_name=_dict["acc_name"],
|
||||
dataset_name=_dict["dataset_name"],
|
||||
method_name=_dict["method_name"],
|
||||
train_prev=_dict["train_prev"],
|
||||
val_prev=_dict["val_prev"],
|
||||
).add_result(
|
||||
test_prevs=_dict["test_prevs"],
|
||||
true_accs=_dict["true_accs"],
|
||||
|
@ -76,37 +90,27 @@ class TestReport:
|
|||
|
||||
|
||||
class Report:
|
||||
def __init__(self, results: list[TestReport]):
|
||||
def __init__(self, results: dict[str, list[TestReport]]):
|
||||
self.results = results
|
||||
|
||||
@classmethod
|
||||
def load_results(cls, basedir):
|
||||
def walk_results(path):
|
||||
results = []
|
||||
if not os.path.exists(path):
|
||||
return results
|
||||
|
||||
for f in os.listdir(path):
|
||||
n_path = os.path.join(path, f)
|
||||
if os.path.isdir(n_path):
|
||||
results += walk_results(n_path)
|
||||
if os.path.isfile(n_path) and n_path.endswith(".json"):
|
||||
results.append(TestReport.load_json(n_path))
|
||||
|
||||
return results
|
||||
|
||||
_path = os.path.join("results", basedir)
|
||||
_results = walk_results(_path)
|
||||
return Report(results=_results)
|
||||
|
||||
def _filter_by_dataset(self):
|
||||
pass
|
||||
|
||||
def _filer_by_acc(self):
|
||||
pass
|
||||
|
||||
def _filter_by_methods(self):
|
||||
pass
|
||||
def load_results(
|
||||
cls, basedir, cls_name, acc_name, dataset_name="*", method_name="*"
|
||||
) -> "Report":
|
||||
_results = defaultdict(lambda: [])
|
||||
if isinstance(method_name, str):
|
||||
method_name = [method_name]
|
||||
if isinstance(dataset_name, str):
|
||||
dataset_name = [dataset_name]
|
||||
for dataset_, method_ in itertools.product(dataset_name, method_name):
|
||||
path = get_results_path(basedir, cls_name, acc_name, dataset_, method_)
|
||||
for file in glob(path):
|
||||
if file.endswith(".json"):
|
||||
# print(file)
|
||||
method = Path(file).stem
|
||||
_res = TestReport.load_json(file)
|
||||
_results[method].append(_res)
|
||||
return Report(_results)
|
||||
|
||||
def train_table(self):
|
||||
pass
|
||||
|
@ -116,3 +120,42 @@ class Report:
|
|||
|
||||
def shift_table(self):
|
||||
pass
|
||||
|
||||
def diagonal_plot_data(self):
|
||||
methods = []
|
||||
true_accs = []
|
||||
estim_accs = []
|
||||
for _method, _results in self.results.items():
|
||||
methods.append(_method)
|
||||
_true_acc = np.array([_r.true_accs for _r in _results]).flatten()
|
||||
_estim_acc = np.array([_r.estim_accs for _r in _results]).flatten()
|
||||
true_accs.append(_true_acc)
|
||||
estim_accs.append(_estim_acc)
|
||||
|
||||
return methods, true_accs, estim_accs
|
||||
|
||||
def delta_plot_data(self, stdev=False):
|
||||
methods = []
|
||||
prevs = []
|
||||
acc_errs = []
|
||||
stdevs = None if stdev is None else []
|
||||
for _method, _results in self.results.items():
|
||||
methods.append(_method)
|
||||
_prevs = np.array([_r.test_prevs for _r in _results]).flatten()
|
||||
_true_accs = np.array([_r.true_accs for _r in _results]).flatten()
|
||||
_estim_accs = np.array([_r.estim_accs for _r in _results]).flatten()
|
||||
_acc_errs = np.abs(_true_accs - _estim_accs)
|
||||
df = pd.DataFrame(
|
||||
np.array([_prevs, _acc_errs]).T, columns=["prevs", "errs"]
|
||||
)
|
||||
df_acc_errs = df.groupby(["prevs"]).mean().reset_index()
|
||||
prevs.append(df_acc_errs["prevs"].to_numpy())
|
||||
acc_errs.append(df_acc_errs["errs"].to_numpy())
|
||||
if stdev:
|
||||
df_stdevs = df.groupby(["prevs"]).std().reset_index()
|
||||
stdevs.append(df_stdevs["errs"].to_numpy())
|
||||
|
||||
return methods, prevs, acc_errs, stdevs
|
||||
|
||||
def shift_plot_data(self, dataset_name, method_names, acc_name):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue