QuAcc/qcdash/app.py

229 lines
6.2 KiB
Python

import os
from collections import defaultdict
from pathlib import Path
from typing import List
import dash_bootstrap_components as dbc
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from dash import Dash, Input, Output, State, callback, dash_table, dcc, html
from quacc import plot
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
backend = plot.get_backend("plotly")
valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes
def get_datasets(root: str | Path) -> List[DatasetReport]:
def explore_datasets(root: str | Path) -> List[Path]:
if isinstance(root, str):
root = Path(root)
if root.name == "plot":
return []
if not root.exists():
return []
dr_paths = []
for f in os.listdir(root):
if (root / f).is_dir():
dr_paths += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
dr_paths.append(root / f)
return dr_paths
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
return {str(drp.parent): DatasetReport.unpickle(drp) for drp in dr_paths}
def get_fig(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
match (view, mode):
case ("avg", _):
return dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
backend=backend,
)
case (_, _):
cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)]
return cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
backend=backend,
)
datasets = get_datasets("output")
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
sidebar_style = {
"top": 0,
"left": 0,
"bottom": 0,
"padding": "1vw",
"padding-top": "2vw",
"margin": "0px",
"flex": 2,
}
content_style = {
# "margin-left": "18vw",
"flex": 9,
}
sidebar = html.Div(
children=[
html.H4("Parameters:", style={"margin-bottom": "1vw"}),
dbc.Select(
options=list(datasets.keys()),
value=list(datasets.keys())[0],
id="dataset",
),
dbc.Select(
# clearable=False,
# searchable=False,
id="metric",
style={"margin-top": "1vh"},
),
html.Div(
[
dbc.RadioItems(
id="view",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
dbc.RadioItems(
id="mode",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
],
className="radio-group-v d-flex justify-content-around",
),
html.Div(
[
dbc.Checklist(
id="estimators",
className="btn-group mt-3",
inputClassName="btn-check",
labelClassName="btn btn-outline-primary",
labelCheckedClassName="active",
),
],
className="radio-group-wide",
),
],
style=sidebar_style,
id="app-sidebar",
)
content = html.Div(
children=[
dcc.Graph(
style={"margin": 0, "height": "100vh"},
id="graph1",
),
],
style=content_style,
)
app.layout = html.Div(
children=[sidebar, content],
style={"display": "flex", "flexDirection": "row"},
)
@callback(
Output("metric", "options"),
Output("metric", "value"),
Input("dataset", "value"),
State("metric", "value"),
)
def update_metrics(dataset, old_metric):
dr = datasets[dataset]
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
return valid_metrics, new_metric
@callback(
Output("estimators", "options"),
Output("estimators", "value"),
Input("dataset", "value"),
Input("metric", "value"),
State("estimators", "value"),
)
def update_estimators(dataset, metric, old_estimators):
dr = datasets[dataset]
valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy()
new_estimators = valid_estimators[
np.isin(valid_estimators, old_estimators)
].tolist()
return valid_estimators, new_estimators
@callback(
Output("view", "options"),
Output("view", "value"),
Input("dataset", "value"),
State("view", "value"),
)
def update_view(dataset, old_view):
dr = datasets[dataset]
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
new_view = old_view if old_view in valid_views else valid_views[0]
return valid_views, new_view
@callback(
Output("mode", "options"),
Output("mode", "value"),
Input("view", "value"),
State("mode", "value"),
)
def update_mode(view, old_mode):
valid_modes = [m for m in valid_plot_modes[view] if not m.endswith("table")]
new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
return valid_modes, new_mode
@callback(
Output("graph1", "figure"),
Input("dataset", "value"),
Input("metric", "value"),
Input("estimators", "value"),
Input("view", "value"),
Input("mode", "value"),
)
def update_graph(dataset, metric, estimators, view, mode):
dr = datasets[dataset]
return get_fig(dr=dr, metric=metric, estimators=estimators, view=view, mode=mode)
def run():
app.run(debug=True)
if __name__ == "__main__":
run()