panel refactored

This commit is contained in:
Lorenzo Volpi 2023-12-01 12:41:51 +01:00
parent e69e8381e3
commit 2aa77d9e19
2 changed files with 88 additions and 65 deletions

View File

@ -13,7 +13,7 @@ valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes valid_plot_modes["avg"] = DatasetReport._default_dr_modes
def create_plots( def create_plot(
dr: DatasetReport, dr: DatasetReport,
mode="delta", mode="delta",
metric="acc", metric="acc",
@ -24,28 +24,7 @@ def create_plots(
estimators = CE.name[estimators] estimators = CE.name[estimators]
if mode is None: if mode is None:
mode = valid_plot_modes[plot_view][0] mode = valid_plot_modes[plot_view][0]
_dpi = 112
match (plot_view, mode): match (plot_view, mode):
case ("avg", "train_table"):
_data = (
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", "test_table"):
_data = (
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", "shift_table"):
_data = (
dr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", "stats_table"):
_data = wilcoxon(dr, metric=metric, estimators=estimators)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", _ as plot_mode): case ("avg", _ as plot_mode):
_plot = dr.get_plots( _plot = dr.get_plots(
mode=mode, mode=mode,
@ -54,36 +33,6 @@ def create_plots(
conf="panel", conf="panel",
save_fig=False, save_fig=False,
) )
return (
pn.pane.Matplotlib(
_plot,
tight=True,
format="png",
# sizing_mode="scale_height",
sizing_mode=_plot_sizing_mode,
# sizing_mode="scale_both",
)
if _plot is not None
else None
)
case (_, "train_table"):
cr = dr.crs[_prevs.index(int(plot_view))]
_data = (
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case (_, "shift_table"):
cr = dr.crs[_prevs.index(int(plot_view))]
_data = (
cr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case (_, "stats_table"):
cr = dr.crs[_prevs.index(int(plot_view))]
_data = wilcoxon(cr, metric=metric, estimators=estimators)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case (_, _ as plot_mode): case (_, _ as plot_mode):
cr = dr.crs[_prevs.index(int(plot_view))] cr = dr.crs[_prevs.index(int(plot_view))]
_plot = cr.get_plots( _plot = cr.get_plots(
@ -93,18 +42,92 @@ def create_plots(
conf="panel", conf="panel",
save_fig=False, save_fig=False,
) )
return ( if _plot is None:
pn.pane.Matplotlib( return None
_plot,
tight=True, return pn.pane.Matplotlib(
format="png", _plot,
sizing_mode=_plot_sizing_mode, tight=True,
# sizing_mode="scale_height", format="png",
# sizing_mode="scale_both", # sizing_mode="scale_height",
) sizing_mode=_plot_sizing_mode,
if _plot is not None styles=dict(margin="0"),
else None # sizing_mode="scale_both",
)
def create_table(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
plot_view=None,
):
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
estimators = CE.name[estimators]
if mode is None:
mode = valid_plot_modes[plot_view][0]
match (plot_view, mode):
case ("avg", "train_table"):
_data = (
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
) )
case ("avg", "test_table"):
_data = (
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
)
case ("avg", "shift_table"):
_data = (
dr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
case ("avg", "stats_table"):
_data = wilcoxon(dr, metric=metric, estimators=estimators)
case (_, "train_table"):
cr = dr.crs[_prevs.index(int(plot_view))]
_data = (
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
)
case (_, "shift_table"):
cr = dr.crs[_prevs.index(int(plot_view))]
_data = (
cr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
case (_, "stats_table"):
cr = dr.crs[_prevs.index(int(plot_view))]
_data = wilcoxon(cr, metric=metric, estimators=estimators)
return (
pn.Column(
pn.pane.DataFrame(
_data,
align="center",
float_format=lambda v: f"{v:6e}",
styles={"font-size-adjust": "0.62"},
),
sizing_mode="stretch_both",
# scroll=True,
)
if not _data.empty
else None
)
def create_result(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
plot_view=None,
):
match mode:
case m if m.endswith("table"):
return create_table(dr, mode, metric, estimators, plot_view)
case _:
return create_plot(dr, mode, metric, estimators, plot_view)
def explore_datasets(root: Path | str): def explore_datasets(root: Path | str):

View File

@ -6,7 +6,7 @@ import pandas as pd
import panel as pn import panel as pn
import param import param
from qcpanel.util import create_plots, explore_datasets, valid_plot_modes from qcpanel.util import create_result, explore_datasets, valid_plot_modes
from quacc.evaluation.estimators import CE from quacc.evaluation.estimators import CE
from quacc.evaluation.report import DatasetReport from quacc.evaluation.report import DatasetReport
@ -365,7 +365,7 @@ class QuaccTestViewer(param.Parameterized):
self.plot_pane = __svg self.plot_pane = __svg
else: else:
_dr = self.datasets_[self.dataset] _dr = self.datasets_[self.dataset]
__plot = create_plots( __plot = create_result(
_dr, _dr,
mode=self.mode, mode=self.mode,
metric=self.metric, metric=self.metric,