From deeb522ccb2d7948ea38fe24e98201899438a922 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Wed, 29 Nov 2023 03:55:13 +0100 Subject: [PATCH] theme updated --- qcpanel/old_run.py | 290 --------------------------------------------- qcpanel/run.py | 7 +- 2 files changed, 4 insertions(+), 293 deletions(-) delete mode 100644 qcpanel/old_run.py diff --git a/qcpanel/old_run.py b/qcpanel/old_run.py deleted file mode 100644 index 1089e7a..0000000 --- a/qcpanel/old_run.py +++ /dev/null @@ -1,290 +0,0 @@ -import argparse -import os -from pathlib import Path - -import panel as pn -import param - -from quacc.evaluation.estimators import CE -from quacc.evaluation.report import DatasetReport - -pn.extension(design="bootstrap") - - -def create_cr_plots( - dr: DatasetReport, - mode="delta", - metric="acc", - estimators=None, - prev=None, -): - idx = [round(cr.train_prev[1] * 100) for cr in dr.crs].index(prev) - cr = dr.crs[idx] - estimators = CE.name[estimators] - _dpi = 112 - return pn.pane.Matplotlib( - cr.get_plots( - mode=mode, - metric=metric, - estimators=estimators, - conf="panel", - return_fig=True, - ), - tight=True, - format="png", - sizing_mode="scale_height", - # sizing_mode="scale_both", - ) - - -def create_avg_plots( - dr: DatasetReport, - mode="delta", - metric="acc", - estimators=None, - prev=None, -): - estimators = CE.name[estimators] - return pn.pane.Matplotlib( - dr.get_plots( - mode=mode, - metric=metric, - estimators=estimators, - conf="panel", - return_fig=True, - ), - tight=True, - format="png", - sizing_mode="scale_height", - # sizing_mode="scale_both", - ) - - -def build_cr_tab(dr: DatasetReport): - _data = dr.data() - _metrics = _data.columns.unique(0) - _estimators = _data.columns.unique(1) - - valid_metrics = [m for m in _metrics if not m.endswith("_score")] - metric_widget = pn.widgets.Select( - name="metric", - value="acc", - options=valid_metrics, - align="center", - ) - - valid_estimators = [e for e in _estimators if e != "ref"] - estimators_widget = pn.widgets.CheckButtonGroup( - name="estimators", - options=valid_estimators, - value=valid_estimators, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - sizing_mode="scale_width", - ) - - valid_plot_modes = ["delta", "delta_stdev", "diagonal", "shift"] - plot_mode_widget = pn.widgets.RadioButtonGroup( - name="mode", - value=valid_plot_modes[0], - options=valid_plot_modes, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - sizing_mode="scale_width", - ) - - valid_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] - prevs_widget = pn.widgets.RadioButtonGroup( - name="train prevalence", - value=valid_prevs[0], - options=valid_prevs, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - ) - - plot_pane = pn.bind( - create_cr_plots, - dr=dr, - mode=plot_mode_widget, - metric=metric_widget, - estimators=estimators_widget, - prev=prevs_widget, - ) - - return pn.Row( - pn.Spacer(width=20), - pn.Column( - metric_widget, - pn.Row( - prevs_widget, - plot_mode_widget, - ), - estimators_widget, - align="center", - ), - pn.Spacer(sizing_mode="scale_width"), - plot_pane, - ) - - -def build_avg_tab(dr: DatasetReport): - _data = dr.data() - _metrics = _data.columns.unique(0) - _estimators = _data.columns.unique(1) - - valid_metrics = [m for m in _metrics if not m.endswith("_score")] - metric_widget = pn.widgets.Select( - name="metric", - value="acc", - options=valid_metrics, - align="center", - ) - - valid_estimators = [e for e in _estimators if e != "ref"] - estimators_widget = pn.widgets.CheckButtonGroup( - name="estimators", - options=valid_estimators, - value=valid_estimators, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - sizing_mode="scale_width", - ) - - valid_plot_modes = [ - "delta_train", - "stdev_train", - "delta_test", - "stdev_test", - "shift", - ] - plot_mode_widget = pn.widgets.RadioButtonGroup( - name="mode", - value=valid_plot_modes[0], - options=valid_plot_modes, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - sizing_mode="scale_width", - ) - - plot_pane = pn.bind( - create_avg_plots, - dr=dr, - mode=plot_mode_widget, - metric=metric_widget, - estimators=estimators_widget, - ) - - return pn.Row( - pn.Spacer(width=20), - pn.Column( - metric_widget, - plot_mode_widget, - estimators_widget, - align="center", - ), - pn.Spacer(sizing_mode="scale_width"), - plot_pane, - ) - - -def build_dataset(dataset_path: Path): - dr: DatasetReport = DatasetReport.unpickle(dataset_path) - - prevs_tab = ("train prevs.", build_cr_tab(dr)) - avg_tab = ("avg", build_avg_tab(dr)) - - app = pn.Tabs(objects=[avg_tab, prevs_tab], dynamic=False) - app.servable() - return app - - -def explore_datasets(root: Path | str): - if isinstance(root, str): - root = Path(root) - - if root.name == "plot": - return [] - - if not root.exists(): - return [] - - drs = [] - for f in os.listdir(root): - if (root / f).is_dir(): - drs += explore_datasets(root / f) - elif f == f"{root.name}.pickle": - drs.append((root, build_dataset(root / f))) - # drs.append((str(root),)) - - return drs - - -class PlotSelector(param.Parameterized): - metric = param.Selector(objects=["acc", "f1"]) - view = param.Selector(objects=["train prevs", "avg"]) - - -def plot_selector_widget(): - return pn.Param( - PlotSelector.param, - widgets={ - "metric": pn.widgets.Select, - "view": pn.widgets.Select, - }, - ) - - -def serve(address="localhost"): - # app = build_dataset(Path("output/rcv1_CCAT_9prevs/rcv1_CCAT_9prevs.pickle")) - __base_path = "output" - __tabs = sorted( - explore_datasets(__base_path), key=lambda t: (len(t[0].parts), t[0]) - ) - __tabs = [(str(p.relative_to(Path(__base_path))), d) for (p, d) in __tabs] - if len(__tabs) > 0: - app = pn.Tabs( - objects=__tabs, - tabs_location="left", - dynamic=False, - ) - else: - app = pn.Column( - pn.pane.Str("No Dataset Found", styles={"font-size": "24pt"}), - align="center", - ) - - __port = 33420 - __allowed = [address] - if address == "localhost": - __allowed.append("127.0.0.1") - - pn.serve( - app, - autoreload=True, - port=__port, - show=False, - address=address, - websocket_origin=[f"{_a}:{__port}" for _a in __allowed], - ) - - -def run(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--address", - action="store", - dest="address", - default="localhost", - ) - args = parser.parse_args() - serve(address=args.address) diff --git a/qcpanel/run.py b/qcpanel/run.py index fdaf2b0..ca6a477 100644 --- a/qcpanel/run.py +++ b/qcpanel/run.py @@ -1,10 +1,11 @@ import argparse import panel as pn +from panel.theme.fast import FastDarkTheme, FastDefaultTheme from qcpanel.viewer import QuaccTestViewer -# pn.config.design = pn.theme.Bootstrap +# pn.config.design = Fast # pn.config.theme = "dark" pn.config.notifications = True @@ -59,8 +60,8 @@ def app_instance(): ], main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")], modal=[qtv.modal_pane], - theme=pn.theme.DarkTheme, - theme_toggle=False, + # theme=FastDefaultTheme, + theme_toggle=True, ) app.servable()