QuAcc/quacc/experiments/report.py

102 lines
2.7 KiB
Python

import os
from quacc.experiments.util import getpath
from quacc.utils.commons import load_json_file, save_json_file
class TestReport:
def __init__(
self,
cls_name,
acc_name,
dataset_name,
method_name,
):
self.cls_name = cls_name
self.acc_name = acc_name
self.dataset_name = dataset_name
self.method_name = method_name
def path(self, basedir):
return getpath(
basedir, self.cls_name, self.acc_name, self.dataset_name, self.method_name
)
def add_result(self, test_prevs, true_accs, estim_accs, t_train, t_test_ave):
self.test_prevs = test_prevs
self.true_accs = true_accs
self.estim_accs = estim_accs
self.t_train = t_train
self.t_test_ave = t_test_ave
return self
def save_json(self, basedir):
if not all([hasattr(self, _attr) for _attr in ["true_accs", "estim_accs"]]):
raise AttributeError("Incomplete report cannot be dumped")
result = {
"cls_name": self.cls_name,
"acc_name": self.acc_name,
"dataset_name": self.dataset_name,
"method_name": self.method_name,
"t_train": self.t_train,
"t_test_ave": self.t_test,
"true_accs": self.true_accs,
"estim_accs": self.estim_accs,
}
result_path = self.path(basedir)
save_json_file(result_path, result)
@classmethod
def load_json(cls, path) -> "TestReport":
def _test_report_hook(_dict):
return TestReport(
cls_name=_dict["cls_name"],
acc_name=_dict["acc_name"],
dataset_name=_dict["dataset_name"],
method_name=_dict["method_name"],
).add_result(
true_accs=_dict["true_accs"],
estim_accs=_dict["estim_accs"],
t_train=_dict["t_train"],
t_test_ave=_dict["t_test_ave"],
)
return load_json_file(path, object_hook=_test_report_hook)
class Report:
def __init__(self, tests: list[TestReport]):
self.tests = tests
@classmethod
def load_tests(cls, path):
if not os.path.isdir(path):
raise ValueError("Cannot load test results: invalid directory")
_tests = []
for f in os.listdir(path):
if f.endswith(".json"):
_tests.append(TestReport.load_json(f))
return Report(_tests)
def _filter_by_dataset(self):
pass
def _filer_by_acc(self):
pass
def _filter_by_methods(self):
pass
def train_table(self):
pass
def test_table(self):
pass
def shift_table(self):
pass