panel updated

This commit is contained in:
Lorenzo Volpi 2023-12-21 16:47:13 +01:00
parent a5c54a93b7
commit e01006e663
2 changed files with 14 additions and 4 deletions

View File

@ -2,6 +2,7 @@ import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import numpy as np
import panel as pn import panel as pn
from quacc.evaluation.estimators import CE from quacc.evaluation.estimators import CE
@ -13,6 +14,10 @@ valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes valid_plot_modes["avg"] = DatasetReport._default_dr_modes
def _get_prev_str(prev: np.ndarray):
return str(tuple(np.around(prev, decimals=2)))
def create_plot( def create_plot(
dr: DatasetReport, dr: DatasetReport,
mode="delta", mode="delta",
@ -20,7 +25,7 @@ def create_plot(
estimators=None, estimators=None,
plot_view=None, plot_view=None,
): ):
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] _prevs = [_get_prev_str(cr.train_prev) for cr in dr.crs]
estimators = CE.name[estimators] estimators = CE.name[estimators]
if mode is None: if mode is None:
mode = valid_plot_modes[plot_view][0] mode = valid_plot_modes[plot_view][0]
@ -34,7 +39,7 @@ def create_plot(
save_fig=False, save_fig=False,
) )
case (_, _ as plot_mode): case (_, _ as plot_mode):
cr = dr.crs[_prevs.index(int(plot_view))] cr = dr.crs[_prevs.index(plot_view)]
_plot = cr.get_plots( _plot = cr.get_plots(
mode=plot_mode, mode=plot_mode,
metric=metric, metric=metric,

View File

@ -6,7 +6,12 @@ import pandas as pd
import panel as pn import panel as pn
import param import param
from qcpanel.util import create_result, explore_datasets, valid_plot_modes from qcpanel.util import (
_get_prev_str,
create_result,
explore_datasets,
valid_plot_modes,
)
from quacc.evaluation.estimators import CE from quacc.evaluation.estimators import CE
from quacc.evaluation.report import DatasetReport from quacc.evaluation.report import DatasetReport
@ -308,7 +313,7 @@ class QuaccTestViewer(param.Parameterized):
if not self.__get_param_init("estimators"): if not self.__get_param_init("estimators"):
self.estimators = _new_estimators self.estimators = _new_estimators
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs] l_valid_views = [_get_prev_str(cr.train_prev) for cr in l_dr.crs]
l_valid_views = ["avg"] + l_valid_views l_valid_views = ["avg"] + l_valid_views
_old_view = self.plot_view _old_view = self.plot_view
self.param["plot_view"].objects = l_valid_views self.param["plot_view"].objects = l_valid_views