QuAcc/qcpanel/run.py

613 lines
18 KiB
Python
Raw Normal View History

2023-11-11 16:44:12 +01:00
import argparse
import os
2023-11-16 01:34:01 +01:00
from collections import defaultdict
2023-11-11 16:44:12 +01:00
from pathlib import Path
2023-11-16 01:34:01 +01:00
from typing import Dict, List
2023-11-11 16:44:12 +01:00
import panel as pn
2023-11-16 01:34:01 +01:00
import param
2023-11-11 16:44:12 +01:00
2023-11-16 01:34:01 +01:00
from quacc import utils
2023-11-11 16:44:12 +01:00
from quacc.evaluation.comp import CE
from quacc.evaluation.report import DatasetReport
2023-11-16 01:34:01 +01:00
pn.config.design = pn.theme.Bootstrap
pn.config.theme = "dark"
pn.config.notifications = True
valid_plot_modes = defaultdict(
lambda: ["delta", "delta_stdev", "diagonal", "shift", "table", "shift_table"]
)
valid_plot_modes["avg"] = [
"delta_train",
"stdev_train",
"delta_test",
"stdev_test",
"shift",
"train_table",
"test_table",
"shift_table",
]
2023-11-11 16:44:12 +01:00
def create_cr_plots(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
prev=None,
):
2023-11-16 01:34:01 +01:00
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
idx = _prevs.index(prev)
2023-11-11 16:44:12 +01:00
cr = dr.crs[idx]
estimators = CE.name[estimators]
2023-11-16 01:34:01 +01:00
if mode is None:
mode = valid_plot_modes[str(prev)][0]
2023-11-11 16:44:12 +01:00
_dpi = 112
2023-11-16 01:34:01 +01:00
if mode == "table":
return pn.pane.DataFrame(
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
elif mode == "shift_table":
return pn.pane.DataFrame(
cr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
else:
return pn.pane.Matplotlib(
cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
),
tight=True,
format="png",
sizing_mode="scale_height",
# sizing_mode="scale_both",
)
2023-11-11 16:44:12 +01:00
def create_avg_plots(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
):
estimators = CE.name[estimators]
2023-11-16 01:34:01 +01:00
if mode is None:
mode = valid_plot_modes["avg"][0]
if mode == "train_table":
return pn.pane.DataFrame(
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean(),
align="center",
)
elif mode == "test_table":
return pn.pane.DataFrame(
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
elif mode == "shift_table":
return pn.pane.DataFrame(
dr.shift_data(metric=metric, estimators=estimators).groupby(level=0).mean(),
align="center",
)
2023-11-11 16:44:12 +01:00
return pn.pane.Matplotlib(
dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
),
tight=True,
format="png",
sizing_mode="scale_height",
# sizing_mode="scale_both",
)
2023-11-16 01:34:01 +01:00
def build_widgets(datasets: Dict[str, DatasetReport]):
available_datasets = list(datasets.keys())
dataset_widget = pn.widgets.Select(
name="dataset",
options=available_datasets,
align="center",
)
_dr = datasets[dataset_widget.value]
_data = _dr.data()
2023-11-11 16:44:12 +01:00
_metrics = _data.columns.unique(0)
_estimators = _data.columns.unique(1)
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
metric_widget = pn.widgets.Select(
name="metric",
value="acc",
options=valid_metrics,
align="center",
)
valid_estimators = [e for e in _estimators if e != "ref"]
estimators_widget = pn.widgets.CheckButtonGroup(
name="estimators",
options=valid_estimators,
value=valid_estimators,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
2023-11-16 01:34:01 +01:00
valid_views = [str(round(cr.train_prev[1] * 100)) for cr in _dr.crs]
view_widget = pn.widgets.RadioButtonGroup(
name="view",
options=valid_views + ["avg"],
value="avg",
2023-11-11 16:44:12 +01:00
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
)
2023-11-16 01:34:01 +01:00
@pn.depends(dataset_widget.param.value, watch=True)
def _update_from_dataset(_dataset):
l_dr = datasets[dataset_widget.value]
l_data = l_dr.data()
l_metrics = l_data.columns.unique(0)
l_estimators = l_data.columns.unique(1)
2023-11-11 16:44:12 +01:00
2023-11-16 01:34:01 +01:00
l_valid_estimators = [e for e in l_estimators if e != "ref"]
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
2023-11-11 16:44:12 +01:00
2023-11-16 01:34:01 +01:00
metric_widget.options = l_valid_metrics
metric_widget.value = l_valid_metrics[0]
2023-11-11 16:44:12 +01:00
2023-11-16 01:34:01 +01:00
estimators_widget.options = l_valid_estimators
estimators_widget.value = l_valid_estimators
2023-11-11 16:44:12 +01:00
2023-11-16 01:34:01 +01:00
view_widget.options = l_valid_views + ["avg"]
view_widget.value = "avg"
2023-11-11 16:44:12 +01:00
plot_mode_widget = pn.widgets.RadioButtonGroup(
name="mode",
2023-11-16 01:34:01 +01:00
value=valid_plot_modes["avg"][0],
options=valid_plot_modes["avg"],
2023-11-11 16:44:12 +01:00
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
2023-11-16 01:34:01 +01:00
@pn.depends(view_widget.param.value, watch=True)
def _update_from_view(_view):
_modes = valid_plot_modes[_view]
plot_mode_widget.options = _modes
plot_mode_widget.value = _modes[0]
widget_pane = pn.Column(
dataset_widget,
metric_widget,
pn.Row(
view_widget,
2023-11-11 16:44:12 +01:00
plot_mode_widget,
),
2023-11-16 01:34:01 +01:00
estimators_widget,
2023-11-11 16:44:12 +01:00
)
2023-11-16 01:34:01 +01:00
return (
widget_pane,
{
"dataset": dataset_widget,
"metric": metric_widget,
"view": view_widget,
"plot_mode": plot_mode_widget,
"estimators": estimators_widget,
},
)
2023-11-11 16:44:12 +01:00
2023-11-16 01:34:01 +01:00
def build_plot(
datasets: Dict[str, DatasetReport],
dst: str,
metric: str,
estimators: List[str],
view: str,
mode: str,
):
_dr = datasets[dst]
if view == "avg":
return create_avg_plots(_dr, mode=mode, metric=metric, estimators=estimators)
else:
prev = int(view)
return create_cr_plots(
_dr, mode=mode, metric=metric, estimators=estimators, prev=prev
)
def build_modal(datasets, dst, metric):
return pn.pane.Str(f"{dst}_{metric}")
2023-11-11 16:44:12 +01:00
2023-11-16 01:34:01 +01:00
def build_save_pane(datasets: Dict[str, DatasetReport], dst: str, metric: str):
return pn.pane.Str(f"{datasets[dst]}_{metric}")
2023-11-11 16:44:12 +01:00
def explore_datasets(root: Path | str):
if isinstance(root, str):
root = Path(root)
2023-11-16 01:34:01 +01:00
if root.name == "plot":
return []
if not root.exists():
return []
2023-11-11 16:44:12 +01:00
drs = []
for f in os.listdir(root):
if (root / f).is_dir():
drs += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
2023-11-16 01:34:01 +01:00
drs.append(root / f)
2023-11-11 16:44:12 +01:00
# drs.append((str(root),))
return drs
2023-11-16 01:34:01 +01:00
class QuaccTestViewer(param.Parameterized):
dataset = param.Selector()
metric = param.Selector()
estimators = param.ListSelector()
plot_view = param.Selector()
mode = param.Selector()
modal_estimators = param.ListSelector()
modal_plot_view = param.ListSelector()
modal_mode_prev = param.ListSelector(
objects=valid_plot_modes[0], default=valid_plot_modes[0]
)
modal_mode_avg = param.ListSelector(
objects=valid_plot_modes["avg"], default=valid_plot_modes["avg"]
)
param_pane = param.Parameter()
plot_pane = param.Parameter()
modal_pane = param.Parameter()
def __init__(self, **params):
super().__init__(**params)
self.__setup_watchers()
self.__import_datasets()
# self._update_on_dataset()
self.__create_param_pane()
self.__create_modal_pane()
def __save_callback(self, event):
_home = utils.get_quacc_home()
_save_input_val = self.save_input.value_input
_config = "default" if len(_save_input_val) == 0 else _save_input_val
base_path = _home / "output" / self.dataset / _config
os.makedirs(base_path, exist_ok=True)
base_plot = base_path / "plot"
os.makedirs(base_plot, exist_ok=True)
l_dr = self.datasets_[self.dataset]
res = l_dr.to_md(
conf=_config,
metric=self.metric,
estimators=CE.name[self.modal_estimators],
dr_modes=self.modal_mode_avg,
cr_modes=self.modal_mode_prev,
plot_path=base_plot,
)
with open(base_path / f"{self.metric}.md", "w") as f:
f.write(res)
pn.state.notifications.success(f'"{_config}" successfully saved')
def __create_param_pane(self):
self.dataset_widget = pn.Param(
self,
show_name=False,
parameters=["dataset"],
widgets={"dataset": {"widget_type": pn.widgets.Select}},
)
self.metric_widget = pn.Param(
self,
show_name=False,
parameters=["metric"],
widgets={"metric": {"widget_type": pn.widgets.Select}},
)
self.estimators_widgets = pn.Param(
self,
show_name=False,
parameters=["estimators"],
widgets={
"estimators": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.plot_view_widget = pn.Param(
self,
show_name=False,
parameters=["plot_view"],
widgets={
"plot_view": {
"widget_type": pn.widgets.RadioButtonGroup,
"orientation": "vertical",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.mode_widget = pn.Param(
self,
show_name=False,
parameters=["mode"],
widgets={
"mode": {
"widget_type": pn.widgets.RadioButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
align="center",
)
self.param_pane = pn.Column(
self.dataset_widget,
self.metric_widget,
pn.Row(
self.plot_view_widget,
self.mode_widget,
),
self.estimators_widgets,
)
def __create_modal_pane(self):
self.modal_estimators_widgets = pn.Param(
self,
show_name=False,
parameters=["modal_estimators"],
widgets={
"modal_estimators": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.modal_plot_view_widget = pn.Param(
self,
show_name=False,
parameters=["modal_plot_view"],
widgets={
"modal_plot_view": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"button_type": "primary",
"button_style": "outline",
}
},
)
self.modal_mode_prev_widget = pn.Param(
self,
show_name=False,
parameters=["modal_mode_prev"],
widgets={
"modal_mode_prev": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
align="center",
)
self.modal_mode_avg_widget = pn.Param(
self,
show_name=False,
parameters=["modal_mode_avg"],
widgets={
"modal_mode_avg": {
"widget_type": pn.widgets.CheckButtonGroup,
"orientation": "vertical",
"sizing_mode": "scale_width",
"button_type": "primary",
"button_style": "outline",
}
},
align="center",
)
self.save_input = pn.widgets.TextInput(
name="Configuration Name", placeholder="default", sizing_mode="scale_width"
)
self.save_button = pn.widgets.Button(
name="Saverrr",
sizing_mode="scale_width",
button_style="solid",
button_type="success",
)
self.save_button.on_click(self.__save_callback)
_title_styles = {
"font-size": "14pt",
"font-weight": "bold",
}
self.modal_pane = pn.Column(
pn.Column(
pn.pane.Str("Avg. configuration", styles=_title_styles),
self.modal_mode_avg_widget,
pn.pane.Str("Train prevs. configuration", styles=_title_styles),
pn.Row(
self.modal_plot_view_widget,
self.modal_mode_prev_widget,
),
pn.pane.Str("Estimators configuration", styles=_title_styles),
self.modal_estimators_widgets,
self.save_input,
self.save_button,
width=450,
align="center",
scroll=True,
),
sizing_mode="stretch_both",
)
def __import_datasets(self):
__base_path = "output"
dataset_paths = sorted(
explore_datasets(__base_path), key=lambda t: (-len(t.parts), t)
)
self.datasets_ = {
str(dp.parent.relative_to(Path(__base_path))): DatasetReport.unpickle(dp)
for dp in dataset_paths
}
self.available_datasets = list(self.datasets_.keys())
self.param["dataset"].objects = self.available_datasets
self.dataset = self.available_datasets[0]
def __setup_watchers(self):
self.param.watch(
self._update_on_dataset,
["dataset"],
queued=True,
precedence=0,
)
self.param.watch(self._update_on_view, ["plot_view"], queued=True, precedence=1)
self.param.watch(
self._update_plot,
["dataset", "metric", "estimators", "plot_view", "mode"],
# ["metric", "estimators", "mode"],
onlychanged=False,
precedence=2,
)
self.param.watch(
self._update_on_estimators,
["estimators"],
queued=True,
precedence=3,
)
def _update_on_dataset(self, *events):
l_dr = self.datasets_[self.dataset]
l_data = l_dr.data()
l_metrics = l_data.columns.unique(0)
l_estimators = l_data.columns.unique(1)
l_valid_estimators = [e for e in l_estimators if e != "ref"]
l_valid_metrics = [m for m in l_metrics if not m.endswith("_score")]
l_valid_views = [str(round(cr.train_prev[1] * 100)) for cr in l_dr.crs]
self.param["metric"].objects = l_valid_metrics
self.metric = l_valid_metrics[0]
self.param["estimators"].objects = l_valid_estimators
self.estimators = l_valid_estimators
self.param["plot_view"].objects = ["avg"] + l_valid_views
self.plot_view = "avg"
self.param["mode"].objects = valid_plot_modes["avg"]
self.mode = valid_plot_modes["avg"][0]
self.param["modal_estimators"].objects = l_valid_estimators
self.modal_estimators = []
self.param["modal_plot_view"].objects = l_valid_views
self.modal_plot_view = l_valid_views.copy()
def _update_on_view(self, *events):
self.param["mode"].objects = valid_plot_modes[self.plot_view]
self.mode = valid_plot_modes[self.plot_view][0]
def _update_on_estimators(self, *events):
self.modal_estimators = self.estimators.copy()
def _update_plot(self, *events):
self.plot_pane = build_plot(
datasets=self.datasets_,
dst=self.dataset,
metric=self.metric,
estimators=self.estimators,
view=self.plot_view,
mode=self.mode,
)
def get_plot(self):
return self.plot_pane
def get_param_pane(self):
return self.param_pane
2023-11-11 16:44:12 +01:00
def serve(address="localhost"):
2023-11-16 01:34:01 +01:00
qtv = QuaccTestViewer()
def save_callback(event):
app.open_modal()
save_button = pn.widgets.Button(
name="Save",
sizing_mode="scale_width",
button_style="solid",
button_type="success",
2023-11-11 16:44:12 +01:00
)
2023-11-16 01:34:01 +01:00
save_button.on_click(save_callback)
2023-11-11 16:44:12 +01:00
2023-11-16 01:34:01 +01:00
app = pn.template.MaterialTemplate(
title="quacc tests",
sidebar=[save_button, qtv.get_param_pane],
main=[qtv.get_plot],
modal=[qtv.modal_pane],
)
app.servable()
2023-11-11 16:44:12 +01:00
__port = 33420
2023-11-16 01:34:01 +01:00
__allowed = [address]
if address == "localhost":
__allowed.append("127.0.0.1")
2023-11-11 16:44:12 +01:00
pn.serve(
app,
autoreload=True,
port=__port,
show=False,
address=address,
2023-11-16 01:34:01 +01:00
websocket_origin=[f"{_a}:{__port}" for _a in __allowed],
2023-11-11 16:44:12 +01:00
)
def run():
parser = argparse.ArgumentParser()
parser.add_argument(
"--address",
action="store",
dest="address",
default="localhost",
)
args = parser.parse_args()
serve(address=args.address)
2023-11-16 01:34:01 +01:00
if __name__ == "__main__":
serve()