theme updated
This commit is contained in:
parent
7fe36308fb
commit
deeb522ccb
|
@ -1,290 +0,0 @@
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import panel as pn
|
|
||||||
import param
|
|
||||||
|
|
||||||
from quacc.evaluation.estimators import CE
|
|
||||||
from quacc.evaluation.report import DatasetReport
|
|
||||||
|
|
||||||
pn.extension(design="bootstrap")
|
|
||||||
|
|
||||||
|
|
||||||
def create_cr_plots(
|
|
||||||
dr: DatasetReport,
|
|
||||||
mode="delta",
|
|
||||||
metric="acc",
|
|
||||||
estimators=None,
|
|
||||||
prev=None,
|
|
||||||
):
|
|
||||||
idx = [round(cr.train_prev[1] * 100) for cr in dr.crs].index(prev)
|
|
||||||
cr = dr.crs[idx]
|
|
||||||
estimators = CE.name[estimators]
|
|
||||||
_dpi = 112
|
|
||||||
return pn.pane.Matplotlib(
|
|
||||||
cr.get_plots(
|
|
||||||
mode=mode,
|
|
||||||
metric=metric,
|
|
||||||
estimators=estimators,
|
|
||||||
conf="panel",
|
|
||||||
return_fig=True,
|
|
||||||
),
|
|
||||||
tight=True,
|
|
||||||
format="png",
|
|
||||||
sizing_mode="scale_height",
|
|
||||||
# sizing_mode="scale_both",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_avg_plots(
|
|
||||||
dr: DatasetReport,
|
|
||||||
mode="delta",
|
|
||||||
metric="acc",
|
|
||||||
estimators=None,
|
|
||||||
prev=None,
|
|
||||||
):
|
|
||||||
estimators = CE.name[estimators]
|
|
||||||
return pn.pane.Matplotlib(
|
|
||||||
dr.get_plots(
|
|
||||||
mode=mode,
|
|
||||||
metric=metric,
|
|
||||||
estimators=estimators,
|
|
||||||
conf="panel",
|
|
||||||
return_fig=True,
|
|
||||||
),
|
|
||||||
tight=True,
|
|
||||||
format="png",
|
|
||||||
sizing_mode="scale_height",
|
|
||||||
# sizing_mode="scale_both",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_cr_tab(dr: DatasetReport):
|
|
||||||
_data = dr.data()
|
|
||||||
_metrics = _data.columns.unique(0)
|
|
||||||
_estimators = _data.columns.unique(1)
|
|
||||||
|
|
||||||
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
|
|
||||||
metric_widget = pn.widgets.Select(
|
|
||||||
name="metric",
|
|
||||||
value="acc",
|
|
||||||
options=valid_metrics,
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_estimators = [e for e in _estimators if e != "ref"]
|
|
||||||
estimators_widget = pn.widgets.CheckButtonGroup(
|
|
||||||
name="estimators",
|
|
||||||
options=valid_estimators,
|
|
||||||
value=valid_estimators,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_plot_modes = ["delta", "delta_stdev", "diagonal", "shift"]
|
|
||||||
plot_mode_widget = pn.widgets.RadioButtonGroup(
|
|
||||||
name="mode",
|
|
||||||
value=valid_plot_modes[0],
|
|
||||||
options=valid_plot_modes,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
|
|
||||||
prevs_widget = pn.widgets.RadioButtonGroup(
|
|
||||||
name="train prevalence",
|
|
||||||
value=valid_prevs[0],
|
|
||||||
options=valid_prevs,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
)
|
|
||||||
|
|
||||||
plot_pane = pn.bind(
|
|
||||||
create_cr_plots,
|
|
||||||
dr=dr,
|
|
||||||
mode=plot_mode_widget,
|
|
||||||
metric=metric_widget,
|
|
||||||
estimators=estimators_widget,
|
|
||||||
prev=prevs_widget,
|
|
||||||
)
|
|
||||||
|
|
||||||
return pn.Row(
|
|
||||||
pn.Spacer(width=20),
|
|
||||||
pn.Column(
|
|
||||||
metric_widget,
|
|
||||||
pn.Row(
|
|
||||||
prevs_widget,
|
|
||||||
plot_mode_widget,
|
|
||||||
),
|
|
||||||
estimators_widget,
|
|
||||||
align="center",
|
|
||||||
),
|
|
||||||
pn.Spacer(sizing_mode="scale_width"),
|
|
||||||
plot_pane,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_avg_tab(dr: DatasetReport):
|
|
||||||
_data = dr.data()
|
|
||||||
_metrics = _data.columns.unique(0)
|
|
||||||
_estimators = _data.columns.unique(1)
|
|
||||||
|
|
||||||
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
|
|
||||||
metric_widget = pn.widgets.Select(
|
|
||||||
name="metric",
|
|
||||||
value="acc",
|
|
||||||
options=valid_metrics,
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_estimators = [e for e in _estimators if e != "ref"]
|
|
||||||
estimators_widget = pn.widgets.CheckButtonGroup(
|
|
||||||
name="estimators",
|
|
||||||
options=valid_estimators,
|
|
||||||
value=valid_estimators,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_plot_modes = [
|
|
||||||
"delta_train",
|
|
||||||
"stdev_train",
|
|
||||||
"delta_test",
|
|
||||||
"stdev_test",
|
|
||||||
"shift",
|
|
||||||
]
|
|
||||||
plot_mode_widget = pn.widgets.RadioButtonGroup(
|
|
||||||
name="mode",
|
|
||||||
value=valid_plot_modes[0],
|
|
||||||
options=valid_plot_modes,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
)
|
|
||||||
|
|
||||||
plot_pane = pn.bind(
|
|
||||||
create_avg_plots,
|
|
||||||
dr=dr,
|
|
||||||
mode=plot_mode_widget,
|
|
||||||
metric=metric_widget,
|
|
||||||
estimators=estimators_widget,
|
|
||||||
)
|
|
||||||
|
|
||||||
return pn.Row(
|
|
||||||
pn.Spacer(width=20),
|
|
||||||
pn.Column(
|
|
||||||
metric_widget,
|
|
||||||
plot_mode_widget,
|
|
||||||
estimators_widget,
|
|
||||||
align="center",
|
|
||||||
),
|
|
||||||
pn.Spacer(sizing_mode="scale_width"),
|
|
||||||
plot_pane,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_dataset(dataset_path: Path):
|
|
||||||
dr: DatasetReport = DatasetReport.unpickle(dataset_path)
|
|
||||||
|
|
||||||
prevs_tab = ("train prevs.", build_cr_tab(dr))
|
|
||||||
avg_tab = ("avg", build_avg_tab(dr))
|
|
||||||
|
|
||||||
app = pn.Tabs(objects=[avg_tab, prevs_tab], dynamic=False)
|
|
||||||
app.servable()
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
def explore_datasets(root: Path | str):
|
|
||||||
if isinstance(root, str):
|
|
||||||
root = Path(root)
|
|
||||||
|
|
||||||
if root.name == "plot":
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not root.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
drs = []
|
|
||||||
for f in os.listdir(root):
|
|
||||||
if (root / f).is_dir():
|
|
||||||
drs += explore_datasets(root / f)
|
|
||||||
elif f == f"{root.name}.pickle":
|
|
||||||
drs.append((root, build_dataset(root / f)))
|
|
||||||
# drs.append((str(root),))
|
|
||||||
|
|
||||||
return drs
|
|
||||||
|
|
||||||
|
|
||||||
class PlotSelector(param.Parameterized):
|
|
||||||
metric = param.Selector(objects=["acc", "f1"])
|
|
||||||
view = param.Selector(objects=["train prevs", "avg"])
|
|
||||||
|
|
||||||
|
|
||||||
def plot_selector_widget():
|
|
||||||
return pn.Param(
|
|
||||||
PlotSelector.param,
|
|
||||||
widgets={
|
|
||||||
"metric": pn.widgets.Select,
|
|
||||||
"view": pn.widgets.Select,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def serve(address="localhost"):
|
|
||||||
# app = build_dataset(Path("output/rcv1_CCAT_9prevs/rcv1_CCAT_9prevs.pickle"))
|
|
||||||
__base_path = "output"
|
|
||||||
__tabs = sorted(
|
|
||||||
explore_datasets(__base_path), key=lambda t: (len(t[0].parts), t[0])
|
|
||||||
)
|
|
||||||
__tabs = [(str(p.relative_to(Path(__base_path))), d) for (p, d) in __tabs]
|
|
||||||
if len(__tabs) > 0:
|
|
||||||
app = pn.Tabs(
|
|
||||||
objects=__tabs,
|
|
||||||
tabs_location="left",
|
|
||||||
dynamic=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
app = pn.Column(
|
|
||||||
pn.pane.Str("No Dataset Found", styles={"font-size": "24pt"}),
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
|
|
||||||
__port = 33420
|
|
||||||
__allowed = [address]
|
|
||||||
if address == "localhost":
|
|
||||||
__allowed.append("127.0.0.1")
|
|
||||||
|
|
||||||
pn.serve(
|
|
||||||
app,
|
|
||||||
autoreload=True,
|
|
||||||
port=__port,
|
|
||||||
show=False,
|
|
||||||
address=address,
|
|
||||||
websocket_origin=[f"{_a}:{__port}" for _a in __allowed],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--address",
|
|
||||||
action="store",
|
|
||||||
dest="address",
|
|
||||||
default="localhost",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
serve(address=args.address)
|
|
|
@ -1,10 +1,11 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import panel as pn
|
import panel as pn
|
||||||
|
from panel.theme.fast import FastDarkTheme, FastDefaultTheme
|
||||||
|
|
||||||
from qcpanel.viewer import QuaccTestViewer
|
from qcpanel.viewer import QuaccTestViewer
|
||||||
|
|
||||||
# pn.config.design = pn.theme.Bootstrap
|
# pn.config.design = Fast
|
||||||
# pn.config.theme = "dark"
|
# pn.config.theme = "dark"
|
||||||
pn.config.notifications = True
|
pn.config.notifications = True
|
||||||
|
|
||||||
|
@ -59,8 +60,8 @@ def app_instance():
|
||||||
],
|
],
|
||||||
main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")],
|
main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")],
|
||||||
modal=[qtv.modal_pane],
|
modal=[qtv.modal_pane],
|
||||||
theme=pn.theme.DarkTheme,
|
# theme=FastDefaultTheme,
|
||||||
theme_toggle=False,
|
theme_toggle=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.servable()
|
app.servable()
|
||||||
|
|
Loading…
Reference in New Issue