diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index 4c30d40..ba28d6d 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd from quacc import plot -from quacc.environment import env from quacc.utils import fmt_line_md @@ -145,6 +144,14 @@ class CompReport: avg_p.loc["avg", :] = f_data.mean() return avg_p + def shift_table( + self, metric: str = None, estimators: List[str] = None + ) -> pd.DataFrame: + f_data = self.shift_data(metric=metric, estimators=estimators) + avg_p = f_data.groupby(level=0).mean() + avg_p.loc["avg", :] = f_data.mean() + return avg_p + def get_plots( self, mode="delta", @@ -152,6 +159,7 @@ class CompReport: estimators=None, conf="default", return_fig=False, + base_path=None, ) -> List[Tuple[str, Path]]: if mode == "delta": avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) @@ -163,6 +171,7 @@ class CompReport: name=conf, train_prev=self.train_prev, return_fig=return_fig, + base_path=base_path, ) elif mode == "delta_stdev": avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) @@ -176,6 +185,7 @@ class CompReport: train_prev=self.train_prev, stdevs=st_data.T.to_numpy(), return_fig=return_fig, + base_path=base_path, ) elif mode == "diagonal": f_data = self.data(metric=metric + "_score", estimators=estimators) @@ -189,6 +199,7 @@ class CompReport: name=conf, train_prev=self.train_prev, return_fig=return_fig, + base_path=base_path, ) elif mode == "shift": _shift_data = self.shift_data(metric=metric, estimators=estimators) @@ -207,30 +218,44 @@ class CompReport: train_prev=self.train_prev, counts=shift_counts.T.to_numpy(), return_fig=return_fig, + base_path=base_path, ) - def to_md(self, conf="default", metric="acc", estimators=None, stdev=False) -> str: + def to_md( + self, + conf="default", + metric="acc", + estimators=None, + modes=["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"], + plot_path=None, + ) -> str: res = f"## {int(np.around(self.train_prev, decimals=2)[1]*100)}% positives\n" res += fmt_line_md(f"train: {str(self.train_prev)}") res += fmt_line_md(f"validation: {str(self.valid_prev)}") for k, v in self.times.items(): res += fmt_line_md(f"{k}: {v:.3f}s") res += "\n" - res += self.table(metric=metric, estimators=estimators).to_html() + "\n\n" + if "table" in modes: + res += "### table\n" + res += self.table(metric=metric, estimators=estimators).to_html() + "\n\n" + if "shift_table" in modes: + res += "### shift table\n" + res += ( + self.shift_table(metric=metric, estimators=estimators).to_html() + + "\n\n" + ) - plot_modes = np.array(["delta", "diagonal", "shift"], dtype="object") - if stdev: - whd = np.where(plot_modes == "delta")[0] - if len(whd) > 0: - plot_modes = np.insert(plot_modes, whd + 1, "delta_stdev") + plot_modes = [m for m in modes if m not in ["table", "shift_table"]] for mode in plot_modes: + res += f"### {mode}\n" op = self.get_plots( mode=mode, metric=metric, estimators=estimators, conf=conf, + base_path=plot_path, ) - res += f"![plot_{mode}]({op.relative_to(env.OUT_DIR).as_posix()})\n" + res += f"![plot_{mode}]({op.relative_to(op.parents[1]).as_posix()})\n" return res @@ -304,6 +329,7 @@ class DatasetReport: estimators=None, conf="default", return_fig=False, + base_path=None, ): if mode == "delta_train": _data = self.data(metric, estimators) if data is None else data @@ -320,6 +346,7 @@ class DatasetReport: train_prev=None, avg="train", return_fig=return_fig, + base_path=base_path, ) elif mode == "stdev_train": _data = self.data(metric, estimators) if data is None else data @@ -338,6 +365,7 @@ class DatasetReport: stdevs=stdev_on_train.T.to_numpy(), avg="train", return_fig=return_fig, + base_path=base_path, ) elif mode == "delta_test": _data = self.data(metric, estimators) if data is None else data @@ -352,6 +380,7 @@ class DatasetReport: train_prev=None, avg="test", return_fig=return_fig, + base_path=base_path, ) elif mode == "stdev_test": _data = self.data(metric, estimators) if data is None else data @@ -368,6 +397,7 @@ class DatasetReport: stdevs=stdev_on_test.T.to_numpy(), avg="test", return_fig=return_fig, + base_path=base_path, ) elif mode == "shift": _shift_data = self.shift_data(metric, estimators) if data is None else data @@ -383,12 +413,37 @@ class DatasetReport: train_prev=None, counts=count_shift.T.to_numpy(), return_fig=return_fig, + base_path=base_path, ) - def to_md(self, conf="default", metric="acc", estimators=[], stdev=False): + def to_md( + self, + conf="default", + metric="acc", + estimators=[], + dr_modes=[ + "delta_train", + "stdev_train", + "delta_test", + "stdev_test", + "shift", + "train_table", + "test_table", + "shift_table", + ], + cr_modes=[ + "delta", + "delta_stdev", + "diagonal", + "shift", + "table", + "shift_table", + ], + plot_path=None, + ): res = f"# {self.name}\n\n" for cr in self.crs: - res += f"{cr.to_md(conf, metric=metric, estimators=estimators, stdev=stdev)}\n\n" + res += f"{cr.to_md(conf, metric=metric, estimators=estimators, modes=cr_modes, plot_path=plot_path)}\n\n" _data = self.data(metric=metric, estimators=estimators) _shift_data = self.shift_data(metric=metric, estimators=estimators) @@ -398,68 +453,81 @@ class DatasetReport: ######################## avg on train ######################## res += "### avg on train\n" - avg_on_train_tbl = _data.groupby(level=1).mean() - avg_on_train_tbl.loc["avg", :] = _data.mean() + if "train_table" in dr_modes: + avg_on_train_tbl = _data.groupby(level=1).mean() + avg_on_train_tbl.loc["avg", :] = _data.mean() + res += avg_on_train_tbl.to_html() + "\n\n" - res += avg_on_train_tbl.to_html() + "\n\n" + if "delta_train" in dr_modes: + delta_op = self.get_plots( + data=_data, + mode="delta_train", + metric=metric, + estimators=estimators, + conf=conf, + base_path=plot_path, + ) + res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" - delta_op = self.get_plots( - data=_data, - mode="delta_train", - metric=metric, - estimators=estimators, - conf=conf, - ) - res += f"![plot_delta]({delta_op.relative_to(env.OUT_DIR).as_posix()})\n" - - if stdev: + if "stdev_train" in dr_modes: delta_stdev_op = self.get_plots( data=_data, mode="stdev_train", metric=metric, estimators=estimators, conf=conf, + base_path=plot_path, ) - res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(env.OUT_DIR).as_posix()})\n" + res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" ######################## avg on test ######################## res += "### avg on test\n" - avg_on_test_tbl = _data.groupby(level=0).mean() - avg_on_test_tbl.loc["avg", :] = _data.mean() + if "test_table" in dr_modes: + avg_on_test_tbl = _data.groupby(level=0).mean() + avg_on_test_tbl.loc["avg", :] = _data.mean() + res += avg_on_test_tbl.to_html() + "\n\n" - res += avg_on_test_tbl.to_html() + "\n\n" + if "delta_test" in dr_modes: + delta_op = self.get_plots( + data=_data, + mode="delta_test", + metric=metric, + estimators=estimators, + conf=conf, + base_path=plot_path, + ) + res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" - delta_op = self.get_plots( - data=_data, - mode="delta_test", - metric=metric, - estimators=estimators, - conf=conf, - ) - res += f"![plot_delta]({delta_op.relative_to(env.OUT_DIR).as_posix()})\n" - - if stdev: + if "stdev_test" in dr_modes: delta_stdev_op = self.get_plots( data=_data, mode="stdev_test", metric=metric, estimators=estimators, conf=conf, + base_path=plot_path, ) - res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(env.OUT_DIR).as_posix()})\n" + res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" ######################## avg shift ######################## res += "### avg dataset shift\n" - shift_op = self.get_plots( - data=_shift_data, - mode="shift", - metric=metric, - estimators=estimators, - conf=conf, - ) - res += f"![plot_shift]({shift_op.relative_to(env.OUT_DIR).as_posix()})\n" + if "shift_table" in dr_modes: + shift_on_train_tbl = _shift_data.groupby(level=0).mean() + shift_on_train_tbl.loc["avg", :] = _shift_data.mean() + res += shift_on_train_tbl.to_html() + "\n\n" + + if "shift" in dr_modes: + shift_op = self.get_plots( + data=_shift_data, + mode="shift", + metric=metric, + estimators=estimators, + conf=conf, + base_path=plot_path, + ) + res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n" return res