location sync added, missing plot bug fixed
This commit is contained in:
parent
1123940954
commit
f7b566c4a4
|
@ -9,8 +9,23 @@ from qcpanel.viewer import QuaccTestViewer
|
||||||
pn.config.notifications = True
|
pn.config.notifications = True
|
||||||
|
|
||||||
|
|
||||||
def serve(address="localhost"):
|
def app_instance():
|
||||||
qtv = QuaccTestViewer()
|
param_init = {
|
||||||
|
k: v
|
||||||
|
for k, v in pn.state.location.query_params.items()
|
||||||
|
if k in ["dataset", "metric", "plot_view", "mode", "estimators"]
|
||||||
|
}
|
||||||
|
qtv = QuaccTestViewer(param_init=param_init)
|
||||||
|
pn.state.location.sync(
|
||||||
|
qtv,
|
||||||
|
{
|
||||||
|
"dataset": "dataset",
|
||||||
|
"metric": "metric",
|
||||||
|
"plot_view": "plot_view",
|
||||||
|
"mode": "mode",
|
||||||
|
"estimators": "estimators",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def save_callback(event):
|
def save_callback(event):
|
||||||
app.open_modal()
|
app.open_modal()
|
||||||
|
@ -48,13 +63,17 @@ def serve(address="localhost"):
|
||||||
)
|
)
|
||||||
|
|
||||||
app.servable()
|
app.servable()
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def serve(address="localhost"):
|
||||||
__port = 33420
|
__port = 33420
|
||||||
__allowed = [address]
|
__allowed = [address]
|
||||||
if address == "localhost":
|
if address == "localhost":
|
||||||
__allowed.append("127.0.0.1")
|
__allowed.append("127.0.0.1")
|
||||||
|
|
||||||
pn.serve(
|
pn.serve(
|
||||||
app,
|
app_instance,
|
||||||
autoreload=True,
|
autoreload=True,
|
||||||
port=__port,
|
port=__port,
|
||||||
show=False,
|
show=False,
|
||||||
|
@ -76,4 +95,4 @@ def run():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
serve()
|
run()
|
||||||
|
|
171
qcpanel/util.py
171
qcpanel/util.py
|
@ -1,7 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import panel as pn
|
import panel as pn
|
||||||
|
|
||||||
|
@ -10,118 +9,112 @@ from quacc.evaluation.report import DatasetReport
|
||||||
|
|
||||||
_plot_sizing_mode = "stretch_both"
|
_plot_sizing_mode = "stretch_both"
|
||||||
valid_plot_modes = defaultdict(
|
valid_plot_modes = defaultdict(
|
||||||
lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"]
|
lambda: [
|
||||||
|
"delta_train",
|
||||||
|
"stdev_train",
|
||||||
|
"train_table",
|
||||||
|
"shift",
|
||||||
|
"shift_table",
|
||||||
|
"diagonal",
|
||||||
|
]
|
||||||
)
|
)
|
||||||
valid_plot_modes["avg"] = [
|
valid_plot_modes["avg"] = [
|
||||||
"delta_train",
|
"delta_train",
|
||||||
"stdev_train",
|
"stdev_train",
|
||||||
|
"train_table",
|
||||||
|
"shift",
|
||||||
|
"shift_table",
|
||||||
"delta_test",
|
"delta_test",
|
||||||
"stdev_test",
|
"stdev_test",
|
||||||
"shift",
|
|
||||||
"train_table",
|
|
||||||
"test_table",
|
"test_table",
|
||||||
"shift_table",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_cr_plots(
|
def create_plots(
|
||||||
dr: DatasetReport,
|
dr: DatasetReport,
|
||||||
mode="delta",
|
mode="delta",
|
||||||
metric="acc",
|
metric="acc",
|
||||||
estimators=None,
|
estimators=None,
|
||||||
prev=None,
|
plot_view=None,
|
||||||
):
|
):
|
||||||
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
|
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
|
||||||
idx = _prevs.index(prev)
|
|
||||||
cr = dr.crs[idx]
|
|
||||||
estimators = CE.name[estimators]
|
estimators = CE.name[estimators]
|
||||||
if mode is None:
|
if mode is None:
|
||||||
mode = valid_plot_modes[str(prev)][0]
|
mode = valid_plot_modes[plot_view][0]
|
||||||
_dpi = 112
|
_dpi = 112
|
||||||
if mode == "table":
|
match (plot_view, mode):
|
||||||
return pn.pane.DataFrame(
|
case ("avg", "train_table"):
|
||||||
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
_data = (
|
||||||
align="center",
|
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
|
||||||
)
|
)
|
||||||
elif mode == "shift_table":
|
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
|
||||||
return pn.pane.DataFrame(
|
case ("avg", "test_table"):
|
||||||
cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
_data = (
|
||||||
align="center",
|
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||||
)
|
)
|
||||||
else:
|
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
|
||||||
return pn.pane.Matplotlib(
|
case ("avg", "shift_table"):
|
||||||
cr.get_plots(
|
_data = (
|
||||||
|
dr.shift_data(metric=metric, estimators=estimators)
|
||||||
|
.groupby(level=0)
|
||||||
|
.mean()
|
||||||
|
)
|
||||||
|
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
|
||||||
|
case ("avg", _ as plot_mode):
|
||||||
|
_plot = dr.get_plots(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
metric=metric,
|
metric=metric,
|
||||||
estimators=estimators,
|
estimators=estimators,
|
||||||
conf="panel",
|
conf="panel",
|
||||||
return_fig=True,
|
return_fig=True,
|
||||||
),
|
)
|
||||||
tight=True,
|
return (
|
||||||
format="png",
|
pn.pane.Matplotlib(
|
||||||
sizing_mode=_plot_sizing_mode,
|
_plot,
|
||||||
# sizing_mode="scale_height",
|
tight=True,
|
||||||
# sizing_mode="scale_both",
|
format="png",
|
||||||
)
|
# sizing_mode="scale_height",
|
||||||
|
sizing_mode=_plot_sizing_mode,
|
||||||
|
# sizing_mode="scale_both",
|
||||||
def create_avg_plots(
|
)
|
||||||
dr: DatasetReport,
|
if _plot is not None
|
||||||
mode="delta",
|
else None
|
||||||
metric="acc",
|
)
|
||||||
estimators=None,
|
case (_, "train_table"):
|
||||||
):
|
cr = dr.crs[_prevs.index(int(plot_view))]
|
||||||
estimators = CE.name[estimators]
|
_data = (
|
||||||
if mode is None:
|
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||||
mode = valid_plot_modes["avg"][0]
|
)
|
||||||
|
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
|
||||||
if mode == "train_table":
|
case (_, "shift_table"):
|
||||||
return pn.pane.DataFrame(
|
cr = dr.crs[_prevs.index(int(plot_view))]
|
||||||
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(),
|
_data = (
|
||||||
align="center",
|
cr.shift_data(metric=metric, estimators=estimators)
|
||||||
)
|
.groupby(level=0)
|
||||||
elif mode == "test_table":
|
.mean()
|
||||||
return pn.pane.DataFrame(
|
)
|
||||||
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
|
||||||
align="center",
|
case (_, _ as plot_mode):
|
||||||
)
|
cr = dr.crs[_prevs.index(int(plot_view))]
|
||||||
elif mode == "shift_table":
|
_plot = cr.get_plots(
|
||||||
return pn.pane.DataFrame(
|
mode=plot_mode,
|
||||||
dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
|
metric=metric,
|
||||||
align="center",
|
estimators=estimators,
|
||||||
)
|
conf="panel",
|
||||||
return pn.pane.Matplotlib(
|
return_fig=True,
|
||||||
dr.get_plots(
|
)
|
||||||
mode=mode,
|
return (
|
||||||
metric=metric,
|
pn.pane.Matplotlib(
|
||||||
estimators=estimators,
|
_plot,
|
||||||
conf="panel",
|
tight=True,
|
||||||
return_fig=True,
|
format="png",
|
||||||
),
|
sizing_mode=_plot_sizing_mode,
|
||||||
tight=True,
|
# sizing_mode="scale_height",
|
||||||
format="png",
|
# sizing_mode="scale_both",
|
||||||
# sizing_mode="scale_height",
|
)
|
||||||
sizing_mode=_plot_sizing_mode,
|
if _plot is not None
|
||||||
# sizing_mode="scale_both",
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_plot(
|
|
||||||
datasets: Dict[str, DatasetReport],
|
|
||||||
dst: str,
|
|
||||||
metric: str,
|
|
||||||
estimators: List[str],
|
|
||||||
view: str,
|
|
||||||
mode: str,
|
|
||||||
):
|
|
||||||
_dr = datasets[dst]
|
|
||||||
if view == "avg":
|
|
||||||
return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators)
|
|
||||||
else:
|
|
||||||
prev = int(view)
|
|
||||||
return create_cr_plots(
|
|
||||||
_dr, mode=mode, metric=metric, estimators=estimators, prev=prev
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def explore_datasets(root: Path | str):
|
def explore_datasets(root: Path | str):
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import panel as pn
|
import panel as pn
|
||||||
import param
|
import param
|
||||||
|
|
||||||
from qcpanel.util import build_plot, explore_datasets, valid_plot_modes
|
from qcpanel.util import create_plots, explore_datasets, valid_plot_modes
|
||||||
from quacc.evaluation.comp import CE
|
from quacc.evaluation.comp import CE
|
||||||
from quacc.evaluation.report import DatasetReport
|
from quacc.evaluation.report import DatasetReport
|
||||||
|
|
||||||
|
@ -29,15 +31,24 @@ class QuaccTestViewer(param.Parameterized):
|
||||||
plot_pane = param.Parameter()
|
plot_pane = param.Parameter()
|
||||||
modal_pane = param.Parameter()
|
modal_pane = param.Parameter()
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, param_init=None, **params):
|
||||||
super().__init__(**params)
|
super().__init__(**params)
|
||||||
|
|
||||||
|
self.param_init = param_init
|
||||||
self.__setup_watchers()
|
self.__setup_watchers()
|
||||||
self.update_datasets()
|
self.update_datasets()
|
||||||
# self._update_on_dataset()
|
# self._update_on_dataset()
|
||||||
self.__create_param_pane()
|
self.__create_param_pane()
|
||||||
self.__create_modal_pane()
|
self.__create_modal_pane()
|
||||||
|
|
||||||
|
def __get_param_init(self, val):
|
||||||
|
__b = val in self.param_init
|
||||||
|
if __b:
|
||||||
|
setattr(self, val, self.param_init[val])
|
||||||
|
del self.param_init[val]
|
||||||
|
|
||||||
|
return __b
|
||||||
|
|
||||||
def __save_callback(self, event):
|
def __save_callback(self, event):
|
||||||
_home = Path("output")
|
_home = Path("output")
|
||||||
_save_input_val = self.save_input.value_input
|
_save_input_val = self.save_input.value_input
|
||||||
|
@ -233,8 +244,14 @@ class QuaccTestViewer(param.Parameterized):
|
||||||
}
|
}
|
||||||
|
|
||||||
self.available_datasets = list(self.datasets_.keys())
|
self.available_datasets = list(self.datasets_.keys())
|
||||||
|
_old_dataset = self.dataset
|
||||||
self.param["dataset"].objects = self.available_datasets
|
self.param["dataset"].objects = self.available_datasets
|
||||||
self.dataset = self.available_datasets[0]
|
if not self.__get_param_init("dataset"):
|
||||||
|
self.dataset = (
|
||||||
|
_old_dataset
|
||||||
|
if _old_dataset in self.available_datasets
|
||||||
|
else self.available_datasets[0]
|
||||||
|
)
|
||||||
|
|
||||||
def __setup_watchers(self):
|
def __setup_watchers(self):
|
||||||
self.param.watch(
|
self.param.watch(
|
||||||
|
@ -244,41 +261,57 @@ class QuaccTestViewer(param.Parameterized):
|
||||||
precedence=0,
|
precedence=0,
|
||||||
)
|
)
|
||||||
self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1)
|
self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1)
|
||||||
|
self.param.watch(self._update_on_metric, ["metric"], queued=True, precedence=2)
|
||||||
self.param.watch(
|
self.param.watch(
|
||||||
self._update_plot,
|
self._update_plot,
|
||||||
["dataset", "metric", "estimators", "plot_view", "mode"],
|
["dataset", "metric", "estimators", "plot_view", "mode"],
|
||||||
# ["metric", "estimators", "mode"],
|
# ["metric", "estimators", "mode"],
|
||||||
onlychanged=False,
|
onlychanged=False,
|
||||||
precedence=2,
|
precedence=3,
|
||||||
)
|
)
|
||||||
self.param.watch(
|
self.param.watch(
|
||||||
self._update_on_estimators,
|
self._update_on_estimators,
|
||||||
["estimators"],
|
["estimators"],
|
||||||
queued=True,
|
queued=True,
|
||||||
precedence=3,
|
precedence=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _update_on_dataset(self, *events):
|
def _update_on_dataset(self, *events):
|
||||||
l_dr = self.datasets_[self.dataset]
|
l_dr = self.datasets_[self.dataset]
|
||||||
l_data = l_dr.data()
|
l_data = l_dr.data()
|
||||||
|
|
||||||
l_metrics = l_data.columns.unique(0)
|
l_metrics = l_data.columns.unique(0)
|
||||||
l_estimators = l_data.columns.unique(1)
|
|
||||||
|
|
||||||
l_valid_estimators = [e for e in l_estimators if e != "ref"]
|
|
||||||
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
|
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
|
||||||
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
|
_old_metric = self.metric
|
||||||
|
|
||||||
self.param["metric"].objects = l_valid_metrics
|
self.param["metric"].objects = l_valid_metrics
|
||||||
self.metric = l_valid_metrics[0]
|
if not self.__get_param_init("metric"):
|
||||||
|
self.metric = (
|
||||||
|
_old_metric if _old_metric in l_valid_metrics else l_valid_metrics[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
_old_estimators = self.estimators
|
||||||
|
l_valid_estimators = l_dr.data(metric=self.metric).columns.unique(0).to_numpy()
|
||||||
|
_new_estimators = l_valid_estimators[
|
||||||
|
np.isin(l_valid_estimators, _old_estimators)
|
||||||
|
].tolist()
|
||||||
self.param["estimators"].objects = l_valid_estimators
|
self.param["estimators"].objects = l_valid_estimators
|
||||||
self.estimators = l_valid_estimators
|
if not self.__get_param_init("estimators"):
|
||||||
|
self.estimators = _new_estimators
|
||||||
|
|
||||||
self.param["plot_view"].objects = ["avg"] + l_valid_views
|
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
|
||||||
self.plot_view = "avg"
|
l_valid_views = ["avg"] + l_valid_views
|
||||||
|
_old_view = self.plot_view
|
||||||
|
self.param["plot_view"].objects = l_valid_views
|
||||||
|
if not self.__get_param_init("plot_view"):
|
||||||
|
self.plot_view = _old_view if _old_view in l_valid_views else "avg"
|
||||||
|
|
||||||
self.param["mode"].objects = valid_plot_modes["avg"]
|
self.param["mode"].objects = valid_plot_modes[self.plot_view]
|
||||||
self.mode = valid_plot_modes["avg"][0]
|
if not self.__get_param_init("mode"):
|
||||||
|
_old_mode = self.mode
|
||||||
|
if _old_mode in valid_plot_modes[self.plot_view]:
|
||||||
|
self.mode = _old_mode
|
||||||
|
else:
|
||||||
|
self.mode = valid_plot_modes[self.plot_view][0]
|
||||||
|
|
||||||
self.param["modal_estimators"].objects = l_valid_estimators
|
self.param["modal_estimators"].objects = l_valid_estimators
|
||||||
self.modal_estimators = []
|
self.modal_estimators = []
|
||||||
|
@ -287,21 +320,49 @@ class QuaccTestViewer(param.Parameterized):
|
||||||
self.modal_plot_view = l_valid_views.copy()
|
self.modal_plot_view = l_valid_views.copy()
|
||||||
|
|
||||||
def _update_on_view(self, *events):
|
def _update_on_view(self, *events):
|
||||||
|
_old_mode = self.mode
|
||||||
self.param["mode"].objects = valid_plot_modes[self.plot_view]
|
self.param["mode"].objects = valid_plot_modes[self.plot_view]
|
||||||
self.mode = valid_plot_modes[self.plot_view][0]
|
if _old_mode in valid_plot_modes[self.plot_view]:
|
||||||
|
self.mode = _old_mode
|
||||||
|
else:
|
||||||
|
self.mode = valid_plot_modes[self.plot_view][0]
|
||||||
|
|
||||||
|
def _update_on_metric(self, *events):
|
||||||
|
_old_estimators = self.estimators
|
||||||
|
|
||||||
|
l_dr = self.datasets_[self.dataset]
|
||||||
|
l_data: pd.DataFrame = l_dr.data(metric=self.metric)
|
||||||
|
l_valid_estimators: np.ndarray = l_data.columns.unique(0).to_numpy()
|
||||||
|
_new_estimators = l_valid_estimators[
|
||||||
|
np.isin(l_valid_estimators, _old_estimators)
|
||||||
|
].tolist()
|
||||||
|
self.param["estimators"].objects = l_valid_estimators
|
||||||
|
self.estimators = _new_estimators
|
||||||
|
|
||||||
def _update_on_estimators(self, *events):
|
def _update_on_estimators(self, *events):
|
||||||
self.modal_estimators = self.estimators.copy()
|
self.modal_estimators = self.estimators.copy()
|
||||||
|
|
||||||
def _update_plot(self, *events):
|
def _update_plot(self, *events):
|
||||||
self.plot_pane = build_plot(
|
__svg = pn.pane.SVG(
|
||||||
datasets=self.datasets_,
|
"""<svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-chart-area-filled" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
|
||||||
dst=self.dataset,
|
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||||||
metric=self.metric,
|
<path d="M20 18a1 1 0 0 1 .117 1.993l-.117 .007h-16a1 1 0 0 1 -.117 -1.993l.117 -.007h16z" stroke-width="0" fill="currentColor" />
|
||||||
estimators=self.estimators,
|
<path d="M15.22 5.375a1 1 0 0 1 1.393 -.165l.094 .083l4 4a1 1 0 0 1 .284 .576l.009 .131v5a1 1 0 0 1 -.883 .993l-.117 .007h-16.022l-.11 -.009l-.11 -.02l-.107 -.034l-.105 -.046l-.1 -.059l-.094 -.07l-.06 -.055l-.072 -.082l-.064 -.089l-.054 -.096l-.016 -.035l-.04 -.103l-.027 -.106l-.015 -.108l-.004 -.11l.009 -.11l.019 -.105c.01 -.04 .022 -.077 .035 -.112l.046 -.105l.059 -.1l4 -6a1 1 0 0 1 1.165 -.39l.114 .05l3.277 1.638l3.495 -4.369z" stroke-width="0" fill="currentColor" />
|
||||||
view=self.plot_view,
|
</svg>""",
|
||||||
mode=self.mode,
|
sizing_mode="stretch_both",
|
||||||
)
|
)
|
||||||
|
if len(self.estimators) == 0:
|
||||||
|
self.plot_pane = __svg
|
||||||
|
else:
|
||||||
|
_dr = self.datasets_[self.dataset]
|
||||||
|
__plot = create_plots(
|
||||||
|
_dr,
|
||||||
|
mode=self.mode,
|
||||||
|
metric=self.metric,
|
||||||
|
estimators=self.estimators,
|
||||||
|
plot_view=self.plot_view,
|
||||||
|
)
|
||||||
|
self.plot_pane = __svg if __plot is None else __plot
|
||||||
|
|
||||||
def get_plot(self):
|
def get_plot(self):
|
||||||
return self.plot_pane
|
return self.plot_pane
|
||||||
|
|
Loading…
Reference in New Issue