diff --git a/qcpanel/util.py b/qcpanel/util.py index 01d3abc..7a296a5 100644 --- a/qcpanel/util.py +++ b/qcpanel/util.py @@ -2,6 +2,7 @@ import os from collections import defaultdict from pathlib import Path +import numpy as np import panel as pn from quacc.evaluation.estimators import CE @@ -13,6 +14,10 @@ valid_plot_modes = defaultdict(lambda: CompReport._default_modes) valid_plot_modes["avg"] = DatasetReport._default_dr_modes +def _get_prev_str(prev: np.ndarray): + return str(tuple(np.around(prev, decimals=2))) + + def create_plot( dr: DatasetReport, mode="delta", @@ -20,7 +25,7 @@ def create_plot( estimators=None, plot_view=None, ): - _prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] + _prevs = [_get_prev_str(cr.train_prev) for cr in dr.crs] estimators = CE.name[estimators] if mode is None: mode = valid_plot_modes[plot_view][0] @@ -34,7 +39,7 @@ def create_plot( save_fig=False, ) case (_, _ as plot_mode): - cr = dr.crs[_prevs.index(int(plot_view))] + cr = dr.crs[_prevs.index(plot_view)] _plot = cr.get_plots( mode=plot_mode, metric=metric, diff --git a/qcpanel/viewer.py b/qcpanel/viewer.py index 76440d4..3aece3c 100644 --- a/qcpanel/viewer.py +++ b/qcpanel/viewer.py @@ -6,7 +6,12 @@ import pandas as pd import panel as pn import param -from qcpanel.util import create_result, explore_datasets, valid_plot_modes +from qcpanel.util import ( + _get_prev_str, + create_result, + explore_datasets, + valid_plot_modes, +) from quacc.evaluation.estimators import CE from quacc.evaluation.report import DatasetReport @@ -308,7 +313,7 @@ class QuaccTestViewer(param.Parameterized): if not self.__get_param_init("estimators"): self.estimators = _new_estimators - l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs] + l_valid_views = [_get_prev_str(cr.train_prev) for cr in l_dr.crs] l_valid_views = ["avg"] + l_valid_views _old_view = self.plot_view self.param["plot_view"].objects = l_valid_views