dash updated

This commit is contained in:
Lorenzo Volpi 2023-12-21 16:47:07 +01:00
parent e3b42e0648
commit a5c54a93b7
1 changed files with 85 additions and 37 deletions

View File

@ -2,6 +2,7 @@ import json
import os
from collections import defaultdict
from json import JSONDecodeError
from operator import index
from pathlib import Path
from typing import List
from urllib.parse import parse_qsl, quote, urlencode, urlparse
@ -9,7 +10,7 @@ from urllib.parse import parse_qsl, quote, urlencode, urlparse
import dash_bootstrap_components as dbc
import numpy as np
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
from dash.dash_table.Format import Format, Scheme
from dash.dash_table.Format import Align, Format, Scheme
from quacc import plot
from quacc.evaluation.estimators import CE
@ -21,6 +22,10 @@ valid_plot_modes["avg"] = DatasetReport._default_dr_modes
root_folder = "output"
def _get_prev_str(prev: np.ndarray):
return str(tuple(np.around(prev, decimals=2)))
def get_datasets(root: str | Path) -> List[DatasetReport]:
def load_dataset(dataset):
dataset = Path(dataset)
@ -63,7 +68,7 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None):
backend=_backend,
)
case (_, _):
cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)]
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
return cr.get_plots(
mode=mode,
metric=metric,
@ -76,53 +81,105 @@ def get_fig(dr: DatasetReport, metric, estimators, view, mode, backend=None):
def get_table(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
_prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
match (view, mode):
case ("avg", "train_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
# return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
return dr.train_table(metric=metric, estimators=estimators)
case ("avg", "test_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
# return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
return dr.test_table(metric=metric, estimators=estimators)
case ("avg", "shift_table"):
return (
dr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
# return (
# dr.shift_data(metric=metric, estimators=estimators)
# .groupby(level=0)
# .mean()
# )
return dr.shift_table(metric=metric, estimators=estimators)
case ("avg", "stats_table"):
return wilcoxon(dr, metric=metric, estimators=estimators)
case (_, "train_table"):
cr = dr.crs[_prevs.index(view)]
return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
# return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
return cr.train_table(metric=metric, estimators=estimators)
case (_, "shift_table"):
cr = dr.crs[_prevs.index(view)]
return (
cr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
# return (
# cr.shift_data(metric=metric, estimators=estimators)
# .groupby(level=0)
# .mean()
# )
return cr.shift_table(metric=metric, estimators=estimators)
case (_, "stats_table"):
cr = dr.crs[_prevs.index(view)]
cr = dr.crs[[_get_prev_str(c.train_prev) for c in dr.crs].index(view)]
return wilcoxon(cr, metric=metric, estimators=estimators)
def get_DataTable(df):
def get_DataTable(df, mode):
_primary = "#0d6efd"
if df.empty:
return None
_index_name = dict(
train_table="test prev.",
test_table="train prev.",
shift_table="shift",
stats_table="method",
)
df = df.reset_index()
columns = {
c: dict(
id=c,
name=c,
name=_index_name[mode] if c == "index" else c,
type="numeric",
format=Format(precision=6, scheme=Scheme.exponent, nully="nan"),
)
for c in df.columns
}
columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
# columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
columns["index"]["format"] = Format()
columns = list(columns.values())
data = df.to_dict("records")
for d in data:
if isinstance(d["index"], tuple | list | np.ndarray):
d["index"] = "(" + ", ".join([f"{v:.2f}" for v in d["index"]]) + ")"
elif isinstance(d["index"], float):
d["index"] = f"{d['index']:.2f}"
_style_cell = {
"padding": "0 12px",
"border": "0",
"border-bottom": f"1px solid {_primary}",
}
_style_cell_conditional = [
{
"if": {"column_id": "index"},
"text_align": "center",
},
]
_style_data_conditional = []
if mode != "stats_table":
_style_data_conditional += [
{
"if": {"column_id": "index", "row_index": len(data) - 1},
"font_weight": "bold",
},
{
"if": {"row_index": len(data) - 1},
"background_color": "#0d6efd",
"color": "white",
},
]
_style_table = {
"margin": "6vh 15px",
"padding": "15px",
"maxWidth": "80vw",
"overflowX": "auto",
"border": f"0px solid {_primary}",
"border-radius": "6px",
}
return html.Div(
[
@ -130,19 +187,10 @@ def get_DataTable(df):
data=data,
columns=columns,
id="table1",
style_cell={
"padding": "0 12px",
"border": "0",
"border-bottom": f"1px solid {_primary}",
},
style_table={
"margin": "6vh 15px",
"padding": "15px",
"maxWidth": "80vw",
"overflowX": "auto",
"border": f"0px solid {_primary}",
"border-radius": "6px",
},
style_cell=_style_cell,
style_cell_conditional=_style_cell_conditional,
style_data_conditional=_style_data_conditional,
style_table=_style_table,
)
],
style={
@ -361,7 +409,7 @@ def update_estimators(href, dataset, metric, curr_estimators, root):
def update_view(href, dataset, curr_view, root):
dr = get_dr(root, dataset)
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
valid_views = ["avg"] + [_get_prev_str(cr.train_prev) for cr in dr.crs]
new_view = old_view if old_view in valid_views else valid_views[0]
return valid_views, new_view
@ -412,7 +460,7 @@ def update_content(dataset, metric, estimators, view, mode, root):
view=view,
mode=mode,
)
dt = get_DataTable(df)
dt = get_DataTable(df, mode)
app_content = [] if dt is None else [dt]
case _:
fig = get_fig(