root added

This commit is contained in:
Lorenzo Volpi 2023-12-01 13:22:53 +01:00
parent 3252311382
commit f1a769e585
3 changed files with 41 additions and 21 deletions

View File

@ -18,6 +18,7 @@ from quacc.evaluation.stats import wilcoxon
valid_plot_modes = defaultdict(lambda: CompReport._default_modes) valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes valid_plot_modes["avg"] = DatasetReport._default_dr_modes
root_folder = "output"
def get_datasets(root: str | Path) -> List[DatasetReport]: def get_datasets(root: str | Path) -> List[DatasetReport]:
@ -168,7 +169,13 @@ def get_Graph(fig):
) )
datasets = get_datasets("output") datasets = get_datasets(root_folder)
def get_dr(root, dataset):
ds = str(Path(root) / dataset)
return datasets[ds]
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# app.config.suppress_callback_exceptions = True # app.config.suppress_callback_exceptions = True
@ -246,6 +253,7 @@ app.layout = html.Div(
[ [
dcc.Interval(id="reload", interval=10 * 60 * 1000), dcc.Interval(id="reload", interval=10 * 60 * 1000),
dcc.Location(id="url", refresh=False), dcc.Location(id="url", refresh=False),
dcc.Store(id="root", storage_type="session", data=root_folder),
html.Div( html.Div(
[ [
html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style), html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
@ -272,26 +280,33 @@ def apply_param(href, triggered_id, id, curr):
@callback( @callback(
Output("dataset", "value"), Output("dataset", "value"),
Output("dataset", "options"), Output("dataset", "options"),
Output("root", "data"),
Input("url", "href"), Input("url", "href"),
Input("reload", "n_intervals"), Input("reload", "n_intervals"),
State("dataset", "value"), State("dataset", "value"),
) )
def update_dataset(href, n_intervals, dataset): def update_dataset(href, n_intervals, dataset):
match ctx.triggered_id: if ctx.triggered_id == "reload":
case "reload":
new_datasets = get_datasets("output") new_datasets = get_datasets("output")
global datasets global datasets
datasets = new_datasets datasets = new_datasets
req_dataset = dataset
case "url":
params = parse_href(href) params = parse_href(href)
req_dataset = params.get("dataset", None) req_dataset = params.get("dataset", None)
root = params.get("root", "")
available_datasets = list(datasets.keys()) def filter_datasets(root, available_datasets):
return [
str(Path(d).relative_to(root))
for d in available_datasets
if d.startswith(root)
]
available_datasets = filter_datasets(root, datasets.keys())
new_dataset = ( new_dataset = (
req_dataset if req_dataset in available_datasets else available_datasets[0] req_dataset if req_dataset in available_datasets else available_datasets[0]
) )
return new_dataset, available_datasets return new_dataset, available_datasets, root
@callback( @callback(
@ -300,9 +315,10 @@ def update_dataset(href, n_intervals, dataset):
Input("url", "href"), Input("url", "href"),
Input("dataset", "value"), Input("dataset", "value"),
State("metric", "value"), State("metric", "value"),
State("root", "data"),
) )
def update_metrics(href, dataset, curr_metric): def update_metrics(href, dataset, curr_metric, root):
dr = datasets[dataset] dr = get_dr(root, dataset)
old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric) old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")] valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0] new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
@ -316,9 +332,10 @@ def update_metrics(href, dataset, curr_metric):
Input("dataset", "value"), Input("dataset", "value"),
Input("metric", "value"), Input("metric", "value"),
State("estimators", "value"), State("estimators", "value"),
State("root", "data"),
) )
def update_estimators(href, dataset, metric, curr_estimators): def update_estimators(href, dataset, metric, curr_estimators, root):
dr = datasets[dataset] dr = get_dr(root, dataset)
old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators) old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
if isinstance(old_estimators, str): if isinstance(old_estimators, str):
try: try:
@ -338,9 +355,10 @@ def update_estimators(href, dataset, metric, curr_estimators):
Input("url", "href"), Input("url", "href"),
Input("dataset", "value"), Input("dataset", "value"),
State("view", "value"), State("view", "value"),
State("root", "data"),
) )
def update_view(href, dataset, curr_view): def update_view(href, dataset, curr_view, root):
dr = datasets[dataset] dr = get_dr(root, dataset)
old_view = apply_param(href, ctx.triggered_id, "view", curr_view) old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs] valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
new_view = old_view if old_view in valid_views else valid_views[0] new_view = old_view if old_view in valid_views else valid_views[0]
@ -369,8 +387,9 @@ def update_mode(href, view, curr_mode):
Input("estimators", "value"), Input("estimators", "value"),
Input("view", "value"), Input("view", "value"),
Input("mode", "value"), Input("mode", "value"),
State("root", "data"),
) )
def update_content(dataset, metric, estimators, view, mode): def update_content(dataset, metric, estimators, view, mode, root):
search = urlencode( search = urlencode(
dict( dict(
dataset=dataset, dataset=dataset,
@ -378,10 +397,11 @@ def update_content(dataset, metric, estimators, view, mode):
estimators=json.dumps(estimators), estimators=json.dumps(estimators),
view=view, view=view,
mode=mode, mode=mode,
root=root,
), ),
quote_via=quote, quote_via=quote,
) )
dr = datasets[dataset] dr = get_dr(root, dataset)
match mode: match mode:
case m if m.endswith("table"): case m if m.endswith("table"):
df = get_table( df = get_table(

View File

@ -1 +1 @@
from .method_kdey_clean import KDEy from .kdey import KDEy