dash updated
This commit is contained in:
parent
e3b42e0648
commit
a5c54a93b7
122
qcdash/app.py
122
qcdash/app.py
|
@ -2,6 +2,7 @@ import json
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from operator import index
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from urllib.parse import parse_qsl, quote, urlencode, urlparse
|
||||
|
@ -9,7 +10,7 @@ from urllib.parse import parse_qsl, quote, urlencode, urlparse
|
|||
import dash_bootstrap_components as dbc
|
||||
import numpy as np
|
||||
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
|
||||
from dash.dash_table.Format import Format, Scheme
|
||||
from dash.dash_table.Format import Align, Format, Scheme
|
||||
|
||||
from quacc import plot
|
||||
from quacc.evaluation.estimators import CE
|
||||
|
@ -21,6 +22,10 @@ valid_plot_modes["avg"] = DatasetReport._default_dr_modes
|
|||
root_folder = "output"
|
||||
|
||||
|
||||
def _get_prev_str(prev: np.ndarray):
|
||||
return str(tuple(np.around(prev, decimals=2)))
|
||||
|
||||
|
||||
def get_datasets(root: str | Path) -> List[DatasetReport]:
|
||||
def load_dataset(dataset):
|
||||
dataset = Path(dataset)
|
||||
|
@ -63,7 +68,7 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None):
|
|||
backend=_backend,
|
||||
)
|
||||
case (_, _):
|
||||
cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)]
|
||||
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
|
||||
return cr.get_plots(
|
||||
mode=mode,
|
||||
metric=metric,
|
||||
|
@ -76,53 +81,105 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None):
|
|||
|
||||
def get_table(dr: DatasetReport, metric, estimators, view, mode):
|
||||
estimators = CE.name[estimators]
|
||||
_prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
|
||||
match (view, mode):
|
||||
case ("avg", "train_table"):
|
||||
return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
|
||||
# return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
|
||||
return dr.train_table(metric=metric, estimators=estimators)
|
||||
case ("avg", "test_table"):
|
||||
return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||
# return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||
return dr.test_table(metric=metric, estimators=estimators)
|
||||
case ("avg", "shift_table"):
|
||||
return (
|
||||
dr.shift_data(metric=metric, estimators=estimators)
|
||||
.groupby(level=0)
|
||||
.mean()
|
||||
)
|
||||
# return (
|
||||
# dr.shift_data(metric=metric, estimators=estimators)
|
||||
# .groupby(level=0)
|
||||
# .mean()
|
||||
# )
|
||||
return dr.shift_table(metric=metric, estimators=estimators)
|
||||
case ("avg", "stats_table"):
|
||||
return wilcoxon(dr, metric=metric, estimators=estimators)
|
||||
case (_, "train_table"):
|
||||
cr = dr.crs[_prevs.index(view)]
|
||||
return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
|
||||
# return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||
return cr.train_table(metric=metric, estimators=estimators)
|
||||
case (_, "shift_table"):
|
||||
cr = dr.crs[_prevs.index(view)]
|
||||
return (
|
||||
cr.shift_data(metric=metric, estimators=estimators)
|
||||
.groupby(level=0)
|
||||
.mean()
|
||||
)
|
||||
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
|
||||
# return (
|
||||
# cr.shift_data(metric=metric, estimators=estimators)
|
||||
# .groupby(level=0)
|
||||
# .mean()
|
||||
# )
|
||||
return cr.shift_table(metric=metric, estimators=estimators)
|
||||
case (_, "stats_table"):
|
||||
cr = dr.crs[_prevs.index(view)]
|
||||
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
|
||||
return wilcoxon(cr, metric=metric, estimators=estimators)
|
||||
|
||||
|
||||
def get_DataTable(df):
|
||||
def get_DataTable(df, mode):
|
||||
_primary = "#0d6efd"
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
_index_name = dict(
|
||||
train_table="test prev.",
|
||||
test_table="train prev.",
|
||||
shift_table="shift",
|
||||
stats_table="method",
|
||||
)
|
||||
df = df.reset_index()
|
||||
columns = {
|
||||
c: dict(
|
||||
id=c,
|
||||
name=c,
|
||||
name=_index_name[mode] if c == "index" else c,
|
||||
type="numeric",
|
||||
format=Format(precision=6, scheme=Scheme.exponent, nully="nan"),
|
||||
)
|
||||
for c in df.columns
|
||||
}
|
||||
columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
|
||||
# columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
|
||||
columns["index"]["format"] = Format()
|
||||
columns = list(columns.values())
|
||||
data = df.to_dict("records")
|
||||
for d in data:
|
||||
if isinstance(d["index"], tuple | list | np.ndarray):
|
||||
d["index"] = "(" + ", ".join([f"{v:.2f}" for v in d["index"]]) + ")"
|
||||
elif isinstance(d["index"], float):
|
||||
d["index"] = f"{d['index']:.2f}"
|
||||
|
||||
_style_cell = {
|
||||
"padding": "0 12px",
|
||||
"border": "0",
|
||||
"border-bottom": f"1px solid {_primary}",
|
||||
}
|
||||
|
||||
_style_cell_conditional = [
|
||||
{
|
||||
"if": {"column_id": "index"},
|
||||
"text_align": "center",
|
||||
},
|
||||
]
|
||||
|
||||
_style_data_conditional = []
|
||||
if mode != "stats_table":
|
||||
_style_data_conditional += [
|
||||
{
|
||||
"if": {"column_id": "index", "row_index": len(data) - 1},
|
||||
"font_weight": "bold",
|
||||
},
|
||||
{
|
||||
"if": {"row_index": len(data) - 1},
|
||||
"background_color": "#0d6efd",
|
||||
"color": "white",
|
||||
},
|
||||
]
|
||||
|
||||
_style_table = {
|
||||
"margin": "6vh 15px",
|
||||
"padding": "15px",
|
||||
"maxWidth": "80vw",
|
||||
"overflowX": "auto",
|
||||
"border": f"0px solid {_primary}",
|
||||
"border-radius": "6px",
|
||||
}
|
||||
|
||||
return html.Div(
|
||||
[
|
||||
|
@ -130,19 +187,10 @@ def get_DataTable(df):
|
|||
data=data,
|
||||
columns=columns,
|
||||
id="table1",
|
||||
style_cell={
|
||||
"padding": "0 12px",
|
||||
"border": "0",
|
||||
"border-bottom": f"1px solid {_primary}",
|
||||
},
|
||||
style_table={
|
||||
"margin": "6vh 15px",
|
||||
"padding": "15px",
|
||||
"maxWidth": "80vw",
|
||||
"overflowX": "auto",
|
||||
"border": f"0px solid {_primary}",
|
||||
"border-radius": "6px",
|
||||
},
|
||||
style_cell=_style_cell,
|
||||
style_cell_conditional=_style_cell_conditional,
|
||||
style_data_conditional=_style_data_conditional,
|
||||
style_table=_style_table,
|
||||
)
|
||||
],
|
||||
style={
|
||||
|
@ -361,7 +409,7 @@ def update_estimators(href, dataset, metric, curr_estimators, root):
|
|||
def update_view(href, dataset, curr_view, root):
|
||||
dr = get_dr(root, dataset)
|
||||
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
|
||||
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
|
||||
valid_views = ["avg"] + [_get_prev_str(cr.train_prev) for cr in dr.crs]
|
||||
new_view = old_view if old_view in valid_views else valid_views[0]
|
||||
return valid_views, new_view
|
||||
|
||||
|
@ -412,7 +460,7 @@ def update_content(dataset, metric, estimators, view, mode, root):
|
|||
view=view,
|
||||
mode=mode,
|
||||
)
|
||||
dt = get_DataTable(df)
|
||||
dt = get_DataTable(df, mode)
|
||||
app_content = [] if dt is None else [dt]
|
||||
case _:
|
||||
fig = get_fig(
|
||||
|
|
Loading…
Reference in New Issue