QuAcc/qcdash/app.py

507 lines
15 KiB
Python

import json
import os
from collections import defaultdict
from json import JSONDecodeError
from operator import index
from pathlib import Path
from typing import List
from urllib.parse import parse_qsl, quote, urlencode, urlparse
import dash_bootstrap_components as dbc
import numpy as np
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
from dash.dash_table.Format import Align, Format, Scheme
from quacc import plot
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import wilcoxon
valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes
root_folder = "output"
def _get_prev_str(prev: np.ndarray):
return str(tuple(np.around(prev, decimals=2)))
def get_datasets(root: str | Path) -> List[DatasetReport]:
def load_dataset(dataset):
dataset = Path(dataset)
return DatasetReport.unpickle(dataset)
def explore_datasets(root: str | Path) -> List[Path]:
if isinstance(root, str):
root = Path(root)
if root.name == "plot":
return []
if not root.exists():
return []
dr_paths = []
for f in os.listdir(root):
if (root / f).is_dir():
dr_paths += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
dr_paths.append(root / f)
return dr_paths
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None):
_backend = backend or plot.get_backend("plotly")
estimators = CE.name[estimators]
match (view, mode):
case ("avg", _):
return dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
backend=_backend,
)
case (_, _):
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
return cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
backend=_backend,
)
def get_table(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
match (view, mode):
case ("avg", "train_table"):
# return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
return dr.train_table(metric=metric, estimators=estimators)
case ("avg", "train_std_table"):
return dr.train_std_table(metric=metric, estimators=estimators)
case ("avg", "test_table"):
# return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
return dr.test_table(metric=metric, estimators=estimators)
case ("avg", "shift_table"):
# return (
# dr.shift_data(metric=metric, estimators=estimators)
# .groupby(level=0)
# .mean()
# )
return dr.shift_table(metric=metric, estimators=estimators)
case ("avg", "stats_table"):
return wilcoxon(dr, metric=metric, estimators=estimators)
case (_, "train_table"):
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
# return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
return cr.train_table(metric=metric, estimators=estimators)
case (_, "shift_table"):
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
# return (
# cr.shift_data(metric=metric, estimators=estimators)
# .groupby(level=0)
# .mean()
# )
return cr.shift_table(metric=metric, estimators=estimators)
case (_, "stats_table"):
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
return wilcoxon(cr, metric=metric, estimators=estimators)
def get_DataTable(df, mode):
_primary = "#0d6efd"
if df.empty:
return None
_index_name = dict(
train_table="test prev.",
train_std_table="train prev.",
test_table="train prev.",
shift_table="shift",
stats_table="method",
)
df = df.reset_index()
if mode == "train_std_table":
columns_format = Format()
df_columns = np.concatenate([["index"], df.columns.unique(1)[1:]])
data = [
dict(
index="(" + ", ".join([f"{v:.2f}" for v in idx]) + ")"
if isinstance(idx, tuple | list | np.ndarray)
else str(idx)
)
| {
k: f"{df.loc[i,('avg',k)]:.4f}~{df.loc[i,('std',k)]:.3f}"
for k in df.columns.unique(1)[1:]
}
for i, idx in zip(df.index, df.loc[:, ("index", "")])
]
else:
columns_format = Format(precision=6, scheme=Scheme.exponent, nully="nan")
df_columns = df.columns
data = df.to_dict("records")
columns = {
c: dict(
id=c,
name=_index_name[mode] if c == "index" else c,
type="numeric",
format=columns_format,
)
for c in df_columns
}
columns["index"]["format"] = Format()
columns = list(columns.values())
for d in data:
if isinstance(d["index"], tuple | list | np.ndarray):
d["index"] = "(" + ", ".join([f"{v:.2f}" for v in d["index"]]) + ")"
elif isinstance(d["index"], float):
d["index"] = f"{d['index']:.2f}"
_style_cell = {
"padding": "0 12px",
"border": "0",
"border-bottom": f"1px solid {_primary}",
}
_style_cell_conditional = [
{
"if": {"column_id": "index"},
"text_align": "center",
},
]
_style_data_conditional = []
if mode != "stats_table":
_style_data_conditional += [
{
"if": {"column_id": "index", "row_index": len(data) - 1},
"font_weight": "bold",
},
{
"if": {"row_index": len(data) - 1},
"background_color": "#0d6efd",
"color": "white",
},
]
_style_table = {
"margin": "6vh 15px",
"padding": "15px",
"maxWidth": "80vw",
"overflowX": "auto",
"border": f"0px solid {_primary}",
"border-radius": "6px",
}
return html.Div(
[
dash_table.DataTable(
data=data,
columns=columns,
id="table1",
style_cell=_style_cell,
style_cell_conditional=_style_cell_conditional,
style_data_conditional=_style_data_conditional,
style_table=_style_table,
)
],
style={
"display": "flex",
"flex-direction": "column",
# "justify-content": "center",
"align-items": "center",
"height": "100vh",
},
)
def get_Graph(fig):
if fig is None:
return None
return dcc.Graph(
id="graph1",
figure=fig,
style={
"margin": 0,
"height": "100vh",
},
)
datasets = get_datasets(root_folder)
def get_dr(root, dataset):
ds = str(Path(root) / dataset)
return datasets[ds]
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# app.config.suppress_callback_exceptions = True
sidebar_style = {
"top": 0,
"left": 0,
"bottom": 0,
"padding": "1vw",
"padding-top": "2vw",
"margin": "0px",
"flex": 1,
"overflow": "scroll",
"height": "100vh",
}
content_style = {
"flex": 5,
"maxWidth": "84vw",
}
def parse_href(href: str):
parse_result = urlparse(href)
params = parse_qsl(parse_result.query)
return dict(params)
def get_sidebar():
return [
html.H4("Parameters:", style={"margin-bottom": "1vw"}),
dbc.Select(
# options=list(datasets.keys()),
# value=list(datasets.keys())[0],
id="dataset",
),
dbc.Select(
id="metric",
style={"margin-top": "1vh"},
),
html.Div(
[
dbc.RadioItems(
id="view",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
dbc.RadioItems(
id="mode",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
],
className="radio-group-v d-flex justify-content-around",
),
html.Div(
[
dbc.Checklist(
id="estimators",
className="btn-group mt-3",
inputClassName="btn-check",
labelClassName="btn btn-outline-primary",
labelCheckedClassName="active",
),
],
className="radio-group-wide",
),
]
app.layout = html.Div(
[
dcc.Interval(id="reload", interval=10 * 60 * 1000),
dcc.Location(id="url", refresh=False),
dcc.Store(id="root", storage_type="session", data=root_folder),
html.Div(
[
html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
html.Div(id="app_content", style=content_style),
],
id="page_layout",
style={"display": "flex", "flexDirection": "row"},
),
]
)
server = app.server
def apply_param(href, triggered_id, id, curr):
match triggered_id:
case "url":
params = parse_href(href)
return params.get(id, None)
case _:
return curr
@callback(
Output("dataset", "value"),
Output("dataset", "options"),
Output("root", "data"),
Input("url", "href"),
Input("reload", "n_intervals"),
State("dataset", "value"),
)
def update_dataset(href, n_intervals, dataset):
if ctx.triggered_id == "reload":
new_datasets = get_datasets("output")
global datasets
datasets = new_datasets
params = parse_href(href)
req_dataset = params.get("dataset", None)
root = params.get("root", "")
def filter_datasets(root, available_datasets):
return [
str(Path(d).relative_to(root))
for d in available_datasets
if d.startswith(root)
]
available_datasets = filter_datasets(root, datasets.keys())
new_dataset = (
req_dataset if req_dataset in available_datasets else available_datasets[0]
)
return new_dataset, available_datasets, root
@callback(
Output("metric", "options"),
Output("metric", "value"),
Input("url", "href"),
Input("dataset", "value"),
State("metric", "value"),
State("root", "data"),
)
def update_metrics(href, dataset, curr_metric, root):
dr = get_dr(root, dataset)
old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
return valid_metrics, new_metric
@callback(
Output("estimators", "options"),
Output("estimators", "value"),
Input("url", "href"),
Input("dataset", "value"),
Input("metric", "value"),
State("estimators", "value"),
State("root", "data"),
)
def update_estimators(href, dataset, metric, curr_estimators, root):
dr = get_dr(root, dataset)
old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
if isinstance(old_estimators, str):
try:
old_estimators = json.loads(old_estimators)
except JSONDecodeError:
old_estimators = []
valid_estimators: np.ndarray = dr.data(metric=metric).columns.unique(0).to_numpy()
new_estimators = valid_estimators[
np.isin(valid_estimators, old_estimators)
].tolist()
valid_estimators = CE.name.sort(valid_estimators.tolist())
return valid_estimators, new_estimators
@callback(
Output("view", "options"),
Output("view", "value"),
Input("url", "href"),
Input("dataset", "value"),
State("view", "value"),
State("root", "data"),
)
def update_view(href, dataset, curr_view, root):
dr = get_dr(root, dataset)
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
valid_views = ["avg"] + [_get_prev_str(cr.train_prev) for cr in dr.crs]
new_view = old_view if old_view in valid_views else valid_views[0]
return valid_views, new_view
@callback(
Output("mode", "options"),
Output("mode", "value"),
Input("url", "href"),
Input("view", "value"),
State("mode", "value"),
)
def update_mode(href, view, curr_mode):
old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode)
valid_modes = valid_plot_modes[view]
new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
return valid_modes, new_mode
@callback(
Output("app_content", "children"),
Output("url", "search"),
Input("dataset", "value"),
Input("metric", "value"),
Input("estimators", "value"),
Input("view", "value"),
Input("mode", "value"),
State("root", "data"),
)
def update_content(dataset, metric, estimators, view, mode, root):
search = urlencode(
dict(
dataset=dataset,
metric=metric,
estimators=json.dumps(estimators),
view=view,
mode=mode,
root=root,
),
quote_via=quote,
)
dr = get_dr(root, dataset)
match mode:
case m if m.endswith("table"):
df = get_table(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
dt = get_DataTable(df, mode)
app_content = [] if dt is None else [dt]
case _:
fig = get_fig(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
g = get_Graph(fig)
app_content = [] if g is None else [g]
return app_content, f"?{search}"
def run():
app.run(debug=True)
if __name__ == "__main__":
run()