diff --git a/qcpanel/util.py b/qcpanel/util.py index 3286c2f..01d3abc 100644 --- a/qcpanel/util.py +++ b/qcpanel/util.py @@ -13,7 +13,7 @@ valid_plot_modes = defaultdict(lambda: CompReport._default_modes) valid_plot_modes["avg"] = DatasetReport._default_dr_modes -def create_plots( +def create_plot( dr: DatasetReport, mode="delta", metric="acc", @@ -24,28 +24,7 @@ def create_plots( estimators = CE.name[estimators] if mode is None: mode = valid_plot_modes[plot_view][0] - _dpi = 112 match (plot_view, mode): - case ("avg", "train_table"): - _data = ( - dr.data(metric=metric, estimators=estimators).groupby(level=1).mean() - ) - return pn.pane.DataFrame(_data, align="center") if not _data.empty else None - case ("avg", "test_table"): - _data = ( - dr.data(metric=metric, estimators=estimators).groupby(level=0).mean() - ) - return pn.pane.DataFrame(_data, align="center") if not _data.empty else None - case ("avg", "shift_table"): - _data = ( - dr.shift_data(metric=metric, estimators=estimators) - .groupby(level=0) - .mean() - ) - return pn.pane.DataFrame(_data, align="center") if not _data.empty else None - case ("avg", "stats_table"): - _data = wilcoxon(dr, metric=metric, estimators=estimators) - return pn.pane.DataFrame(_data, align="center") if not _data.empty else None case ("avg", _ as plot_mode): _plot = dr.get_plots( mode=mode, @@ -54,36 +33,6 @@ def create_plots( conf="panel", save_fig=False, ) - return ( - pn.pane.Matplotlib( - _plot, - tight=True, - format="png", - # sizing_mode="scale_height", - sizing_mode=_plot_sizing_mode, - # sizing_mode="scale_both", - ) - if _plot is not None - else None - ) - case (_, "train_table"): - cr = dr.crs[_prevs.index(int(plot_view))] - _data = ( - cr.data(metric=metric, estimators=estimators).groupby(level=0).mean() - ) - return pn.pane.DataFrame(_data, align="center") if not _data.empty else None - case (_, "shift_table"): - cr = dr.crs[_prevs.index(int(plot_view))] - _data = ( - cr.shift_data(metric=metric, estimators=estimators) - .groupby(level=0) - .mean() - ) - return pn.pane.DataFrame(_data, align="center") if not _data.empty else None - case (_, "stats_table"): - cr = dr.crs[_prevs.index(int(plot_view))] - _data = wilcoxon(cr, metric=metric, estimators=estimators) - return pn.pane.DataFrame(_data, align="center") if not _data.empty else None case (_, _ as plot_mode): cr = dr.crs[_prevs.index(int(plot_view))] _plot = cr.get_plots( @@ -93,18 +42,92 @@ def create_plots( conf="panel", save_fig=False, ) - return ( - pn.pane.Matplotlib( - _plot, - tight=True, - format="png", - sizing_mode=_plot_sizing_mode, - # sizing_mode="scale_height", - # sizing_mode="scale_both", - ) - if _plot is not None - else None + if _plot is None: + return None + + return pn.pane.Matplotlib( + _plot, + tight=True, + format="png", + # sizing_mode="scale_height", + sizing_mode=_plot_sizing_mode, + styles=dict(margin="0"), + # sizing_mode="scale_both", + ) + + +def create_table( + dr: DatasetReport, + mode="delta", + metric="acc", + estimators=None, + plot_view=None, +): + _prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] + estimators = CE.name[estimators] + if mode is None: + mode = valid_plot_modes[plot_view][0] + match (plot_view, mode): + case ("avg", "train_table"): + _data = ( + dr.data(metric=metric, estimators=estimators).groupby(level=1).mean() ) + case ("avg", "test_table"): + _data = ( + dr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + ) + case ("avg", "shift_table"): + _data = ( + dr.shift_data(metric=metric, estimators=estimators) + .groupby(level=0) + .mean() + ) + case ("avg", "stats_table"): + _data = wilcoxon(dr, metric=metric, estimators=estimators) + case (_, "train_table"): + cr = dr.crs[_prevs.index(int(plot_view))] + _data = ( + cr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + ) + case (_, "shift_table"): + cr = dr.crs[_prevs.index(int(plot_view))] + _data = ( + cr.shift_data(metric=metric, estimators=estimators) + .groupby(level=0) + .mean() + ) + case (_, "stats_table"): + cr = dr.crs[_prevs.index(int(plot_view))] + _data = wilcoxon(cr, metric=metric, estimators=estimators) + + return ( + pn.Column( + pn.pane.DataFrame( + _data, + align="center", + float_format=lambda v: f"{v:6e}", + styles={"font-size-adjust": "0.62"}, + ), + sizing_mode="stretch_both", + # scroll=True, + ) + if not _data.empty + else None + ) + + +def create_result( + dr: DatasetReport, + mode="delta", + metric="acc", + estimators=None, + plot_view=None, +): + match mode: + case m if m.endswith("table"): + return create_table(dr, mode, metric, estimators, plot_view) + case _: + return create_plot(dr, mode, metric, estimators, plot_view) def explore_datasets(root: Path | str): diff --git a/qcpanel/viewer.py b/qcpanel/viewer.py index 6d7abfa..76440d4 100644 --- a/qcpanel/viewer.py +++ b/qcpanel/viewer.py @@ -6,7 +6,7 @@ import pandas as pd import panel as pn import param -from qcpanel.util import create_plots, explore_datasets, valid_plot_modes +from qcpanel.util import create_result, explore_datasets, valid_plot_modes from quacc.evaluation.estimators import CE from quacc.evaluation.report import DatasetReport @@ -365,7 +365,7 @@ class QuaccTestViewer(param.Parameterized): self.plot_pane = __svg else: _dr = self.datasets_[self.dataset] - __plot = create_plots( + __plot = create_result( _dr, mode=self.mode, metric=self.metric,