added stats_table

This commit is contained in:
Lorenzo Volpi 2023-11-27 03:26:41 +01:00
parent c393ce4b20
commit d7cbde7522
1 changed files with 8 additions and 22 deletions

View File

@ -4,30 +4,13 @@ from pathlib import Path
import panel as pn import panel as pn
from quacc.evaluation.comp import CE from quacc.evaluation.estimators import CE
from quacc.evaluation.report import DatasetReport from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import ttest_rel
_plot_sizing_mode = "stretch_both" _plot_sizing_mode = "stretch_both"
valid_plot_modes = defaultdict( valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
lambda: [ valid_plot_modes["avg"] = DatasetReport._default_dr_modes
"delta_train",
"stdev_train",
"train_table",
"shift",
"shift_table",
"diagonal",
]
)
valid_plot_modes["avg"] = [
"delta_train",
"stdev_train",
"train_table",
"shift",
"shift_table",
"delta_test",
"stdev_test",
"test_table",
]
def create_plots( def create_plots(
@ -60,6 +43,9 @@ def create_plots(
.mean() .mean()
) )
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", "stats_table"):
_data = ttest_rel(dr, metric=metric, estimators=estimators)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", _ as plot_mode): case ("avg", _ as plot_mode):
_plot = dr.get_plots( _plot = dr.get_plots(
mode=mode, mode=mode,