diff --git a/quacc/experiments/report.py b/quacc/experiments/report.py index 60824f3..9d38444 100644 --- a/quacc/experiments/report.py +++ b/quacc/experiments/report.py @@ -1,4 +1,10 @@ -import os +import itertools +from collections import defaultdict +from glob import glob +from pathlib import Path + +import numpy as np +import pandas as pd from quacc.utils.commons import get_results_path, load_json_file, save_json_file @@ -11,12 +17,16 @@ class TestReport: acc_name, dataset_name, method_name, + train_prev, + val_prev, ): self.basedir = basedir self.cls_name = cls_name self.acc_name = acc_name self.dataset_name = dataset_name self.method_name = method_name + self.train_prev = train_prev + self.val_prev = val_prev @property def path(self): @@ -46,6 +56,8 @@ class TestReport: "acc_name": self.acc_name, "dataset_name": self.dataset_name, "method_name": self.method_name, + "train_prev": self.train_prev, + "val_prev": self.val_prev, "test_prevs": self.test_prevs, "true_accs": self.true_accs, "estim_accs": self.estim_accs, @@ -64,6 +76,8 @@ class TestReport: acc_name=_dict["acc_name"], dataset_name=_dict["dataset_name"], method_name=_dict["method_name"], + train_prev=_dict["train_prev"], + val_prev=_dict["val_prev"], ).add_result( test_prevs=_dict["test_prevs"], true_accs=_dict["true_accs"], @@ -76,37 +90,27 @@ class TestReport: class Report: - def __init__(self, results: list[TestReport]): + def __init__(self, results: dict[str, list[TestReport]]): self.results = results @classmethod - def load_results(cls, basedir): - def walk_results(path): - results = [] - if not os.path.exists(path): - return results - - for f in os.listdir(path): - n_path = os.path.join(path, f) - if os.path.isdir(n_path): - results += walk_results(n_path) - if os.path.isfile(n_path) and n_path.endswith(".json"): - results.append(TestReport.load_json(n_path)) - - return results - - _path = os.path.join("results", basedir) - _results = walk_results(_path) - return Report(results=_results) - - def _filter_by_dataset(self): - pass - - def _filer_by_acc(self): - pass - - def _filter_by_methods(self): - pass + def load_results( + cls, basedir, cls_name, acc_name, dataset_name="*", method_name="*" + ) -> "Report": + _results = defaultdict(lambda: []) + if isinstance(method_name, str): + method_name = [method_name] + if isinstance(dataset_name, str): + dataset_name = [dataset_name] + for dataset_, method_ in itertools.product(dataset_name, method_name): + path = get_results_path(basedir, cls_name, acc_name, dataset_, method_) + for file in glob(path): + if file.endswith(".json"): + # print(file) + method = Path(file).stem + _res = TestReport.load_json(file) + _results[method].append(_res) + return Report(_results) def train_table(self): pass @@ -116,3 +120,42 @@ class Report: def shift_table(self): pass + + def diagonal_plot_data(self): + methods = [] + true_accs = [] + estim_accs = [] + for _method, _results in self.results.items(): + methods.append(_method) + _true_acc = np.array([_r.true_accs for _r in _results]).flatten() + _estim_acc = np.array([_r.estim_accs for _r in _results]).flatten() + true_accs.append(_true_acc) + estim_accs.append(_estim_acc) + + return methods, true_accs, estim_accs + + def delta_plot_data(self, stdev=False): + methods = [] + prevs = [] + acc_errs = [] + stdevs = None if stdev is None else [] + for _method, _results in self.results.items(): + methods.append(_method) + _prevs = np.array([_r.test_prevs for _r in _results]).flatten() + _true_accs = np.array([_r.true_accs for _r in _results]).flatten() + _estim_accs = np.array([_r.estim_accs for _r in _results]).flatten() + _acc_errs = np.abs(_true_accs - _estim_accs) + df = pd.DataFrame( + np.array([_prevs, _acc_errs]).T, columns=["prevs", "errs"] + ) + df_acc_errs = df.groupby(["prevs"]).mean().reset_index() + prevs.append(df_acc_errs["prevs"].to_numpy()) + acc_errs.append(df_acc_errs["errs"].to_numpy()) + if stdev: + df_stdevs = df.groupby(["prevs"]).std().reset_index() + stdevs.append(df_stdevs["errs"].to_numpy()) + + return methods, prevs, acc_errs, stdevs + + def shift_plot_data(self, dataset_name, method_names, acc_name): + pass