location sync added, missing plot bug fixed

This commit is contained in:
Lorenzo Volpi 2023-11-22 19:19:51 +01:00
parent 1123940954
commit f7b566c4a4
3 changed files with 190 additions and 117 deletions

View File

@ -9,8 +9,23 @@ from qcpanel.viewer import QuaccTestViewer
pn.config.notifications = True pn.config.notifications = True
def serve(address="localhost"): def app_instance():
qtv = QuaccTestViewer() param_init = {
k: v
for k, v in pn.state.location.query_params.items()
if k in ["dataset", "metric", "plot_view", "mode", "estimators"]
}
qtv = QuaccTestViewer(param_init=param_init)
pn.state.location.sync(
qtv,
{
"dataset": "dataset",
"metric": "metric",
"plot_view": "plot_view",
"mode": "mode",
"estimators": "estimators",
},
)
def save_callback(event): def save_callback(event):
app.open_modal() app.open_modal()
@ -48,13 +63,17 @@ def serve(address="localhost"):
) )
app.servable() app.servable()
return app
def serve(address="localhost"):
__port = 33420 __port = 33420
__allowed = [address] __allowed = [address]
if address == "localhost": if address == "localhost":
__allowed.append("127.0.0.1") __allowed.append("127.0.0.1")
pn.serve( pn.serve(
app, app_instance,
autoreload=True, autoreload=True,
port=__port, port=__port,
show=False, show=False,
@ -76,4 +95,4 @@ def run():
if __name__ == "__main__": if __name__ == "__main__":
serve() run()

View File

@ -1,7 +1,6 @@
import os import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List
import panel as pn import panel as pn
@ -10,118 +9,112 @@ from quacc.evaluation.report import DatasetReport
_plot_sizing_mode = "stretch_both" _plot_sizing_mode = "stretch_both"
valid_plot_modes = defaultdict( valid_plot_modes = defaultdict(
lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"] lambda: [
"delta_train",
"stdev_train",
"train_table",
"shift",
"shift_table",
"diagonal",
]
) )
valid_plot_modes["avg"] = [ valid_plot_modes["avg"] = [
"delta_train", "delta_train",
"stdev_train", "stdev_train",
"train_table",
"shift",
"shift_table",
"delta_test", "delta_test",
"stdev_test", "stdev_test",
"shift",
"train_table",
"test_table", "test_table",
"shift_table",
] ]
def create_cr_plots( def create_plots(
dr: DatasetReport, dr: DatasetReport,
mode="delta", mode="delta",
metric="acc", metric="acc",
estimators=None, estimators=None,
prev=None, plot_view=None,
): ):
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] _prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
idx = _prevs.index(prev)
cr = dr.crs[idx]
estimators = CE.name[estimators] estimators = CE.name[estimators]
if mode is None: if mode is None:
mode = valid_plot_modes[str(prev)][0] mode = valid_plot_modes[plot_view][0]
_dpi = 112 _dpi = 112
if mode == "table": match (plot_view, mode):
return pn.pane.DataFrame( case ("avg", "train_table"):
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(), _data = (
align="center", dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
) )
elif mode == "shift_table": return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
return pn.pane.DataFrame( case ("avg", "test_table"):
cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(), _data = (
align="center", dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
) )
else: return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
return pn.pane.Matplotlib( case ("avg", "shift_table"):
cr.get_plots( _data = (
dr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
case ("avg", _ as plot_mode):
_plot = dr.get_plots(
mode=mode, mode=mode,
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf="panel", conf="panel",
return_fig=True, return_fig=True,
), )
tight=True, return (
format="png", pn.pane.Matplotlib(
sizing_mode=_plot_sizing_mode, _plot,
# sizing_mode="scale_height", tight=True,
# sizing_mode="scale_both", format="png",
) # sizing_mode="scale_height",
sizing_mode=_plot_sizing_mode,
# sizing_mode="scale_both",
def create_avg_plots( )
dr: DatasetReport, if _plot is not None
mode="delta", else None
metric="acc", )
estimators=None, case (_, "train_table"):
): cr = dr.crs[_prevs.index(int(plot_view))]
estimators = CE.name[estimators] _data = (
if mode is None: cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
mode = valid_plot_modes["avg"][0] )
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
if mode == "train_table": case (_, "shift_table"):
return pn.pane.DataFrame( cr = dr.crs[_prevs.index(int(plot_view))]
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(), _data = (
align="center", cr.shift_data(metric=metric, estimators=estimators)
) .groupby(level=0)
elif mode == "test_table": .mean()
return pn.pane.DataFrame( )
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(), return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
align="center", case (_, _ as plot_mode):
) cr = dr.crs[_prevs.index(int(plot_view))]
elif mode == "shift_table": _plot = cr.get_plots(
return pn.pane.DataFrame( mode=plot_mode,
dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(), metric=metric,
align="center", estimators=estimators,
) conf="panel",
return pn.pane.Matplotlib( return_fig=True,
dr.get_plots( )
mode=mode, return (
metric=metric, pn.pane.Matplotlib(
estimators=estimators, _plot,
conf="panel", tight=True,
return_fig=True, format="png",
), sizing_mode=_plot_sizing_mode,
tight=True, # sizing_mode="scale_height",
format="png", # sizing_mode="scale_both",
# sizing_mode="scale_height", )
sizing_mode=_plot_sizing_mode, if _plot is not None
# sizing_mode="scale_both", else None
) )
def build_plot(
datasets: Dict[str, DatasetReport],
dst: str,
metric: str,
estimators: List[str],
view: str,
mode: str,
):
_dr = datasets[dst]
if view == "avg":
return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators)
else:
prev = int(view)
return create_cr_plots(
_dr, mode=mode, metric=metric, estimators=estimators, prev=prev
)
def explore_datasets(root: Path | str): def explore_datasets(root: Path | str):

View File

@ -1,10 +1,12 @@
import os import os
from pathlib import Path from pathlib import Path
import numpy as np
import pandas as pd
import panel as pn import panel as pn
import param import param
from qcpanel.util import build_plot, explore_datasets, valid_plot_modes from qcpanel.util import create_plots, explore_datasets, valid_plot_modes
from quacc.evaluation.comp import CE from quacc.evaluation.comp import CE
from quacc.evaluation.report import DatasetReport from quacc.evaluation.report import DatasetReport
@ -29,15 +31,24 @@ class QuaccTestViewer(param.Parameterized):
plot_pane = param.Parameter() plot_pane = param.Parameter()
modal_pane = param.Parameter() modal_pane = param.Parameter()
def __init__(self, **params): def __init__(self, param_init=None, **params):
super().__init__(**params) super().__init__(**params)
self.param_init = param_init
self.__setup_watchers() self.__setup_watchers()
self.update_datasets() self.update_datasets()
# self._update_on_dataset() # self._update_on_dataset()
self.__create_param_pane() self.__create_param_pane()
self.__create_modal_pane() self.__create_modal_pane()
def __get_param_init(self, val):
__b = val in self.param_init
if __b:
setattr(self, val, self.param_init[val])
del self.param_init[val]
return __b
def __save_callback(self, event): def __save_callback(self, event):
_home = Path("output") _home = Path("output")
_save_input_val = self.save_input.value_input _save_input_val = self.save_input.value_input
@ -233,8 +244,14 @@ class QuaccTestViewer(param.Parameterized):
} }
self.available_datasets = list(self.datasets_.keys()) self.available_datasets = list(self.datasets_.keys())
_old_dataset = self.dataset
self.param["dataset"].objects = self.available_datasets self.param["dataset"].objects = self.available_datasets
self.dataset = self.available_datasets[0] if not self.__get_param_init("dataset"):
self.dataset = (
_old_dataset
if _old_dataset in self.available_datasets
else self.available_datasets[0]
)
def __setup_watchers(self): def __setup_watchers(self):
self.param.watch( self.param.watch(
@ -244,41 +261,57 @@ class QuaccTestViewer(param.Parameterized):
precedence=0, precedence=0,
) )
self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1) self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1)
self.param.watch(self._update_on_metric, ["metric"], queued=True, precedence=2)
self.param.watch( self.param.watch(
self._update_plot, self._update_plot,
["dataset", "metric", "estimators", "plot_view", "mode"], ["dataset", "metric", "estimators", "plot_view", "mode"],
# ["metric", "estimators", "mode"], # ["metric", "estimators", "mode"],
onlychanged=False, onlychanged=False,
precedence=2, precedence=3,
) )
self.param.watch( self.param.watch(
self._update_on_estimators, self._update_on_estimators,
["estimators"], ["estimators"],
queued=True, queued=True,
precedence=3, precedence=4,
) )
def _update_on_dataset(self, *events): def _update_on_dataset(self, *events):
l_dr = self.datasets_[self.dataset] l_dr = self.datasets_[self.dataset]
l_data = l_dr.data() l_data = l_dr.data()
l_metrics = l_data.columns.unique(0) l_metrics = l_data.columns.unique(0)
l_estimators = l_data.columns.unique(1)
l_valid_estimators = [e for e in l_estimators if e != "ref"]
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")] l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs] _old_metric = self.metric
self.param["metric"].objects = l_valid_metrics self.param["metric"].objects = l_valid_metrics
self.metric = l_valid_metrics[0] if not self.__get_param_init("metric"):
self.metric = (
_old_metric if _old_metric in l_valid_metrics else l_valid_metrics[0]
)
_old_estimators = self.estimators
l_valid_estimators = l_dr.data(metric=self.metric).columns.unique(0).to_numpy()
_new_estimators = l_valid_estimators[
np.isin(l_valid_estimators, _old_estimators)
].tolist()
self.param["estimators"].objects = l_valid_estimators self.param["estimators"].objects = l_valid_estimators
self.estimators = l_valid_estimators if not self.__get_param_init("estimators"):
self.estimators = _new_estimators
self.param["plot_view"].objects = ["avg"] + l_valid_views l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
self.plot_view = "avg" l_valid_views = ["avg"] + l_valid_views
_old_view = self.plot_view
self.param["plot_view"].objects = l_valid_views
if not self.__get_param_init("plot_view"):
self.plot_view = _old_view if _old_view in l_valid_views else "avg"
self.param["mode"].objects = valid_plot_modes["avg"] self.param["mode"].objects = valid_plot_modes[self.plot_view]
self.mode = valid_plot_modes["avg"][0] if not self.__get_param_init("mode"):
_old_mode = self.mode
if _old_mode in valid_plot_modes[self.plot_view]:
self.mode = _old_mode
else:
self.mode = valid_plot_modes[self.plot_view][0]
self.param["modal_estimators"].objects = l_valid_estimators self.param["modal_estimators"].objects = l_valid_estimators
self.modal_estimators = [] self.modal_estimators = []
@ -287,21 +320,49 @@ class QuaccTestViewer(param.Parameterized):
self.modal_plot_view = l_valid_views.copy() self.modal_plot_view = l_valid_views.copy()
def _update_on_view(self, *events): def _update_on_view(self, *events):
_old_mode = self.mode
self.param["mode"].objects = valid_plot_modes[self.plot_view] self.param["mode"].objects = valid_plot_modes[self.plot_view]
self.mode = valid_plot_modes[self.plot_view][0] if _old_mode in valid_plot_modes[self.plot_view]:
self.mode = _old_mode
else:
self.mode = valid_plot_modes[self.plot_view][0]
def _update_on_metric(self, *events):
_old_estimators = self.estimators
l_dr = self.datasets_[self.dataset]
l_data: pd.DataFrame = l_dr.data(metric=self.metric)
l_valid_estimators: np.ndarray = l_data.columns.unique(0).to_numpy()
_new_estimators = l_valid_estimators[
np.isin(l_valid_estimators, _old_estimators)
].tolist()
self.param["estimators"].objects = l_valid_estimators
self.estimators = _new_estimators
def _update_on_estimators(self, *events): def _update_on_estimators(self, *events):
self.modal_estimators = self.estimators.copy() self.modal_estimators = self.estimators.copy()
def _update_plot(self, *events): def _update_plot(self, *events):
self.plot_pane = build_plot( __svg = pn.pane.SVG(
datasets=self.datasets_, """<svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-chart-area-filled" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
dst=self.dataset, <path stroke="none" d="M0 0h24v24H0z" fill="none" />
metric=self.metric, <path d="M20 18a1 1 0 0 1 .117 1.993l-.117 .007h-16a1 1 0 0 1 -.117 -1.993l.117 -.007h16z" stroke-width="0" fill="currentColor" />
estimators=self.estimators, <path d="M15.22 5.375a1 1 0 0 1 1.393 -.165l.094 .083l4 4a1 1 0 0 1 .284 .576l.009 .131v5a1 1 0 0 1 -.883 .993l-.117 .007h-16.022l-.11 -.009l-.11 -.02l-.107 -.034l-.105 -.046l-.1 -.059l-.094 -.07l-.06 -.055l-.072 -.082l-.064 -.089l-.054 -.096l-.016 -.035l-.04 -.103l-.027 -.106l-.015 -.108l-.004 -.11l.009 -.11l.019 -.105c.01 -.04 .022 -.077 .035 -.112l.046 -.105l.059 -.1l4 -6a1 1 0 0 1 1.165 -.39l.114 .05l3.277 1.638l3.495 -4.369z" stroke-width="0" fill="currentColor" />
view=self.plot_view, </svg>""",
mode=self.mode, sizing_mode="stretch_both",
) )
if len(self.estimators) == 0:
self.plot_pane = __svg
else:
_dr = self.datasets_[self.dataset]
__plot = create_plots(
_dr,
mode=self.mode,
metric=self.metric,
estimators=self.estimators,
plot_view=self.plot_view,
)
self.plot_pane = __svg if __plot is None else __plot
def get_plot(self): def get_plot(self):
return self.plot_pane return self.plot_pane