qcpanel updated and refactored
This commit is contained in:
parent
f8ae408a4f
commit
a9b926717e
583
qcpanel/run.py
583
qcpanel/run.py
|
@ -1,563 +1,13 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import panel as pn
|
import panel as pn
|
||||||
import param
|
|
||||||
|
|
||||||
from quacc import utils
|
from qcpanel.viewer import QuaccTestViewer
|
||||||
from quacc.evaluation.comp import CE
|
|
||||||
from quacc.evaluation.report import DatasetReport
|
|
||||||
|
|
||||||
pn.config.design = pn.theme.Bootstrap
|
# pn.config.design = pn.theme.Bootstrap
|
||||||
pn.config.theme = "dark"
|
# pn.config.theme = "dark"
|
||||||
pn.config.notifications = True
|
pn.config.notifications = True
|
||||||
|
|
||||||
valid_plot_modes = defaultdict(
|
|
||||||
lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"]
|
|
||||||
)
|
|
||||||
valid_plot_modes["avg"] = [
|
|
||||||
"delta_train",
|
|
||||||
"stdev_train",
|
|
||||||
"delta_test",
|
|
||||||
"stdev_test",
|
|
||||||
"shift",
|
|
||||||
"train_table",
|
|
||||||
"test_table",
|
|
||||||
"shift_table",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def create_cr_plots(
|
|
||||||
dr: DatasetReport,
|
|
||||||
mode="delta",
|
|
||||||
metric="acc",
|
|
||||||
estimators=None,
|
|
||||||
prev=None,
|
|
||||||
):
|
|
||||||
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
|
|
||||||
idx = _prevs.index(prev)
|
|
||||||
cr = dr.crs[idx]
|
|
||||||
estimators = CE.name[estimators]
|
|
||||||
if mode is None:
|
|
||||||
mode = valid_plot_modes[str(prev)][0]
|
|
||||||
_dpi = 112
|
|
||||||
if mode == "table":
|
|
||||||
return pn.pane.DataFrame(
|
|
||||||
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
elif mode == "shift_table":
|
|
||||||
return pn.pane.DataFrame(
|
|
||||||
cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return pn.pane.Matplotlib(
|
|
||||||
cr.get_plots(
|
|
||||||
mode=mode,
|
|
||||||
metric=metric,
|
|
||||||
estimators=estimators,
|
|
||||||
conf="panel",
|
|
||||||
return_fig=True,
|
|
||||||
),
|
|
||||||
tight=True,
|
|
||||||
format="png",
|
|
||||||
sizing_mode="scale_height",
|
|
||||||
# sizing_mode="scale_both",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_avg_plots(
|
|
||||||
dr: DatasetReport,
|
|
||||||
mode="delta",
|
|
||||||
metric="acc",
|
|
||||||
estimators=None,
|
|
||||||
):
|
|
||||||
estimators = CE.name[estimators]
|
|
||||||
if mode is None:
|
|
||||||
mode = valid_plot_modes["avg"][0]
|
|
||||||
|
|
||||||
if mode == "train_table":
|
|
||||||
return pn.pane.DataFrame(
|
|
||||||
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(),
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
elif mode == "test_table":
|
|
||||||
return pn.pane.DataFrame(
|
|
||||||
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
elif mode == "shift_table":
|
|
||||||
return pn.pane.DataFrame(
|
|
||||||
dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
return pn.pane.Matplotlib(
|
|
||||||
dr.get_plots(
|
|
||||||
mode=mode,
|
|
||||||
metric=metric,
|
|
||||||
estimators=estimators,
|
|
||||||
conf="panel",
|
|
||||||
return_fig=True,
|
|
||||||
),
|
|
||||||
tight=True,
|
|
||||||
format="png",
|
|
||||||
sizing_mode="scale_height",
|
|
||||||
# sizing_mode="scale_both",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_widgets(datasets: Dict[str, DatasetReport]):
|
|
||||||
available_datasets = list(datasets.keys())
|
|
||||||
dataset_widget = pn.widgets.Select(
|
|
||||||
name="dataset",
|
|
||||||
options=available_datasets,
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
|
|
||||||
_dr = datasets[dataset_widget.value]
|
|
||||||
_data = _dr.data()
|
|
||||||
_metrics = _data.columns.unique(0)
|
|
||||||
_estimators = _data.columns.unique(1)
|
|
||||||
|
|
||||||
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
|
|
||||||
metric_widget = pn.widgets.Select(
|
|
||||||
name="metric",
|
|
||||||
value="acc",
|
|
||||||
options=valid_metrics,
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_estimators = [e for e in _estimators if e != "ref"]
|
|
||||||
estimators_widget = pn.widgets.CheckButtonGroup(
|
|
||||||
name="estimators",
|
|
||||||
options=valid_estimators,
|
|
||||||
value=valid_estimators,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_views = [str(round(cr.train_prev[1] * 100)) for cr in _dr.crs]
|
|
||||||
view_widget = pn.widgets.RadioButtonGroup(
|
|
||||||
name="view",
|
|
||||||
options=valid_views + ["avg"],
|
|
||||||
value="avg",
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
)
|
|
||||||
|
|
||||||
@pn.depends(dataset_widget.param.value, watch=True)
|
|
||||||
def _update_from_dataset(_dataset):
|
|
||||||
l_dr = datasets[dataset_widget.value]
|
|
||||||
l_data = l_dr.data()
|
|
||||||
l_metrics = l_data.columns.unique(0)
|
|
||||||
l_estimators = l_data.columns.unique(1)
|
|
||||||
|
|
||||||
l_valid_estimators = [e for e in l_estimators if e != "ref"]
|
|
||||||
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
|
|
||||||
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
|
|
||||||
|
|
||||||
metric_widget.options = l_valid_metrics
|
|
||||||
metric_widget.value = l_valid_metrics[0]
|
|
||||||
|
|
||||||
estimators_widget.options = l_valid_estimators
|
|
||||||
estimators_widget.value = l_valid_estimators
|
|
||||||
|
|
||||||
view_widget.options = l_valid_views + ["avg"]
|
|
||||||
view_widget.value = "avg"
|
|
||||||
|
|
||||||
plot_mode_widget = pn.widgets.RadioButtonGroup(
|
|
||||||
name="mode",
|
|
||||||
value=valid_plot_modes["avg"][0],
|
|
||||||
options=valid_plot_modes["avg"],
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
)
|
|
||||||
|
|
||||||
@pn.depends(view_widget.param.value, watch=True)
|
|
||||||
def _update_from_view(_view):
|
|
||||||
_modes = valid_plot_modes[_view]
|
|
||||||
plot_mode_widget.options = _modes
|
|
||||||
plot_mode_widget.value = _modes[0]
|
|
||||||
|
|
||||||
widget_pane = pn.Column(
|
|
||||||
dataset_widget,
|
|
||||||
metric_widget,
|
|
||||||
pn.Row(
|
|
||||||
view_widget,
|
|
||||||
plot_mode_widget,
|
|
||||||
),
|
|
||||||
estimators_widget,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
widget_pane,
|
|
||||||
{
|
|
||||||
"dataset": dataset_widget,
|
|
||||||
"metric": metric_widget,
|
|
||||||
"view": view_widget,
|
|
||||||
"plot_mode": plot_mode_widget,
|
|
||||||
"estimators": estimators_widget,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_plot(
|
|
||||||
datasets: Dict[str, DatasetReport],
|
|
||||||
dst: str,
|
|
||||||
metric: str,
|
|
||||||
estimators: List[str],
|
|
||||||
view: str,
|
|
||||||
mode: str,
|
|
||||||
):
|
|
||||||
_dr = datasets[dst]
|
|
||||||
if view == "avg":
|
|
||||||
return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators)
|
|
||||||
else:
|
|
||||||
prev = int(view)
|
|
||||||
return create_cr_plots(
|
|
||||||
_dr, mode=mode, metric=metric, estimators=estimators, prev=prev
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_modal(datasets, dst, metric):
|
|
||||||
return pn.pane.Str(f"{dst}_{metric}")
|
|
||||||
|
|
||||||
|
|
||||||
def build_save_pane(datasets: Dict[str, DatasetReport], dst: str, metric: str):
|
|
||||||
return pn.pane.Str(f"{datasets[dst]}_{metric}")
|
|
||||||
|
|
||||||
|
|
||||||
def explore_datasets(root: Path | str):
|
|
||||||
if isinstance(root, str):
|
|
||||||
root = Path(root)
|
|
||||||
|
|
||||||
if root.name == "plot":
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not root.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
drs = []
|
|
||||||
for f in os.listdir(root):
|
|
||||||
if (root / f).is_dir():
|
|
||||||
drs += explore_datasets(root / f)
|
|
||||||
elif f == f"{root.name}.pickle":
|
|
||||||
drs.append(root / f)
|
|
||||||
# drs.append((str(root),))
|
|
||||||
|
|
||||||
return drs
|
|
||||||
|
|
||||||
|
|
||||||
class QuaccTestViewer(param.Parameterized):
|
|
||||||
dataset = param.Selector()
|
|
||||||
metric = param.Selector()
|
|
||||||
estimators = param.ListSelector()
|
|
||||||
plot_view = param.Selector()
|
|
||||||
mode = param.Selector()
|
|
||||||
|
|
||||||
modal_estimators = param.ListSelector()
|
|
||||||
modal_plot_view = param.ListSelector()
|
|
||||||
modal_mode_prev = param.ListSelector(
|
|
||||||
objects=valid_plot_modes[0], default=valid_plot_modes[0]
|
|
||||||
)
|
|
||||||
modal_mode_avg = param.ListSelector(
|
|
||||||
objects=valid_plot_modes["avg"], default=valid_plot_modes["avg"]
|
|
||||||
)
|
|
||||||
|
|
||||||
param_pane = param.Parameter()
|
|
||||||
plot_pane = param.Parameter()
|
|
||||||
modal_pane = param.Parameter()
|
|
||||||
|
|
||||||
def __init__(self, **params):
|
|
||||||
super().__init__(**params)
|
|
||||||
|
|
||||||
self.__setup_watchers()
|
|
||||||
self.__import_datasets()
|
|
||||||
# self._update_on_dataset()
|
|
||||||
self.__create_param_pane()
|
|
||||||
self.__create_modal_pane()
|
|
||||||
|
|
||||||
def __save_callback(self, event):
|
|
||||||
_home = utils.get_quacc_home()
|
|
||||||
_save_input_val = self.save_input.value_input
|
|
||||||
_config = "default" if len(_save_input_val) == 0 else _save_input_val
|
|
||||||
base_path = _home / "output" / self.dataset / _config
|
|
||||||
os.makedirs(base_path, exist_ok=True)
|
|
||||||
base_plot = base_path / "plot"
|
|
||||||
os.makedirs(base_plot, exist_ok=True)
|
|
||||||
|
|
||||||
l_dr = self.datasets_[self.dataset]
|
|
||||||
res = l_dr.to_md(
|
|
||||||
conf=_config,
|
|
||||||
metric=self.metric,
|
|
||||||
estimators=CE.name[self.modal_estimators],
|
|
||||||
dr_modes=self.modal_mode_avg,
|
|
||||||
cr_modes=self.modal_mode_prev,
|
|
||||||
plot_path=base_plot,
|
|
||||||
)
|
|
||||||
with open(base_path / f"{self.metric}.md", "w") as f:
|
|
||||||
f.write(res)
|
|
||||||
|
|
||||||
pn.state.notifications.success(f'"{_config}" successfully saved')
|
|
||||||
|
|
||||||
def __create_param_pane(self):
|
|
||||||
self.dataset_widget = pn.Param(
|
|
||||||
self,
|
|
||||||
show_name=False,
|
|
||||||
parameters=["dataset"],
|
|
||||||
widgets={"dataset": {"widget_type": pn.widgets.Select}},
|
|
||||||
)
|
|
||||||
self.metric_widget = pn.Param(
|
|
||||||
self,
|
|
||||||
show_name=False,
|
|
||||||
parameters=["metric"],
|
|
||||||
widgets={"metric": {"widget_type": pn.widgets.Select}},
|
|
||||||
)
|
|
||||||
self.estimators_widgets = pn.Param(
|
|
||||||
self,
|
|
||||||
show_name=False,
|
|
||||||
parameters=["estimators"],
|
|
||||||
widgets={
|
|
||||||
"estimators": {
|
|
||||||
"widget_type": pn.widgets.CheckButtonGroup,
|
|
||||||
"orientation": "vertical",
|
|
||||||
"sizing_mode": "scale_width",
|
|
||||||
"button_type": "primary",
|
|
||||||
"button_style": "outline",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.plot_view_widget = pn.Param(
|
|
||||||
self,
|
|
||||||
show_name=False,
|
|
||||||
parameters=["plot_view"],
|
|
||||||
widgets={
|
|
||||||
"plot_view": {
|
|
||||||
"widget_type": pn.widgets.RadioButtonGroup,
|
|
||||||
"orientation": "vertical",
|
|
||||||
"button_type": "primary",
|
|
||||||
"button_style": "outline",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.mode_widget = pn.Param(
|
|
||||||
self,
|
|
||||||
show_name=False,
|
|
||||||
parameters=["mode"],
|
|
||||||
widgets={
|
|
||||||
"mode": {
|
|
||||||
"widget_type": pn.widgets.RadioButtonGroup,
|
|
||||||
"orientation": "vertical",
|
|
||||||
"sizing_mode": "scale_width",
|
|
||||||
"button_type": "primary",
|
|
||||||
"button_style": "outline",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
self.param_pane = pn.Column(
|
|
||||||
self.dataset_widget,
|
|
||||||
self.metric_widget,
|
|
||||||
pn.Row(
|
|
||||||
self.plot_view_widget,
|
|
||||||
self.mode_widget,
|
|
||||||
),
|
|
||||||
self.estimators_widgets,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __create_modal_pane(self):
|
|
||||||
self.modal_estimators_widgets = pn.Param(
|
|
||||||
self,
|
|
||||||
show_name=False,
|
|
||||||
parameters=["modal_estimators"],
|
|
||||||
widgets={
|
|
||||||
"modal_estimators": {
|
|
||||||
"widget_type": pn.widgets.CheckButtonGroup,
|
|
||||||
"orientation": "vertical",
|
|
||||||
"sizing_mode": "scale_width",
|
|
||||||
"button_type": "primary",
|
|
||||||
"button_style": "outline",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.modal_plot_view_widget = pn.Param(
|
|
||||||
self,
|
|
||||||
show_name=False,
|
|
||||||
parameters=["modal_plot_view"],
|
|
||||||
widgets={
|
|
||||||
"modal_plot_view": {
|
|
||||||
"widget_type": pn.widgets.CheckButtonGroup,
|
|
||||||
"orientation": "vertical",
|
|
||||||
"button_type": "primary",
|
|
||||||
"button_style": "outline",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.modal_mode_prev_widget = pn.Param(
|
|
||||||
self,
|
|
||||||
show_name=False,
|
|
||||||
parameters=["modal_mode_prev"],
|
|
||||||
widgets={
|
|
||||||
"modal_mode_prev": {
|
|
||||||
"widget_type": pn.widgets.CheckButtonGroup,
|
|
||||||
"orientation": "vertical",
|
|
||||||
"sizing_mode": "scale_width",
|
|
||||||
"button_type": "primary",
|
|
||||||
"button_style": "outline",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
self.modal_mode_avg_widget = pn.Param(
|
|
||||||
self,
|
|
||||||
show_name=False,
|
|
||||||
parameters=["modal_mode_avg"],
|
|
||||||
widgets={
|
|
||||||
"modal_mode_avg": {
|
|
||||||
"widget_type": pn.widgets.CheckButtonGroup,
|
|
||||||
"orientation": "vertical",
|
|
||||||
"sizing_mode": "scale_width",
|
|
||||||
"button_type": "primary",
|
|
||||||
"button_style": "outline",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.save_input = pn.widgets.TextInput(
|
|
||||||
name="Configuration Name", placeholder="default", sizing_mode="scale_width"
|
|
||||||
)
|
|
||||||
self.save_button = pn.widgets.Button(
|
|
||||||
name="Saverrr",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
button_style="solid",
|
|
||||||
button_type="success",
|
|
||||||
)
|
|
||||||
self.save_button.on_click(self.__save_callback)
|
|
||||||
|
|
||||||
_title_styles = {
|
|
||||||
"font-size": "14pt",
|
|
||||||
"font-weight": "bold",
|
|
||||||
}
|
|
||||||
self.modal_pane = pn.Column(
|
|
||||||
pn.Column(
|
|
||||||
pn.pane.Str("Avg. configuration", styles=_title_styles),
|
|
||||||
self.modal_mode_avg_widget,
|
|
||||||
pn.pane.Str("Train prevs. configuration", styles=_title_styles),
|
|
||||||
pn.Row(
|
|
||||||
self.modal_plot_view_widget,
|
|
||||||
self.modal_mode_prev_widget,
|
|
||||||
),
|
|
||||||
pn.pane.Str("Estimators configuration", styles=_title_styles),
|
|
||||||
self.modal_estimators_widgets,
|
|
||||||
self.save_input,
|
|
||||||
self.save_button,
|
|
||||||
width=450,
|
|
||||||
align="center",
|
|
||||||
scroll=True,
|
|
||||||
),
|
|
||||||
sizing_mode="stretch_both",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __import_datasets(self):
|
|
||||||
__base_path = "output"
|
|
||||||
dataset_paths = sorted(
|
|
||||||
explore_datasets(__base_path), key=lambda t: (-len(t.parts), t)
|
|
||||||
)
|
|
||||||
self.datasets_ = {
|
|
||||||
str(dp.parent.relative_to(Path(__base_path))): DatasetReport.unpickle(dp)
|
|
||||||
for dp in dataset_paths
|
|
||||||
}
|
|
||||||
|
|
||||||
self.available_datasets = list(self.datasets_.keys())
|
|
||||||
self.param["dataset"].objects = self.available_datasets
|
|
||||||
self.dataset = self.available_datasets[0]
|
|
||||||
|
|
||||||
def __setup_watchers(self):
|
|
||||||
self.param.watch(
|
|
||||||
self._update_on_dataset,
|
|
||||||
["dataset"],
|
|
||||||
queued=True,
|
|
||||||
precedence=0,
|
|
||||||
)
|
|
||||||
self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1)
|
|
||||||
self.param.watch(
|
|
||||||
self._update_plot,
|
|
||||||
["dataset", "metric", "estimators", "plot_view", "mode"],
|
|
||||||
# ["metric", "estimators", "mode"],
|
|
||||||
onlychanged=False,
|
|
||||||
precedence=2,
|
|
||||||
)
|
|
||||||
self.param.watch(
|
|
||||||
self._update_on_estimators,
|
|
||||||
["estimators"],
|
|
||||||
queued=True,
|
|
||||||
precedence=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_on_dataset(self, *events):
|
|
||||||
l_dr = self.datasets_[self.dataset]
|
|
||||||
l_data = l_dr.data()
|
|
||||||
l_metrics = l_data.columns.unique(0)
|
|
||||||
l_estimators = l_data.columns.unique(1)
|
|
||||||
|
|
||||||
l_valid_estimators = [e for e in l_estimators if e != "ref"]
|
|
||||||
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
|
|
||||||
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
|
|
||||||
|
|
||||||
self.param["metric"].objects = l_valid_metrics
|
|
||||||
self.metric = l_valid_metrics[0]
|
|
||||||
|
|
||||||
self.param["estimators"].objects = l_valid_estimators
|
|
||||||
self.estimators = l_valid_estimators
|
|
||||||
|
|
||||||
self.param["plot_view"].objects = ["avg"] + l_valid_views
|
|
||||||
self.plot_view = "avg"
|
|
||||||
|
|
||||||
self.param["mode"].objects = valid_plot_modes["avg"]
|
|
||||||
self.mode = valid_plot_modes["avg"][0]
|
|
||||||
|
|
||||||
self.param["modal_estimators"].objects = l_valid_estimators
|
|
||||||
self.modal_estimators = []
|
|
||||||
|
|
||||||
self.param["modal_plot_view"].objects = l_valid_views
|
|
||||||
self.modal_plot_view = l_valid_views.copy()
|
|
||||||
|
|
||||||
def _update_on_view(self, *events):
|
|
||||||
self.param["mode"].objects = valid_plot_modes[self.plot_view]
|
|
||||||
self.mode = valid_plot_modes[self.plot_view][0]
|
|
||||||
|
|
||||||
def _update_on_estimators(self, *events):
|
|
||||||
self.modal_estimators = self.estimators.copy()
|
|
||||||
|
|
||||||
def _update_plot(self, *events):
|
|
||||||
self.plot_pane = build_plot(
|
|
||||||
datasets=self.datasets_,
|
|
||||||
dst=self.dataset,
|
|
||||||
metric=self.metric,
|
|
||||||
estimators=self.estimators,
|
|
||||||
view=self.plot_view,
|
|
||||||
mode=self.mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_plot(self):
|
|
||||||
return self.plot_pane
|
|
||||||
|
|
||||||
def get_param_pane(self):
|
|
||||||
return self.param_pane
|
|
||||||
|
|
||||||
|
|
||||||
def serve(address="localhost"):
|
def serve(address="localhost"):
|
||||||
qtv = QuaccTestViewer()
|
qtv = QuaccTestViewer()
|
||||||
|
@ -565,19 +15,36 @@ def serve(address="localhost"):
|
||||||
def save_callback(event):
|
def save_callback(event):
|
||||||
app.open_modal()
|
app.open_modal()
|
||||||
|
|
||||||
|
def refresh_callback(event):
|
||||||
|
qtv.update_datasets()
|
||||||
|
|
||||||
save_button = pn.widgets.Button(
|
save_button = pn.widgets.Button(
|
||||||
name="Save",
|
# name="Save",
|
||||||
sizing_mode="scale_width",
|
icon="device-floppy",
|
||||||
|
icon_size="16px",
|
||||||
|
# sizing_mode="scale_width",
|
||||||
button_style="solid",
|
button_style="solid",
|
||||||
button_type="success",
|
button_type="success",
|
||||||
)
|
)
|
||||||
save_button.on_click(save_callback)
|
save_button.on_click(save_callback)
|
||||||
|
|
||||||
app = pn.template.MaterialTemplate(
|
refresh_button = pn.widgets.Button(
|
||||||
|
icon="refresh",
|
||||||
|
icon_size="16px",
|
||||||
|
button_style="solid",
|
||||||
|
)
|
||||||
|
refresh_button.on_click(refresh_callback)
|
||||||
|
|
||||||
|
app = pn.template.FastListTemplate(
|
||||||
title="quacc tests",
|
title="quacc tests",
|
||||||
sidebar=[save_button, qtv.get_param_pane],
|
sidebar=[
|
||||||
main=[qtv.get_plot],
|
pn.FlexBox(save_button, refresh_button, flex_direction="row-reverse"),
|
||||||
|
qtv.get_param_pane,
|
||||||
|
],
|
||||||
|
main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")],
|
||||||
modal=[qtv.modal_pane],
|
modal=[qtv.modal_pane],
|
||||||
|
theme=pn.theme.DarkTheme,
|
||||||
|
theme_toggle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.servable()
|
app.servable()
|
||||||
|
|
|
@ -0,0 +1,248 @@
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import panel as pn
|
||||||
|
|
||||||
|
from quacc.evaluation.comp import CE
|
||||||
|
from quacc.evaluation.report import DatasetReport
|
||||||
|
|
||||||
|
_plot_sizing_mode = "stretch_both"
|
||||||
|
valid_plot_modes = defaultdict(
|
||||||
|
lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"]
|
||||||
|
)
|
||||||
|
valid_plot_modes["avg"] = [
|
||||||
|
"delta_train",
|
||||||
|
"stdev_train",
|
||||||
|
"delta_test",
|
||||||
|
"stdev_test",
|
||||||
|
"shift",
|
||||||
|
"train_table",
|
||||||
|
"test_table",
|
||||||
|
"shift_table",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_cr_plots(
|
||||||
|
dr: DatasetReport,
|
||||||
|
mode="delta",
|
||||||
|
metric="acc",
|
||||||
|
estimators=None,
|
||||||
|
prev=None,
|
||||||
|
):
|
||||||
|
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
|
||||||
|
idx = _prevs.index(prev)
|
||||||
|
cr = dr.crs[idx]
|
||||||
|
estimators = CE.name[estimators]
|
||||||
|
if mode is None:
|
||||||
|
mode = valid_plot_modes[str(prev)][0]
|
||||||
|
_dpi = 112
|
||||||
|
if mode == "table":
|
||||||
|
return pn.pane.DataFrame(
|
||||||
|
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
||||||
|
align="center",
|
||||||
|
)
|
||||||
|
elif mode == "shift_table":
|
||||||
|
return pn.pane.DataFrame(
|
||||||
|
cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
||||||
|
align="center",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return pn.pane.Matplotlib(
|
||||||
|
cr.get_plots(
|
||||||
|
mode=mode,
|
||||||
|
metric=metric,
|
||||||
|
estimators=estimators,
|
||||||
|
conf="panel",
|
||||||
|
return_fig=True,
|
||||||
|
),
|
||||||
|
tight=True,
|
||||||
|
format="png",
|
||||||
|
sizing_mode=_plot_sizing_mode,
|
||||||
|
# sizing_mode="scale_height",
|
||||||
|
# sizing_mode="scale_both",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_avg_plots(
|
||||||
|
dr: DatasetReport,
|
||||||
|
mode="delta",
|
||||||
|
metric="acc",
|
||||||
|
estimators=None,
|
||||||
|
):
|
||||||
|
estimators = CE.name[estimators]
|
||||||
|
if mode is None:
|
||||||
|
mode = valid_plot_modes["avg"][0]
|
||||||
|
|
||||||
|
if mode == "train_table":
|
||||||
|
return pn.pane.DataFrame(
|
||||||
|
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(),
|
||||||
|
align="center",
|
||||||
|
)
|
||||||
|
elif mode == "test_table":
|
||||||
|
return pn.pane.DataFrame(
|
||||||
|
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
||||||
|
align="center",
|
||||||
|
)
|
||||||
|
elif mode == "shift_table":
|
||||||
|
return pn.pane.DataFrame(
|
||||||
|
dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
||||||
|
align="center",
|
||||||
|
)
|
||||||
|
return pn.pane.Matplotlib(
|
||||||
|
dr.get_plots(
|
||||||
|
mode=mode,
|
||||||
|
metric=metric,
|
||||||
|
estimators=estimators,
|
||||||
|
conf="panel",
|
||||||
|
return_fig=True,
|
||||||
|
),
|
||||||
|
tight=True,
|
||||||
|
format="png",
|
||||||
|
# sizing_mode="scale_height",
|
||||||
|
sizing_mode=_plot_sizing_mode,
|
||||||
|
# sizing_mode="scale_both",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_widgets(datasets: Dict[str, DatasetReport]):
|
||||||
|
available_datasets = list(datasets.keys())
|
||||||
|
dataset_widget = pn.widgets.Select(
|
||||||
|
name="dataset",
|
||||||
|
options=available_datasets,
|
||||||
|
align="center",
|
||||||
|
)
|
||||||
|
|
||||||
|
_dr = datasets[dataset_widget.value]
|
||||||
|
_data = _dr.data()
|
||||||
|
_metrics = _data.columns.unique(0)
|
||||||
|
_estimators = _data.columns.unique(1)
|
||||||
|
|
||||||
|
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
|
||||||
|
metric_widget = pn.widgets.Select(
|
||||||
|
name="metric",
|
||||||
|
value="acc",
|
||||||
|
options=valid_metrics,
|
||||||
|
align="center",
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_estimators = [e for e in _estimators if e != "ref"]
|
||||||
|
estimators_widget = pn.widgets.CheckButtonGroup(
|
||||||
|
name="estimators",
|
||||||
|
options=valid_estimators,
|
||||||
|
value=valid_estimators,
|
||||||
|
button_style="outline",
|
||||||
|
button_type="primary",
|
||||||
|
align="center",
|
||||||
|
orientation="vertical",
|
||||||
|
sizing_mode="scale_width",
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_views = [str(round(cr.train_prev[1] * 100)) for cr in _dr.crs]
|
||||||
|
view_widget = pn.widgets.RadioButtonGroup(
|
||||||
|
name="view",
|
||||||
|
options=valid_views + ["avg"],
|
||||||
|
value="avg",
|
||||||
|
button_style="outline",
|
||||||
|
button_type="primary",
|
||||||
|
align="center",
|
||||||
|
orientation="vertical",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pn.depends(dataset_widget.param.value, watch=True)
|
||||||
|
def _update_from_dataset(_dataset):
|
||||||
|
l_dr = datasets[dataset_widget.value]
|
||||||
|
l_data = l_dr.data()
|
||||||
|
l_metrics = l_data.columns.unique(0)
|
||||||
|
l_estimators = l_data.columns.unique(1)
|
||||||
|
|
||||||
|
l_valid_estimators = [e for e in l_estimators if e != "ref"]
|
||||||
|
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
|
||||||
|
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
|
||||||
|
|
||||||
|
metric_widget.options = l_valid_metrics
|
||||||
|
metric_widget.value = l_valid_metrics[0]
|
||||||
|
|
||||||
|
estimators_widget.options = l_valid_estimators
|
||||||
|
estimators_widget.value = l_valid_estimators
|
||||||
|
|
||||||
|
view_widget.options = l_valid_views + ["avg"]
|
||||||
|
view_widget.value = "avg"
|
||||||
|
|
||||||
|
plot_mode_widget = pn.widgets.RadioButtonGroup(
|
||||||
|
name="mode",
|
||||||
|
value=valid_plot_modes["avg"][0],
|
||||||
|
options=valid_plot_modes["avg"],
|
||||||
|
button_style="outline",
|
||||||
|
button_type="primary",
|
||||||
|
align="center",
|
||||||
|
orientation="vertical",
|
||||||
|
sizing_mode="scale_width",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pn.depends(view_widget.param.value, watch=True)
|
||||||
|
def _update_from_view(_view):
|
||||||
|
_modes = valid_plot_modes[_view]
|
||||||
|
plot_mode_widget.options = _modes
|
||||||
|
plot_mode_widget.value = _modes[0]
|
||||||
|
|
||||||
|
widget_pane = pn.Column(
|
||||||
|
dataset_widget,
|
||||||
|
metric_widget,
|
||||||
|
pn.Row(
|
||||||
|
view_widget,
|
||||||
|
plot_mode_widget,
|
||||||
|
),
|
||||||
|
estimators_widget,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
widget_pane,
|
||||||
|
{
|
||||||
|
"dataset": dataset_widget,
|
||||||
|
"metric": metric_widget,
|
||||||
|
"view": view_widget,
|
||||||
|
"plot_mode": plot_mode_widget,
|
||||||
|
"estimators": estimators_widget,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_plot(
|
||||||
|
datasets: Dict[str, DatasetReport],
|
||||||
|
dst: str,
|
||||||
|
metric: str,
|
||||||
|
estimators: List[str],
|
||||||
|
view: str,
|
||||||
|
mode: str,
|
||||||
|
):
|
||||||
|
_dr = datasets[dst]
|
||||||
|
if view == "avg":
|
||||||
|
return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators)
|
||||||
|
else:
|
||||||
|
prev = int(view)
|
||||||
|
return create_cr_plots(
|
||||||
|
_dr, mode=mode, metric=metric, estimators=estimators, prev=prev
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def explore_datasets(root: Path | str):
|
||||||
|
if isinstance(root, str):
|
||||||
|
root = Path(root)
|
||||||
|
|
||||||
|
if root.name == "plot":
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not root.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
drs = []
|
||||||
|
for f in os.listdir(root):
|
||||||
|
if (root / f).is_dir():
|
||||||
|
drs += explore_datasets(root / f)
|
||||||
|
elif f == f"{root.name}.pickle":
|
||||||
|
drs.append(root / f)
|
||||||
|
# drs.append((str(root),))
|
||||||
|
|
||||||
|
return drs
|
|
@ -0,0 +1,310 @@
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import panel as pn
|
||||||
|
import param
|
||||||
|
|
||||||
|
from qcpanel.util import build_plot, explore_datasets, valid_plot_modes
|
||||||
|
from quacc.evaluation.comp import CE
|
||||||
|
from quacc.evaluation.report import DatasetReport
|
||||||
|
|
||||||
|
|
||||||
|
class QuaccTestViewer(param.Parameterized):
|
||||||
|
dataset = param.Selector()
|
||||||
|
metric = param.Selector()
|
||||||
|
estimators = param.ListSelector()
|
||||||
|
plot_view = param.Selector()
|
||||||
|
mode = param.Selector()
|
||||||
|
|
||||||
|
modal_estimators = param.ListSelector()
|
||||||
|
modal_plot_view = param.ListSelector()
|
||||||
|
modal_mode_prev = param.ListSelector(
|
||||||
|
objects=valid_plot_modes[0], default=valid_plot_modes[0]
|
||||||
|
)
|
||||||
|
modal_mode_avg = param.ListSelector(
|
||||||
|
objects=valid_plot_modes["avg"], default=valid_plot_modes["avg"]
|
||||||
|
)
|
||||||
|
|
||||||
|
param_pane = param.Parameter()
|
||||||
|
plot_pane = param.Parameter()
|
||||||
|
modal_pane = param.Parameter()
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
super().__init__(**params)
|
||||||
|
|
||||||
|
self.__setup_watchers()
|
||||||
|
self.update_datasets()
|
||||||
|
# self._update_on_dataset()
|
||||||
|
self.__create_param_pane()
|
||||||
|
self.__create_modal_pane()
|
||||||
|
|
||||||
|
def __save_callback(self, event):
|
||||||
|
_home = Path("output")
|
||||||
|
_save_input_val = self.save_input.value_input
|
||||||
|
_config = "default" if len(_save_input_val) == 0 else _save_input_val
|
||||||
|
base_path = _home / self.dataset / _config
|
||||||
|
os.makedirs(base_path, exist_ok=True)
|
||||||
|
base_plot = base_path / "plot"
|
||||||
|
os.makedirs(base_plot, exist_ok=True)
|
||||||
|
|
||||||
|
l_dr = self.datasets_[self.dataset]
|
||||||
|
res = l_dr.to_md(
|
||||||
|
conf=_config,
|
||||||
|
metric=self.metric,
|
||||||
|
estimators=CE.name[self.modal_estimators],
|
||||||
|
dr_modes=self.modal_mode_avg,
|
||||||
|
cr_modes=self.modal_mode_prev,
|
||||||
|
cr_prevs=self.modal_plot_view,
|
||||||
|
plot_path=base_plot,
|
||||||
|
)
|
||||||
|
with open(base_path / f"{self.metric}.md", "w") as f:
|
||||||
|
f.write(res)
|
||||||
|
|
||||||
|
pn.state.notifications.success(f'"{_config}" successfully saved')
|
||||||
|
|
||||||
|
def __create_param_pane(self):
|
||||||
|
self.dataset_widget = pn.Param(
|
||||||
|
self,
|
||||||
|
show_name=False,
|
||||||
|
parameters=["dataset"],
|
||||||
|
widgets={"dataset": {"widget_type": pn.widgets.Select}},
|
||||||
|
)
|
||||||
|
self.metric_widget = pn.Param(
|
||||||
|
self,
|
||||||
|
show_name=False,
|
||||||
|
parameters=["metric"],
|
||||||
|
widgets={"metric": {"widget_type": pn.widgets.Select}},
|
||||||
|
)
|
||||||
|
self.estimators_widgets = pn.Param(
|
||||||
|
self,
|
||||||
|
show_name=False,
|
||||||
|
parameters=["estimators"],
|
||||||
|
widgets={
|
||||||
|
"estimators": {
|
||||||
|
"widget_type": pn.widgets.CheckButtonGroup,
|
||||||
|
"orientation": "vertical",
|
||||||
|
"sizing_mode": "scale_width",
|
||||||
|
"button_type": "primary",
|
||||||
|
"button_style": "outline",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.plot_view_widget = pn.Param(
|
||||||
|
self,
|
||||||
|
show_name=False,
|
||||||
|
parameters=["plot_view"],
|
||||||
|
widgets={
|
||||||
|
"plot_view": {
|
||||||
|
"widget_type": pn.widgets.RadioButtonGroup,
|
||||||
|
"orientation": "vertical",
|
||||||
|
"button_type": "primary",
|
||||||
|
"button_style": "outline",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.mode_widget = pn.Param(
|
||||||
|
self,
|
||||||
|
show_name=False,
|
||||||
|
parameters=["mode"],
|
||||||
|
widgets={
|
||||||
|
"mode": {
|
||||||
|
"widget_type": pn.widgets.RadioButtonGroup,
|
||||||
|
"orientation": "vertical",
|
||||||
|
"sizing_mode": "scale_width",
|
||||||
|
"button_type": "primary",
|
||||||
|
"button_style": "outline",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
align="center",
|
||||||
|
)
|
||||||
|
self.param_pane = pn.Column(
|
||||||
|
self.dataset_widget,
|
||||||
|
self.metric_widget,
|
||||||
|
pn.Row(
|
||||||
|
self.plot_view_widget,
|
||||||
|
self.mode_widget,
|
||||||
|
),
|
||||||
|
self.estimators_widgets,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __create_modal_pane(self):
|
||||||
|
self.modal_estimators_widgets = pn.Param(
|
||||||
|
self,
|
||||||
|
show_name=False,
|
||||||
|
parameters=["modal_estimators"],
|
||||||
|
widgets={
|
||||||
|
"modal_estimators": {
|
||||||
|
"widget_type": pn.widgets.CheckButtonGroup,
|
||||||
|
"orientation": "vertical",
|
||||||
|
"sizing_mode": "scale_width",
|
||||||
|
"button_type": "primary",
|
||||||
|
"button_style": "outline",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.modal_plot_view_widget = pn.Param(
|
||||||
|
self,
|
||||||
|
show_name=False,
|
||||||
|
parameters=["modal_plot_view"],
|
||||||
|
widgets={
|
||||||
|
"modal_plot_view": {
|
||||||
|
"widget_type": pn.widgets.CheckButtonGroup,
|
||||||
|
"orientation": "vertical",
|
||||||
|
"button_type": "primary",
|
||||||
|
"button_style": "outline",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.modal_mode_prev_widget = pn.Param(
|
||||||
|
self,
|
||||||
|
show_name=False,
|
||||||
|
parameters=["modal_mode_prev"],
|
||||||
|
widgets={
|
||||||
|
"modal_mode_prev": {
|
||||||
|
"widget_type": pn.widgets.CheckButtonGroup,
|
||||||
|
"orientation": "vertical",
|
||||||
|
"sizing_mode": "scale_width",
|
||||||
|
"button_type": "primary",
|
||||||
|
"button_style": "outline",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
align="center",
|
||||||
|
)
|
||||||
|
self.modal_mode_avg_widget = pn.Param(
|
||||||
|
self,
|
||||||
|
show_name=False,
|
||||||
|
parameters=["modal_mode_avg"],
|
||||||
|
widgets={
|
||||||
|
"modal_mode_avg": {
|
||||||
|
"widget_type": pn.widgets.CheckButtonGroup,
|
||||||
|
"orientation": "vertical",
|
||||||
|
"sizing_mode": "scale_width",
|
||||||
|
"button_type": "primary",
|
||||||
|
"button_style": "outline",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
align="center",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.save_input = pn.widgets.TextInput(
|
||||||
|
name="Configuration Name", placeholder="default", sizing_mode="scale_width"
|
||||||
|
)
|
||||||
|
self.save_button = pn.widgets.Button(
|
||||||
|
name="Save",
|
||||||
|
sizing_mode="scale_width",
|
||||||
|
button_style="solid",
|
||||||
|
button_type="success",
|
||||||
|
)
|
||||||
|
self.save_button.on_click(self.__save_callback)
|
||||||
|
|
||||||
|
_title_styles = {
|
||||||
|
"font-size": "14pt",
|
||||||
|
"font-weight": "bold",
|
||||||
|
}
|
||||||
|
self.modal_pane = pn.Column(
|
||||||
|
pn.Column(
|
||||||
|
pn.pane.Str("Avg. configuration", styles=_title_styles),
|
||||||
|
self.modal_mode_avg_widget,
|
||||||
|
pn.pane.Str("Train prevs. configuration", styles=_title_styles),
|
||||||
|
pn.Row(
|
||||||
|
self.modal_plot_view_widget,
|
||||||
|
self.modal_mode_prev_widget,
|
||||||
|
),
|
||||||
|
pn.pane.Str("Estimators configuration", styles=_title_styles),
|
||||||
|
self.modal_estimators_widgets,
|
||||||
|
self.save_input,
|
||||||
|
self.save_button,
|
||||||
|
pn.Spacer(height=20),
|
||||||
|
width=450,
|
||||||
|
align="center",
|
||||||
|
scroll=True,
|
||||||
|
),
|
||||||
|
sizing_mode="stretch_both",
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_datasets(self):
|
||||||
|
__base_path = "output"
|
||||||
|
dataset_paths = sorted(
|
||||||
|
explore_datasets(__base_path), key=lambda t: (-len(t.parts), t)
|
||||||
|
)
|
||||||
|
self.datasets_ = {
|
||||||
|
str(dp.parent.relative_to(Path(__base_path))): DatasetReport.unpickle(dp)
|
||||||
|
for dp in dataset_paths
|
||||||
|
}
|
||||||
|
|
||||||
|
self.available_datasets = list(self.datasets_.keys())
|
||||||
|
self.param["dataset"].objects = self.available_datasets
|
||||||
|
self.dataset = self.available_datasets[0]
|
||||||
|
|
||||||
|
def __setup_watchers(self):
|
||||||
|
self.param.watch(
|
||||||
|
self._update_on_dataset,
|
||||||
|
["dataset"],
|
||||||
|
queued=True,
|
||||||
|
precedence=0,
|
||||||
|
)
|
||||||
|
self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1)
|
||||||
|
self.param.watch(
|
||||||
|
self._update_plot,
|
||||||
|
["dataset", "metric", "estimators", "plot_view", "mode"],
|
||||||
|
# ["metric", "estimators", "mode"],
|
||||||
|
onlychanged=False,
|
||||||
|
precedence=2,
|
||||||
|
)
|
||||||
|
self.param.watch(
|
||||||
|
self._update_on_estimators,
|
||||||
|
["estimators"],
|
||||||
|
queued=True,
|
||||||
|
precedence=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_on_dataset(self, *events):
|
||||||
|
l_dr = self.datasets_[self.dataset]
|
||||||
|
l_data = l_dr.data()
|
||||||
|
l_metrics = l_data.columns.unique(0)
|
||||||
|
l_estimators = l_data.columns.unique(1)
|
||||||
|
|
||||||
|
l_valid_estimators = [e for e in l_estimators if e != "ref"]
|
||||||
|
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
|
||||||
|
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
|
||||||
|
|
||||||
|
self.param["metric"].objects = l_valid_metrics
|
||||||
|
self.metric = l_valid_metrics[0]
|
||||||
|
|
||||||
|
self.param["estimators"].objects = l_valid_estimators
|
||||||
|
self.estimators = l_valid_estimators
|
||||||
|
|
||||||
|
self.param["plot_view"].objects = ["avg"] + l_valid_views
|
||||||
|
self.plot_view = "avg"
|
||||||
|
|
||||||
|
self.param["mode"].objects = valid_plot_modes["avg"]
|
||||||
|
self.mode = valid_plot_modes["avg"][0]
|
||||||
|
|
||||||
|
self.param["modal_estimators"].objects = l_valid_estimators
|
||||||
|
self.modal_estimators = []
|
||||||
|
|
||||||
|
self.param["modal_plot_view"].objects = l_valid_views
|
||||||
|
self.modal_plot_view = l_valid_views.copy()
|
||||||
|
|
||||||
|
def _update_on_view(self, *events):
|
||||||
|
self.param["mode"].objects = valid_plot_modes[self.plot_view]
|
||||||
|
self.mode = valid_plot_modes[self.plot_view][0]
|
||||||
|
|
||||||
|
def _update_on_estimators(self, *events):
|
||||||
|
self.modal_estimators = self.estimators.copy()
|
||||||
|
|
||||||
|
def _update_plot(self, *events):
|
||||||
|
self.plot_pane = build_plot(
|
||||||
|
datasets=self.datasets_,
|
||||||
|
dst=self.dataset,
|
||||||
|
metric=self.metric,
|
||||||
|
estimators=self.estimators,
|
||||||
|
view=self.plot_view,
|
||||||
|
mode=self.mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_plot(self):
|
||||||
|
return self.plot_pane
|
||||||
|
|
||||||
|
def get_param_pane(self):
|
||||||
|
return self.param_pane
|
Loading…
Reference in New Issue