dash base implementation completed

This commit is contained in:
Lorenzo Volpi 2023-11-29 03:56:17 +01:00
parent c670f48b5b
commit 020728ed5d
4 changed files with 684 additions and 30 deletions

396
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -18,6 +18,7 @@ abstention = "^0.1.3.1"
main = "quacc.main:main"
run = "run:run"
panel = "qcpanel.run:run"
dash = "qcdash.app:run"
sync_up = "remote:sync_code"
sync_down = "remote:sync_output"
merge_data = "merge_data:run"
@ -48,6 +49,8 @@ ipympl = "^0.9.3"
ipykernel = "^6.26.0"
ipywidgets-bokeh = "^1.5.0"
pandas-stubs = "^2.1.1.230928"
dash = "^2.14.1"
dash-bootstrap-components = "^1.5.0"
[tool.pytest.ini_options]
addopts = "--cov=quacc --capture=tee-sys"

View File

@ -1,43 +1,226 @@
import os
from collections import defaultdict
from pathlib import Path
from typing import List
import dash_bootstrap_components as dbc
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from dash import Dash, dash_table, dcc, html
from dash import Dash, Input, Output, State, callback, dash_table, dcc, html
from quacc.evaluation.report import DatasetReport
from quacc import plot
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
backend = plot.get_backend("plotly")
valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes
def get_fig(data: pd.DataFrame):
fig = go.Figure()
xs = data.index.to_numpy()
for col in data.columns.unique(0):
_line = go.Scatter(x=xs, y=data.loc[:, col], mode="lines+markers", name=col)
fig.add_trace(_line)
def get_datasets(root: str | Path) -> List[DatasetReport]:
def explore_datasets(root: str | Path) -> List[Path]:
if isinstance(root, str):
root = Path(root)
fig.update_layout(xaxis_title="test_prevalence", yaxis_title="acc. error")
if root.name == "plot":
return []
return fig
if not root.exists():
return []
dr_paths = []
for f in os.listdir(root):
if (root / f).is_dir():
dr_paths += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
dr_paths.append(root / f)
return dr_paths
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
return {str(drp.parent): DatasetReport.unpickle(drp) for drp in dr_paths}
def app_instance():
dr: DatasetReport = DatasetReport.unpickle(Path("output/debug/imdb/imdb.pickle"))
data = dr.data(metric="acc").groupby(level=1).mean()
app = Dash(__name__)
app.layout = html.Div(
[
# html.Div(children="Hello World"),
# dash_table.DataTable(data=df.to_dict("records")),
dcc.Graph(figure=get_fig(data), style={"height": "95vh"}),
]
def get_fig(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
match (view, mode):
case ("avg", _):
return dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
backend=backend,
)
return app
case (_, _):
cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)]
return cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
backend=backend,
)
datasets = get_datasets("output")
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
sidebar_style = {
"top": 0,
"left": 0,
"bottom": 0,
"padding": "1vw",
"padding-top": "2vw",
"margin": "0px",
"flex": 2,
}
content_style = {
# "margin-left": "18vw",
"flex": 9,
}
sidebar = html.Div(
children=[
html.H4("Parameters:", style={"margin-bottom": "1vw"}),
dbc.Select(
options=list(datasets.keys()),
value=list(datasets.keys())[0],
id="dataset",
),
dbc.Select(
# clearable=False,
# searchable=False,
id="metric",
style={"margin-top": "1vh"},
),
html.Div(
[
dbc.RadioItems(
id="view",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
dbc.RadioItems(
id="mode",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
],
className="radio-group-v d-flex justify-content-around",
),
html.Div(
[
dbc.Checklist(
id="estimators",
className="btn-group mt-3",
inputClassName="btn-check",
labelClassName="btn btn-outline-primary",
labelCheckedClassName="active",
),
],
className="radio-group-wide",
),
],
style=sidebar_style,
id="app-sidebar",
)
content = html.Div(
children=[
dcc.Graph(
style={"margin": 0, "height": "100vh"},
id="graph1",
),
],
style=content_style,
)
app.layout = html.Div(
children=[sidebar, content],
style={"display": "flex", "flexDirection": "row"},
)
@callback(
Output("metric", "options"),
Output("metric", "value"),
Input("dataset", "value"),
State("metric", "value"),
)
def update_metrics(dataset, old_metric):
dr = datasets[dataset]
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
return valid_metrics, new_metric
@callback(
Output("estimators", "options"),
Output("estimators", "value"),
Input("dataset", "value"),
Input("metric", "value"),
State("estimators", "value"),
)
def update_estimators(dataset, metric, old_estimators):
dr = datasets[dataset]
valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy()
new_estimators = valid_estimators[
np.isin(valid_estimators, old_estimators)
].tolist()
return valid_estimators, new_estimators
@callback(
Output("view", "options"),
Output("view", "value"),
Input("dataset", "value"),
State("view", "value"),
)
def update_view(dataset, old_view):
dr = datasets[dataset]
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
new_view = old_view if old_view in valid_views else valid_views[0]
return valid_views, new_view
@callback(
Output("mode", "options"),
Output("mode", "value"),
Input("view", "value"),
State("mode", "value"),
)
def update_mode(view, old_mode):
valid_modes = [m for m in valid_plot_modes[view] if not m.endswith("table")]
new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
return valid_modes, new_mode
@callback(
Output("graph1", "figure"),
Input("dataset", "value"),
Input("metric", "value"),
Input("estimators", "value"),
Input("view", "value"),
Input("mode", "value"),
)
def update_graph(dataset, metric, estimators, view, mode):
dr = datasets[dataset]
return get_fig(dr=dr, metric=metric, estimators=estimators, view=view, mode=mode)
def run():
app = app_instance()
app.run(debug=True)

View File

@ -0,0 +1,86 @@
/* restyle radio items */
.radio-group .form-check {
padding-left: 0;
}
.radio-group .btn-group > .form-check:not(:last-child) > .btn {
border-top-right-radius: 0;
border-bottom-right-radius: 0;
}
.radio-group .btn-group > .form-check:not(:first-child) > .btn {
border-top-left-radius: 0;
border-bottom-left-radius: 0;
margin-left: -1px;
}
.radio-group-v{
padding: 0 10px;
}
.radio-group-v .form-check {
padding-top: 0;
padding-left: 2px;
}
.radio-group-v .btn-group {
flex-direction: column;
justify-content: center;
align-items: stretch;
flex-grow: 1;
}
.radio-group-v > .btn-group:first-child {
flex-grow: 0;
}
.radio-group-v > .btn-group:last-child {
margin-left: 20px;
}
.radio-group-v .btn-group > .form-check:not(:last-child) > .btn {
border-bottom-right-radius: 0;
border-bottom-left-radius: 0;
}
.radio-group-v .btn-group > .form-check:not(:first-child) > .btn {
border-top-right-radius: 0;
border-top-left-radius: 0;
margin-top: -3px;
}
.radio-group-v .btn-group .btn{
width: 100%;
}
.radio-group-wide .form-check {
padding-left: 0px;
}
.radio-group-wide .btn-group{
flex-wrap: wrap;
}
.radio-group-wide .btn-group .form-check{
flex: 1;
margin-top: -3px;
margin-left: -1px;
}
.radio-group-wide .btn-group .form-check .btn{
width: 100%;
border-radius: 0;
}
.radio-group-wide .btn-group .form-check:first-child > .btn{
border-top-left-radius: 10px;
}
.radio-group-wide .btn-group .form-check:last-child > .btn{
border-bottom-right-radius: 10px;
}
div#app-sidebar{
border-right: 2px solid var(--primary);
}