QuAcc/qcdash/app.py

507 lines
15 KiB
Python
Raw Normal View History

2023-11-30 03:11:09 +01:00
import json
2023-11-29 03:56:17 +01:00
import os
from collections import defaultdict
2023-11-30 03:11:09 +01:00
from json import JSONDecodeError
2023-12-21 16:47:07 +01:00
from operator import index
2023-11-28 09:19:21 +01:00
from pathlib import Path
2023-11-29 03:56:17 +01:00
from typing import List
2023-11-30 03:11:09 +01:00
from urllib.parse import parse_qsl, quote, urlencode, urlparse
2023-11-28 09:19:21 +01:00
2023-11-29 03:56:17 +01:00
import dash_bootstrap_components as dbc
2023-11-28 09:19:21 +01:00
import numpy as np
2023-11-30 03:11:09 +01:00
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
2023-12-21 16:47:07 +01:00
from dash.dash_table.Format import Align, Format, Scheme
2023-11-28 09:19:21 +01:00
2023-11-29 03:56:17 +01:00
from quacc import plot
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
2023-11-30 18:39:49 +01:00
from quacc.evaluation.stats import wilcoxon
2023-11-28 09:19:21 +01:00
2023-11-29 03:56:17 +01:00
valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes
2023-12-01 13:22:53 +01:00
root_folder = "output"
2023-11-28 09:19:21 +01:00
2023-12-21 16:47:07 +01:00
def _get_prev_str(prev: np.ndarray):
return str(tuple(np.around(prev, decimals=2)))
2023-11-29 03:56:17 +01:00
def get_datasets(root: str | Path) -> List[DatasetReport]:
2023-11-30 18:28:14 +01:00
def load_dataset(dataset):
dataset = Path(dataset)
return DatasetReport.unpickle(dataset)
2023-11-29 03:56:17 +01:00
def explore_datasets(root: str | Path) -> List[Path]:
if isinstance(root, str):
root = Path(root)
2023-11-28 09:19:21 +01:00
2023-11-29 03:56:17 +01:00
if root.name == "plot":
return []
2023-11-28 09:19:21 +01:00
2023-11-29 03:56:17 +01:00
if not root.exists():
return []
2023-11-28 09:19:21 +01:00
2023-11-29 03:56:17 +01:00
dr_paths = []
for f in os.listdir(root):
if (root / f).is_dir():
dr_paths += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
dr_paths.append(root / f)
2023-11-28 09:19:21 +01:00
2023-11-29 03:56:17 +01:00
return dr_paths
2023-11-30 18:28:14 +01:00
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
2023-11-29 03:56:17 +01:00
2023-12-01 10:48:50 +01:00
def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None):
_backend = backend or plot.get_backend("plotly")
2023-11-29 03:56:17 +01:00
estimators = CE.name[estimators]
match (view, mode):
case ("avg", _):
return dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
2023-12-01 10:48:50 +01:00
backend=_backend,
2023-11-29 03:56:17 +01:00
)
case (_, _):
2023-12-21 16:47:07 +01:00
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
2023-11-29 03:56:17 +01:00
return cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
2023-12-01 10:48:50 +01:00
backend=_backend,
2023-11-29 03:56:17 +01:00
)
2023-11-30 03:11:09 +01:00
def get_table(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
match (view, mode):
case ("avg", "train_table"):
2023-12-21 16:47:07 +01:00
# return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
return dr.train_table(metric=metric, estimators=estimators)
2024-01-30 13:56:17 +01:00
case ("avg", "train_std_table"):
return dr.train_std_table(metric=metric, estimators=estimators)
2023-11-30 03:11:09 +01:00
case ("avg", "test_table"):
2023-12-21 16:47:07 +01:00
# return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
return dr.test_table(metric=metric, estimators=estimators)
2023-11-30 03:11:09 +01:00
case ("avg", "shift_table"):
2023-12-21 16:47:07 +01:00
# return (
# dr.shift_data(metric=metric, estimators=estimators)
# .groupby(level=0)
# .mean()
# )
return dr.shift_table(metric=metric, estimators=estimators)
2023-11-30 03:11:09 +01:00
case ("avg", "stats_table"):
2023-11-30 18:39:49 +01:00
return wilcoxon(dr, metric=metric, estimators=estimators)
2023-11-30 03:11:09 +01:00
case (_, "train_table"):
2023-12-21 16:47:07 +01:00
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
# return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
return cr.train_table(metric=metric, estimators=estimators)
2023-11-30 03:11:09 +01:00
case (_, "shift_table"):
2023-12-21 16:47:07 +01:00
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
# return (
# cr.shift_data(metric=metric, estimators=estimators)
# .groupby(level=0)
# .mean()
# )
return cr.shift_table(metric=metric, estimators=estimators)
2023-11-30 18:39:49 +01:00
case (_, "stats_table"):
2023-12-21 16:47:07 +01:00
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
2023-11-30 18:39:49 +01:00
return wilcoxon(cr, metric=metric, estimators=estimators)
2023-11-30 03:11:09 +01:00
2023-12-21 16:47:07 +01:00
def get_DataTable(df, mode):
2023-11-30 03:11:09 +01:00
_primary = "#0d6efd"
if df.empty:
return None
2023-12-21 16:47:07 +01:00
_index_name = dict(
train_table="test prev.",
2024-01-30 13:56:17 +01:00
train_std_table="train prev.",
2023-12-21 16:47:07 +01:00
test_table="train prev.",
shift_table="shift",
stats_table="method",
)
2023-11-30 03:11:09 +01:00
df = df.reset_index()
2024-01-30 13:56:17 +01:00
if mode == "train_std_table":
columns_format = Format()
df_columns = np.concatenate([["index"], df.columns.unique(1)[1:]])
data = [
dict(
index="(" + ", ".join([f"{v:.2f}" for v in idx]) + ")"
if isinstance(idx, tuple | list | np.ndarray)
else str(idx)
)
| {
k: f"{df.loc[i,('avg',k)]:.4f}~{df.loc[i,('std',k)]:.3f}"
for k in df.columns.unique(1)[1:]
}
for i, idx in zip(df.index, df.loc[:, ("index", "")])
]
else:
columns_format = Format(precision=6, scheme=Scheme.exponent, nully="nan")
df_columns = df.columns
data = df.to_dict("records")
2023-11-30 03:11:09 +01:00
columns = {
c: dict(
id=c,
2023-12-21 16:47:07 +01:00
name=_index_name[mode] if c == "index" else c,
2023-11-30 03:11:09 +01:00
type="numeric",
2024-01-30 13:56:17 +01:00
format=columns_format,
2023-11-30 03:11:09 +01:00
)
2024-01-30 13:56:17 +01:00
for c in df_columns
2023-11-30 03:11:09 +01:00
}
2023-12-21 16:47:07 +01:00
columns["index"]["format"] = Format()
2023-11-30 03:11:09 +01:00
columns = list(columns.values())
2023-12-21 16:47:07 +01:00
for d in data:
if isinstance(d["index"], tuple | list | np.ndarray):
d["index"] = "(" + ", ".join([f"{v:.2f}" for v in d["index"]]) + ")"
elif isinstance(d["index"], float):
d["index"] = f"{d['index']:.2f}"
_style_cell = {
"padding": "0 12px",
"border": "0",
"border-bottom": f"1px solid {_primary}",
}
_style_cell_conditional = [
{
"if": {"column_id": "index"},
"text_align": "center",
},
]
_style_data_conditional = []
if mode != "stats_table":
_style_data_conditional += [
{
"if": {"column_id": "index", "row_index": len(data) - 1},
"font_weight": "bold",
},
{
"if": {"row_index": len(data) - 1},
"background_color": "#0d6efd",
"color": "white",
},
]
_style_table = {
"margin": "6vh 15px",
"padding": "15px",
"maxWidth": "80vw",
"overflowX": "auto",
"border": f"0px solid {_primary}",
"border-radius": "6px",
}
2023-11-30 03:11:09 +01:00
return html.Div(
[
dash_table.DataTable(
data=data,
columns=columns,
id="table1",
2023-12-21 16:47:07 +01:00
style_cell=_style_cell,
style_cell_conditional=_style_cell_conditional,
style_data_conditional=_style_data_conditional,
style_table=_style_table,
2023-11-30 03:11:09 +01:00
)
],
style={
"display": "flex",
"flex-direction": "column",
# "justify-content": "center",
"align-items": "center",
"height": "100vh",
},
)
2023-11-29 03:56:17 +01:00
2023-11-30 03:11:09 +01:00
def get_Graph(fig):
if fig is None:
return None
return dcc.Graph(
id="graph1",
figure=fig,
style={
"margin": 0,
"height": "100vh",
},
)
2023-12-01 13:22:53 +01:00
datasets = get_datasets(root_folder)
def get_dr(root, dataset):
ds = str(Path(root) / dataset)
return datasets[ds]
2023-11-30 03:11:09 +01:00
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# app.config.suppress_callback_exceptions = True
2023-11-29 03:56:17 +01:00
sidebar_style = {
"top": 0,
"left": 0,
"bottom": 0,
"padding": "1vw",
"padding-top": "2vw",
"margin": "0px",
2023-11-30 18:28:14 +01:00
"flex": 1,
"overflow": "scroll",
"height": "100vh",
2023-11-29 03:56:17 +01:00
}
content_style = {
2023-11-30 18:28:14 +01:00
"flex": 5,
"maxWidth": "84vw",
2023-11-29 03:56:17 +01:00
}
2023-11-30 03:11:09 +01:00
def parse_href(href: str):
parse_result = urlparse(href)
params = parse_qsl(parse_result.query)
return dict(params)
def get_sidebar():
return [
2023-11-29 03:56:17 +01:00
html.H4("Parameters:", style={"margin-bottom": "1vw"}),
dbc.Select(
2023-11-30 03:11:09 +01:00
# options=list(datasets.keys()),
# value=list(datasets.keys())[0],
2023-11-29 03:56:17 +01:00
id="dataset",
),
dbc.Select(
id="metric",
style={"margin-top": "1vh"},
),
html.Div(
[
dbc.RadioItems(
id="view",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
dbc.RadioItems(
id="mode",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
],
className="radio-group-v d-flex justify-content-around",
),
html.Div(
[
dbc.Checklist(
id="estimators",
className="btn-group mt-3",
inputClassName="btn-check",
labelClassName="btn btn-outline-primary",
labelCheckedClassName="active",
),
],
className="radio-group-wide",
),
2023-11-30 03:11:09 +01:00
]
2023-11-29 03:56:17 +01:00
2023-11-30 03:11:09 +01:00
app.layout = html.Div(
[
2023-11-30 18:28:14 +01:00
dcc.Interval(id="reload", interval=10 * 60 * 1000),
2023-11-30 03:11:09 +01:00
dcc.Location(id="url", refresh=False),
2023-12-01 13:22:53 +01:00
dcc.Store(id="root", storage_type="session", data=root_folder),
2023-11-30 03:11:09 +01:00
html.Div(
[
html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
html.Div(id="app_content", style=content_style),
],
id="page_layout",
style={"display": "flex", "flexDirection": "row"},
2023-11-29 03:56:17 +01:00
),
2023-11-30 03:11:09 +01:00
]
2023-11-29 03:56:17 +01:00
)
2023-11-30 18:28:14 +01:00
server = app.server
2023-11-30 03:11:09 +01:00
def apply_param(href, triggered_id, id, curr):
match triggered_id:
case "url":
params = parse_href(href)
return params.get(id, None)
case _:
return curr
@callback(
Output("dataset", "value"),
Output("dataset", "options"),
2023-12-01 13:22:53 +01:00
Output("root", "data"),
2023-11-30 03:11:09 +01:00
Input("url", "href"),
2023-11-30 18:28:14 +01:00
Input("reload", "n_intervals"),
State("dataset", "value"),
2023-11-29 03:56:17 +01:00
)
2023-11-30 18:28:14 +01:00
def update_dataset(href, n_intervals, dataset):
2023-12-01 13:22:53 +01:00
if ctx.triggered_id == "reload":
new_datasets = get_datasets("output")
global datasets
datasets = new_datasets
params = parse_href(href)
req_dataset = params.get("dataset", None)
root = params.get("root", "")
def filter_datasets(root, available_datasets):
return [
str(Path(d).relative_to(root))
for d in available_datasets
if d.startswith(root)
]
available_datasets = filter_datasets(root, datasets.keys())
2023-11-30 18:28:14 +01:00
new_dataset = (
req_dataset if req_dataset in available_datasets else available_datasets[0]
)
2023-12-01 13:22:53 +01:00
return new_dataset, available_datasets, root
2023-11-29 03:56:17 +01:00
@callback(
Output("metric", "options"),
Output("metric", "value"),
2023-11-30 03:11:09 +01:00
Input("url", "href"),
2023-11-29 03:56:17 +01:00
Input("dataset", "value"),
State("metric", "value"),
2023-12-01 13:22:53 +01:00
State("root", "data"),
2023-11-29 03:56:17 +01:00
)
2023-12-01 13:22:53 +01:00
def update_metrics(href, dataset, curr_metric, root):
dr = get_dr(root, dataset)
2023-11-30 03:11:09 +01:00
old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
2023-11-29 03:56:17 +01:00
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
return valid_metrics, new_metric
@callback(
Output("estimators", "options"),
Output("estimators", "value"),
2023-11-30 03:11:09 +01:00
Input("url", "href"),
2023-11-29 03:56:17 +01:00
Input("dataset", "value"),
Input("metric", "value"),
State("estimators", "value"),
2023-12-01 13:22:53 +01:00
State("root", "data"),
2023-11-29 03:56:17 +01:00
)
2023-12-01 13:22:53 +01:00
def update_estimators(href, dataset, metric, curr_estimators, root):
dr = get_dr(root, dataset)
2023-11-30 03:11:09 +01:00
old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
if isinstance(old_estimators, str):
try:
old_estimators = json.loads(old_estimators)
except JSONDecodeError:
old_estimators = []
2023-12-02 02:06:53 +01:00
valid_estimators: np.ndarray = dr.data(metric=metric).columns.unique(0).to_numpy()
2023-11-29 03:56:17 +01:00
new_estimators = valid_estimators[
np.isin(valid_estimators, old_estimators)
].tolist()
2023-12-02 02:06:53 +01:00
valid_estimators = CE.name.sort(valid_estimators.tolist())
2023-11-29 03:56:17 +01:00
return valid_estimators, new_estimators
@callback(
Output("view", "options"),
Output("view", "value"),
2023-11-30 03:11:09 +01:00
Input("url", "href"),
2023-11-29 03:56:17 +01:00
Input("dataset", "value"),
State("view", "value"),
2023-12-01 13:22:53 +01:00
State("root", "data"),
2023-11-29 03:56:17 +01:00
)
2023-12-01 13:22:53 +01:00
def update_view(href, dataset, curr_view, root):
dr = get_dr(root, dataset)
2023-11-30 03:11:09 +01:00
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
2023-12-21 16:47:07 +01:00
valid_views = ["avg"] + [_get_prev_str(cr.train_prev) for cr in dr.crs]
2023-11-29 03:56:17 +01:00
new_view = old_view if old_view in valid_views else valid_views[0]
return valid_views, new_view
@callback(
Output("mode", "options"),
Output("mode", "value"),
2023-11-30 03:11:09 +01:00
Input("url", "href"),
2023-11-29 03:56:17 +01:00
Input("view", "value"),
State("mode", "value"),
)
2023-11-30 03:11:09 +01:00
def update_mode(href, view, curr_mode):
old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode)
valid_modes = valid_plot_modes[view]
2023-11-29 03:56:17 +01:00
new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
return valid_modes, new_mode
@callback(
2023-11-30 03:11:09 +01:00
Output("app_content", "children"),
Output("url", "search"),
2023-11-29 03:56:17 +01:00
Input("dataset", "value"),
Input("metric", "value"),
Input("estimators", "value"),
Input("view", "value"),
Input("mode", "value"),
2023-12-01 13:22:53 +01:00
State("root", "data"),
2023-11-29 03:56:17 +01:00
)
2023-12-01 13:22:53 +01:00
def update_content(dataset, metric, estimators, view, mode, root):
2023-11-30 03:11:09 +01:00
search = urlencode(
dict(
dataset=dataset,
metric=metric,
estimators=json.dumps(estimators),
view=view,
mode=mode,
2023-12-01 13:22:53 +01:00
root=root,
2023-11-30 03:11:09 +01:00
),
quote_via=quote,
)
2023-12-01 13:22:53 +01:00
dr = get_dr(root, dataset)
2023-11-30 03:11:09 +01:00
match mode:
case m if m.endswith("table"):
df = get_table(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
2023-12-21 16:47:07 +01:00
dt = get_DataTable(df, mode)
2023-11-30 03:11:09 +01:00
app_content = [] if dt is None else [dt]
case _:
fig = get_fig(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
g = get_Graph(fig)
app_content = [] if g is None else [g]
return app_content, f"?{search}"
2023-11-28 09:19:21 +01:00
def run():
app.run(debug=True)
if __name__ == "__main__":
run()