Merge branch 'dash'
This commit is contained in:
commit
44d820d4ab
|
@ -6,11 +6,13 @@ quavenv/*
|
|||
__pycache__/*
|
||||
baselines/__pycache__/*
|
||||
baselines/densratio/__pycache__/*
|
||||
qcdash/__pycache__/*
|
||||
qcpanel/__pycache__/*
|
||||
quacc/__pycache__/*
|
||||
quacc/evaluation/__pycache__/*
|
||||
quacc/method/__pycache__/*
|
||||
quacc/quantification/__pycache__/*
|
||||
quacc/plot/__pycache__/*
|
||||
tests/__pycache__/*
|
||||
tests/*/__pycache__/*
|
||||
tests/*/*/__pycache__/*
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -18,6 +18,7 @@ abstention = "^0.1.3.1"
|
|||
main = "quacc.main:main"
|
||||
run = "run:run"
|
||||
panel = "qcpanel.run:run"
|
||||
dash = "qcdash.app:run"
|
||||
sync_up = "remote:sync_code"
|
||||
sync_down = "remote:sync_output"
|
||||
merge_data = "merge_data:run"
|
||||
|
@ -27,6 +28,7 @@ poetry_command = ""
|
|||
|
||||
[tool.poe.tasks]
|
||||
ilona = "ssh volpi@ilona.isti.cnr.it"
|
||||
dash = "gunicorn qcdash.app:server -b ilona.isti.cnr.it:33421"
|
||||
|
||||
[tool.poe.tasks.logr]
|
||||
shell = """
|
||||
|
@ -48,6 +50,9 @@ ipympl = "^0.9.3"
|
|||
ipykernel = "^6.26.0"
|
||||
ipywidgets-bokeh = "^1.5.0"
|
||||
pandas-stubs = "^2.1.1.230928"
|
||||
dash = "^2.14.1"
|
||||
dash-bootstrap-components = "^1.5.0"
|
||||
gunicorn = "^21.2.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--cov=quacc --capture=tee-sys"
|
||||
|
|
|
@ -0,0 +1,413 @@
|
|||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from urllib.parse import parse_qsl, quote, urlencode, urlparse
|
||||
|
||||
import dash_bootstrap_components as dbc
|
||||
import numpy as np
|
||||
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
|
||||
from dash.dash_table.Format import Format, Scheme
|
||||
|
||||
from quacc import plot
|
||||
from quacc.evaluation.estimators import CE
|
||||
from quacc.evaluation.report import CompReport, DatasetReport
|
||||
from quacc.evaluation.stats import ttest_rel
|
||||
|
||||
backend = plot.get_backend("plotly")
|
||||
|
||||
valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
|
||||
valid_plot_modes["avg"] = DatasetReport._default_dr_modes
|
||||
|
||||
|
||||
def get_datasets(root: str | Path) -> List[DatasetReport]:
|
||||
def load_dataset(dataset):
|
||||
dataset = Path(dataset)
|
||||
return DatasetReport.unpickle(dataset)
|
||||
|
||||
def explore_datasets(root: str | Path) -> List[Path]:
|
||||
if isinstance(root, str):
|
||||
root = Path(root)
|
||||
|
||||
if root.name == "plot":
|
||||
return []
|
||||
|
||||
if not root.exists():
|
||||
return []
|
||||
|
||||
dr_paths = []
|
||||
for f in os.listdir(root):
|
||||
if (root / f).is_dir():
|
||||
dr_paths += explore_datasets(root / f)
|
||||
elif f == f"{root.name}.pickle":
|
||||
dr_paths.append(root / f)
|
||||
|
||||
return dr_paths
|
||||
|
||||
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
|
||||
return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
|
||||
|
||||
|
||||
def get_fig(dr: DatasetReport, metric, estimators, view, mode):
|
||||
estimators = CE.name[estimators]
|
||||
match (view, mode):
|
||||
case ("avg", _):
|
||||
return dr.get_plots(
|
||||
mode=mode,
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf="plotly",
|
||||
save_fig=False,
|
||||
backend=backend,
|
||||
)
|
||||
case (_, _):
|
||||
cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)]
|
||||
return cr.get_plots(
|
||||
mode=mode,
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf="plotly",
|
||||
save_fig=False,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
def get_table(dr: DatasetReport, metric, estimators, view, mode):
|
||||
estimators = CE.name[estimators]
|
||||
_prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
|
||||
match (view, mode):
|
||||
case ("avg", "train_table"):
|
||||
return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
|
||||
case ("avg", "test_table"):
|
||||
return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||
case ("avg", "shift_table"):
|
||||
return (
|
||||
dr.shift_data(metric=metric, estimators=estimators)
|
||||
.groupby(level=0)
|
||||
.mean()
|
||||
)
|
||||
case ("avg", "stats_table"):
|
||||
return ttest_rel(dr, metric=metric, estimators=estimators)
|
||||
case (_, "train_table"):
|
||||
cr = dr.crs[_prevs.index(view)]
|
||||
return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||
case (_, "shift_table"):
|
||||
cr = dr.crs[_prevs.index(view)]
|
||||
return (
|
||||
cr.shift_data(metric=metric, estimators=estimators)
|
||||
.groupby(level=0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
|
||||
def get_DataTable(df):
|
||||
_primary = "#0d6efd"
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
df = df.reset_index()
|
||||
columns = {
|
||||
c: dict(
|
||||
id=c,
|
||||
name=c,
|
||||
type="numeric",
|
||||
format=Format(precision=6, scheme=Scheme.exponent),
|
||||
)
|
||||
for c in df.columns
|
||||
}
|
||||
columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
|
||||
columns = list(columns.values())
|
||||
data = df.to_dict("records")
|
||||
|
||||
return html.Div(
|
||||
[
|
||||
dash_table.DataTable(
|
||||
data=data,
|
||||
columns=columns,
|
||||
id="table1",
|
||||
style_cell={
|
||||
"padding": "0 12px",
|
||||
"border": "0",
|
||||
"border-bottom": f"1px solid {_primary}",
|
||||
},
|
||||
style_table={
|
||||
"margin": "6vh 15px",
|
||||
"padding": "15px",
|
||||
"maxWidth": "80vw",
|
||||
"overflowX": "auto",
|
||||
"border": f"0px solid {_primary}",
|
||||
"border-radius": "6px",
|
||||
},
|
||||
)
|
||||
],
|
||||
style={
|
||||
"display": "flex",
|
||||
"flex-direction": "column",
|
||||
# "justify-content": "center",
|
||||
"align-items": "center",
|
||||
"height": "100vh",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_Graph(fig):
|
||||
if fig is None:
|
||||
return None
|
||||
|
||||
return dcc.Graph(
|
||||
id="graph1",
|
||||
figure=fig,
|
||||
style={
|
||||
"margin": 0,
|
||||
"height": "100vh",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
datasets = get_datasets("output")
|
||||
|
||||
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
|
||||
# app.config.suppress_callback_exceptions = True
|
||||
sidebar_style = {
|
||||
"top": 0,
|
||||
"left": 0,
|
||||
"bottom": 0,
|
||||
"padding": "1vw",
|
||||
"padding-top": "2vw",
|
||||
"margin": "0px",
|
||||
"flex": 1,
|
||||
"overflow": "scroll",
|
||||
"height": "100vh",
|
||||
}
|
||||
|
||||
content_style = {
|
||||
"flex": 5,
|
||||
"maxWidth": "84vw",
|
||||
}
|
||||
|
||||
|
||||
def parse_href(href: str):
|
||||
parse_result = urlparse(href)
|
||||
params = parse_qsl(parse_result.query)
|
||||
return dict(params)
|
||||
|
||||
|
||||
def get_sidebar():
|
||||
return [
|
||||
html.H4("Parameters:", style={"margin-bottom": "1vw"}),
|
||||
dbc.Select(
|
||||
# options=list(datasets.keys()),
|
||||
# value=list(datasets.keys())[0],
|
||||
id="dataset",
|
||||
),
|
||||
dbc.Select(
|
||||
id="metric",
|
||||
style={"margin-top": "1vh"},
|
||||
),
|
||||
html.Div(
|
||||
[
|
||||
dbc.RadioItems(
|
||||
id="view",
|
||||
class_name="btn-group mt-3",
|
||||
input_class_name="btn-check",
|
||||
label_class_name="btn btn-outline-primary",
|
||||
label_checked_class_name="active",
|
||||
),
|
||||
dbc.RadioItems(
|
||||
id="mode",
|
||||
class_name="btn-group mt-3",
|
||||
input_class_name="btn-check",
|
||||
label_class_name="btn btn-outline-primary",
|
||||
label_checked_class_name="active",
|
||||
),
|
||||
],
|
||||
className="radio-group-v d-flex justify-content-around",
|
||||
),
|
||||
html.Div(
|
||||
[
|
||||
dbc.Checklist(
|
||||
id="estimators",
|
||||
className="btn-group mt-3",
|
||||
inputClassName="btn-check",
|
||||
labelClassName="btn btn-outline-primary",
|
||||
labelCheckedClassName="active",
|
||||
),
|
||||
],
|
||||
className="radio-group-wide",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
app.layout = html.Div(
|
||||
[
|
||||
dcc.Interval(id="reload", interval=10 * 60 * 1000),
|
||||
dcc.Location(id="url", refresh=False),
|
||||
html.Div(
|
||||
[
|
||||
html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
|
||||
html.Div(id="app_content", style=content_style),
|
||||
],
|
||||
id="page_layout",
|
||||
style={"display": "flex", "flexDirection": "row"},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
server = app.server
|
||||
|
||||
|
||||
def apply_param(href, triggered_id, id, curr):
|
||||
match triggered_id:
|
||||
case "url":
|
||||
params = parse_href(href)
|
||||
return params.get(id, None)
|
||||
case _:
|
||||
return curr
|
||||
|
||||
|
||||
@callback(
|
||||
Output("dataset", "value"),
|
||||
Output("dataset", "options"),
|
||||
Input("url", "href"),
|
||||
Input("reload", "n_intervals"),
|
||||
State("dataset", "value"),
|
||||
)
|
||||
def update_dataset(href, n_intervals, dataset):
|
||||
match ctx.triggered_id:
|
||||
case "reload":
|
||||
new_datasets = get_datasets("output")
|
||||
global datasets
|
||||
datasets = new_datasets
|
||||
req_dataset = dataset
|
||||
case "url":
|
||||
params = parse_href(href)
|
||||
req_dataset = params.get("dataset", None)
|
||||
|
||||
available_datasets = list(datasets.keys())
|
||||
new_dataset = (
|
||||
req_dataset if req_dataset in available_datasets else available_datasets[0]
|
||||
)
|
||||
return new_dataset, available_datasets
|
||||
|
||||
|
||||
@callback(
|
||||
Output("metric", "options"),
|
||||
Output("metric", "value"),
|
||||
Input("url", "href"),
|
||||
Input("dataset", "value"),
|
||||
State("metric", "value"),
|
||||
)
|
||||
def update_metrics(href, dataset, curr_metric):
|
||||
dr = datasets[dataset]
|
||||
old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
|
||||
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
|
||||
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
|
||||
return valid_metrics, new_metric
|
||||
|
||||
|
||||
@callback(
|
||||
Output("estimators", "options"),
|
||||
Output("estimators", "value"),
|
||||
Input("url", "href"),
|
||||
Input("dataset", "value"),
|
||||
Input("metric", "value"),
|
||||
State("estimators", "value"),
|
||||
)
|
||||
def update_estimators(href, dataset, metric, curr_estimators):
|
||||
dr = datasets[dataset]
|
||||
old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
|
||||
if isinstance(old_estimators, str):
|
||||
try:
|
||||
old_estimators = json.loads(old_estimators)
|
||||
except JSONDecodeError:
|
||||
old_estimators = []
|
||||
valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy()
|
||||
new_estimators = valid_estimators[
|
||||
np.isin(valid_estimators, old_estimators)
|
||||
].tolist()
|
||||
return valid_estimators, new_estimators
|
||||
|
||||
|
||||
@callback(
|
||||
Output("view", "options"),
|
||||
Output("view", "value"),
|
||||
Input("url", "href"),
|
||||
Input("dataset", "value"),
|
||||
State("view", "value"),
|
||||
)
|
||||
def update_view(href, dataset, curr_view):
|
||||
dr = datasets[dataset]
|
||||
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
|
||||
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
|
||||
new_view = old_view if old_view in valid_views else valid_views[0]
|
||||
return valid_views, new_view
|
||||
|
||||
|
||||
@callback(
|
||||
Output("mode", "options"),
|
||||
Output("mode", "value"),
|
||||
Input("url", "href"),
|
||||
Input("view", "value"),
|
||||
State("mode", "value"),
|
||||
)
|
||||
def update_mode(href, view, curr_mode):
|
||||
old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode)
|
||||
valid_modes = valid_plot_modes[view]
|
||||
new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
|
||||
return valid_modes, new_mode
|
||||
|
||||
|
||||
@callback(
|
||||
Output("app_content", "children"),
|
||||
Output("url", "search"),
|
||||
Input("dataset", "value"),
|
||||
Input("metric", "value"),
|
||||
Input("estimators", "value"),
|
||||
Input("view", "value"),
|
||||
Input("mode", "value"),
|
||||
)
|
||||
def update_content(dataset, metric, estimators, view, mode):
|
||||
search = urlencode(
|
||||
dict(
|
||||
dataset=dataset,
|
||||
metric=metric,
|
||||
estimators=json.dumps(estimators),
|
||||
view=view,
|
||||
mode=mode,
|
||||
),
|
||||
quote_via=quote,
|
||||
)
|
||||
dr = datasets[dataset]
|
||||
match mode:
|
||||
case m if m.endswith("table"):
|
||||
df = get_table(
|
||||
dr=dr,
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
view=view,
|
||||
mode=mode,
|
||||
)
|
||||
dt = get_DataTable(df)
|
||||
app_content = [] if dt is None else [dt]
|
||||
case _:
|
||||
fig = get_fig(
|
||||
dr=dr,
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
view=view,
|
||||
mode=mode,
|
||||
)
|
||||
g = get_Graph(fig)
|
||||
app_content = [] if g is None else [g]
|
||||
|
||||
return app_content, f"?{search}"
|
||||
|
||||
|
||||
def run():
|
||||
app.run(debug=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
|
@ -0,0 +1,86 @@
|
|||
/* restyle radio items */
|
||||
.radio-group .form-check {
|
||||
padding-left: 0;
|
||||
}
|
||||
|
||||
.radio-group .btn-group > .form-check:not(:last-child) > .btn {
|
||||
border-top-right-radius: 0;
|
||||
border-bottom-right-radius: 0;
|
||||
}
|
||||
|
||||
.radio-group .btn-group > .form-check:not(:first-child) > .btn {
|
||||
border-top-left-radius: 0;
|
||||
border-bottom-left-radius: 0;
|
||||
margin-left: -1px;
|
||||
}
|
||||
|
||||
.radio-group-v{
|
||||
padding: 0 10px;
|
||||
}
|
||||
.radio-group-v .form-check {
|
||||
padding-top: 0;
|
||||
padding-left: 2px;
|
||||
}
|
||||
|
||||
.radio-group-v .btn-group {
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: stretch;
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
.radio-group-v > .btn-group:first-child {
|
||||
flex-grow: 0;
|
||||
}
|
||||
|
||||
.radio-group-v > .btn-group:last-child {
|
||||
margin-left: 20px;
|
||||
}
|
||||
|
||||
.radio-group-v .btn-group > .form-check:not(:last-child) > .btn {
|
||||
border-bottom-right-radius: 0;
|
||||
border-bottom-left-radius: 0;
|
||||
}
|
||||
|
||||
.radio-group-v .btn-group > .form-check:not(:first-child) > .btn {
|
||||
border-top-right-radius: 0;
|
||||
border-top-left-radius: 0;
|
||||
margin-top: -3px;
|
||||
}
|
||||
|
||||
|
||||
.radio-group-v .btn-group .btn{
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.radio-group-wide .form-check {
|
||||
padding-left: 0px;
|
||||
}
|
||||
|
||||
.radio-group-wide .btn-group{
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.radio-group-wide .btn-group .form-check{
|
||||
flex: 1;
|
||||
margin-top: -3px;
|
||||
margin-left: -1px;
|
||||
}
|
||||
|
||||
|
||||
.radio-group-wide .btn-group .form-check .btn{
|
||||
width: 100%;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.radio-group-wide .btn-group .form-check:first-child > .btn{
|
||||
border-top-left-radius: 10px;
|
||||
}
|
||||
|
||||
.radio-group-wide .btn-group .form-check:last-child > .btn{
|
||||
border-bottom-right-radius: 10px;
|
||||
}
|
||||
|
||||
div#app-sidebar{
|
||||
border-right: 2px solid var(--primary);
|
||||
}
|
|
@ -1,290 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import panel as pn
|
||||
import param
|
||||
|
||||
from quacc.evaluation.estimators import CE
|
||||
from quacc.evaluation.report import DatasetReport
|
||||
|
||||
pn.extension(design="bootstrap")
|
||||
|
||||
|
||||
def create_cr_plots(
|
||||
dr: DatasetReport,
|
||||
mode="delta",
|
||||
metric="acc",
|
||||
estimators=None,
|
||||
prev=None,
|
||||
):
|
||||
idx = [round(cr.train_prev[1] * 100) for cr in dr.crs].index(prev)
|
||||
cr = dr.crs[idx]
|
||||
estimators = CE.name[estimators]
|
||||
_dpi = 112
|
||||
return pn.pane.Matplotlib(
|
||||
cr.get_plots(
|
||||
mode=mode,
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf="panel",
|
||||
return_fig=True,
|
||||
),
|
||||
tight=True,
|
||||
format="png",
|
||||
sizing_mode="scale_height",
|
||||
# sizing_mode="scale_both",
|
||||
)
|
||||
|
||||
|
||||
def create_avg_plots(
|
||||
dr: DatasetReport,
|
||||
mode="delta",
|
||||
metric="acc",
|
||||
estimators=None,
|
||||
prev=None,
|
||||
):
|
||||
estimators = CE.name[estimators]
|
||||
return pn.pane.Matplotlib(
|
||||
dr.get_plots(
|
||||
mode=mode,
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf="panel",
|
||||
return_fig=True,
|
||||
),
|
||||
tight=True,
|
||||
format="png",
|
||||
sizing_mode="scale_height",
|
||||
# sizing_mode="scale_both",
|
||||
)
|
||||
|
||||
|
||||
def build_cr_tab(dr: DatasetReport):
|
||||
_data = dr.data()
|
||||
_metrics = _data.columns.unique(0)
|
||||
_estimators = _data.columns.unique(1)
|
||||
|
||||
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
|
||||
metric_widget = pn.widgets.Select(
|
||||
name="metric",
|
||||
value="acc",
|
||||
options=valid_metrics,
|
||||
align="center",
|
||||
)
|
||||
|
||||
valid_estimators = [e for e in _estimators if e != "ref"]
|
||||
estimators_widget = pn.widgets.CheckButtonGroup(
|
||||
name="estimators",
|
||||
options=valid_estimators,
|
||||
value=valid_estimators,
|
||||
button_style="outline",
|
||||
button_type="primary",
|
||||
align="center",
|
||||
orientation="vertical",
|
||||
sizing_mode="scale_width",
|
||||
)
|
||||
|
||||
valid_plot_modes = ["delta", "delta_stdev", "diagonal", "shift"]
|
||||
plot_mode_widget = pn.widgets.RadioButtonGroup(
|
||||
name="mode",
|
||||
value=valid_plot_modes[0],
|
||||
options=valid_plot_modes,
|
||||
button_style="outline",
|
||||
button_type="primary",
|
||||
align="center",
|
||||
orientation="vertical",
|
||||
sizing_mode="scale_width",
|
||||
)
|
||||
|
||||
valid_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
|
||||
prevs_widget = pn.widgets.RadioButtonGroup(
|
||||
name="train prevalence",
|
||||
value=valid_prevs[0],
|
||||
options=valid_prevs,
|
||||
button_style="outline",
|
||||
button_type="primary",
|
||||
align="center",
|
||||
orientation="vertical",
|
||||
)
|
||||
|
||||
plot_pane = pn.bind(
|
||||
create_cr_plots,
|
||||
dr=dr,
|
||||
mode=plot_mode_widget,
|
||||
metric=metric_widget,
|
||||
estimators=estimators_widget,
|
||||
prev=prevs_widget,
|
||||
)
|
||||
|
||||
return pn.Row(
|
||||
pn.Spacer(width=20),
|
||||
pn.Column(
|
||||
metric_widget,
|
||||
pn.Row(
|
||||
prevs_widget,
|
||||
plot_mode_widget,
|
||||
),
|
||||
estimators_widget,
|
||||
align="center",
|
||||
),
|
||||
pn.Spacer(sizing_mode="scale_width"),
|
||||
plot_pane,
|
||||
)
|
||||
|
||||
|
||||
def build_avg_tab(dr: DatasetReport):
|
||||
_data = dr.data()
|
||||
_metrics = _data.columns.unique(0)
|
||||
_estimators = _data.columns.unique(1)
|
||||
|
||||
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
|
||||
metric_widget = pn.widgets.Select(
|
||||
name="metric",
|
||||
value="acc",
|
||||
options=valid_metrics,
|
||||
align="center",
|
||||
)
|
||||
|
||||
valid_estimators = [e for e in _estimators if e != "ref"]
|
||||
estimators_widget = pn.widgets.CheckButtonGroup(
|
||||
name="estimators",
|
||||
options=valid_estimators,
|
||||
value=valid_estimators,
|
||||
button_style="outline",
|
||||
button_type="primary",
|
||||
align="center",
|
||||
orientation="vertical",
|
||||
sizing_mode="scale_width",
|
||||
)
|
||||
|
||||
valid_plot_modes = [
|
||||
"delta_train",
|
||||
"stdev_train",
|
||||
"delta_test",
|
||||
"stdev_test",
|
||||
"shift",
|
||||
]
|
||||
plot_mode_widget = pn.widgets.RadioButtonGroup(
|
||||
name="mode",
|
||||
value=valid_plot_modes[0],
|
||||
options=valid_plot_modes,
|
||||
button_style="outline",
|
||||
button_type="primary",
|
||||
align="center",
|
||||
orientation="vertical",
|
||||
sizing_mode="scale_width",
|
||||
)
|
||||
|
||||
plot_pane = pn.bind(
|
||||
create_avg_plots,
|
||||
dr=dr,
|
||||
mode=plot_mode_widget,
|
||||
metric=metric_widget,
|
||||
estimators=estimators_widget,
|
||||
)
|
||||
|
||||
return pn.Row(
|
||||
pn.Spacer(width=20),
|
||||
pn.Column(
|
||||
metric_widget,
|
||||
plot_mode_widget,
|
||||
estimators_widget,
|
||||
align="center",
|
||||
),
|
||||
pn.Spacer(sizing_mode="scale_width"),
|
||||
plot_pane,
|
||||
)
|
||||
|
||||
|
||||
def build_dataset(dataset_path: Path):
|
||||
dr: DatasetReport = DatasetReport.unpickle(dataset_path)
|
||||
|
||||
prevs_tab = ("train prevs.", build_cr_tab(dr))
|
||||
avg_tab = ("avg", build_avg_tab(dr))
|
||||
|
||||
app = pn.Tabs(objects=[avg_tab, prevs_tab], dynamic=False)
|
||||
app.servable()
|
||||
return app
|
||||
|
||||
|
||||
def explore_datasets(root: Path | str):
|
||||
if isinstance(root, str):
|
||||
root = Path(root)
|
||||
|
||||
if root.name == "plot":
|
||||
return []
|
||||
|
||||
if not root.exists():
|
||||
return []
|
||||
|
||||
drs = []
|
||||
for f in os.listdir(root):
|
||||
if (root / f).is_dir():
|
||||
drs += explore_datasets(root / f)
|
||||
elif f == f"{root.name}.pickle":
|
||||
drs.append((root, build_dataset(root / f)))
|
||||
# drs.append((str(root),))
|
||||
|
||||
return drs
|
||||
|
||||
|
||||
class PlotSelector(param.Parameterized):
|
||||
metric = param.Selector(objects=["acc", "f1"])
|
||||
view = param.Selector(objects=["train prevs", "avg"])
|
||||
|
||||
|
||||
def plot_selector_widget():
|
||||
return pn.Param(
|
||||
PlotSelector.param,
|
||||
widgets={
|
||||
"metric": pn.widgets.Select,
|
||||
"view": pn.widgets.Select,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def serve(address="localhost"):
|
||||
# app = build_dataset(Path("output/rcv1_CCAT_9prevs/rcv1_CCAT_9prevs.pickle"))
|
||||
__base_path = "output"
|
||||
__tabs = sorted(
|
||||
explore_datasets(__base_path), key=lambda t: (len(t[0].parts), t[0])
|
||||
)
|
||||
__tabs = [(str(p.relative_to(Path(__base_path))), d) for (p, d) in __tabs]
|
||||
if len(__tabs) > 0:
|
||||
app = pn.Tabs(
|
||||
objects=__tabs,
|
||||
tabs_location="left",
|
||||
dynamic=False,
|
||||
)
|
||||
else:
|
||||
app = pn.Column(
|
||||
pn.pane.Str("No Dataset Found", styles={"font-size": "24pt"}),
|
||||
align="center",
|
||||
)
|
||||
|
||||
__port = 33420
|
||||
__allowed = [address]
|
||||
if address == "localhost":
|
||||
__allowed.append("127.0.0.1")
|
||||
|
||||
pn.serve(
|
||||
app,
|
||||
autoreload=True,
|
||||
port=__port,
|
||||
show=False,
|
||||
address=address,
|
||||
websocket_origin=[f"{_a}:{__port}" for _a in __allowed],
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
action="store",
|
||||
dest="address",
|
||||
default="localhost",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
serve(address=args.address)
|
|
@ -1,10 +1,11 @@
|
|||
import argparse
|
||||
|
||||
import panel as pn
|
||||
from panel.theme.fast import FastDarkTheme, FastDefaultTheme
|
||||
|
||||
from qcpanel.viewer import QuaccTestViewer
|
||||
|
||||
# pn.config.design = pn.theme.Bootstrap
|
||||
# pn.config.design = Fast
|
||||
# pn.config.theme = "dark"
|
||||
pn.config.notifications = True
|
||||
|
||||
|
@ -59,8 +60,8 @@ def app_instance():
|
|||
],
|
||||
main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")],
|
||||
modal=[qtv.modal_pane],
|
||||
theme=pn.theme.DarkTheme,
|
||||
theme_toggle=False,
|
||||
# theme=FastDefaultTheme,
|
||||
theme_toggle=True,
|
||||
)
|
||||
|
||||
app.servable()
|
||||
|
|
|
@ -52,7 +52,7 @@ def create_plots(
|
|||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf="panel",
|
||||
return_fig=True,
|
||||
save_fig=False,
|
||||
)
|
||||
return (
|
||||
pn.pane.Matplotlib(
|
||||
|
@ -91,7 +91,7 @@ def create_plots(
|
|||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf="panel",
|
||||
return_fig=True,
|
||||
save_fig=False,
|
||||
)
|
||||
return (
|
||||
pn.pane.Matplotlib(
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import List, Tuple
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from quacc import plot
|
||||
import quacc.plot as plot
|
||||
from quacc.utils import fmt_line_md
|
||||
|
||||
|
||||
|
@ -215,16 +215,17 @@ class CompReport:
|
|||
|
||||
def get_plots(
|
||||
self,
|
||||
mode="delta",
|
||||
mode="delta_train",
|
||||
metric="acc",
|
||||
estimators=None,
|
||||
conf="default",
|
||||
return_fig=False,
|
||||
save_fig=True,
|
||||
base_path=None,
|
||||
backend=None,
|
||||
) -> List[Tuple[str, Path]]:
|
||||
if mode == "delta_train":
|
||||
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
|
||||
if avg_data.empty is True:
|
||||
if avg_data.empty:
|
||||
return None
|
||||
|
||||
return plot.plot_delta(
|
||||
|
@ -234,8 +235,9 @@ class CompReport:
|
|||
metric=metric,
|
||||
name=conf,
|
||||
train_prev=self.train_prev,
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "stdev_train":
|
||||
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
|
||||
|
@ -251,8 +253,9 @@ class CompReport:
|
|||
name=conf,
|
||||
train_prev=self.train_prev,
|
||||
stdevs=st_data.T.to_numpy(),
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "diagonal":
|
||||
f_data = self.data(metric=metric + "_score", estimators=estimators)
|
||||
|
@ -268,8 +271,9 @@ class CompReport:
|
|||
metric=metric,
|
||||
name=conf,
|
||||
train_prev=self.train_prev,
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "shift":
|
||||
_shift_data = self.shift_data(metric=metric, estimators=estimators)
|
||||
|
@ -290,8 +294,9 @@ class CompReport:
|
|||
name=conf,
|
||||
train_prev=self.train_prev,
|
||||
counts=shift_counts.T.to_numpy(),
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
def to_md(
|
||||
|
@ -323,11 +328,12 @@ class CompReport:
|
|||
plot_modes = [m for m in modes if not m.endswith("table")]
|
||||
for mode in plot_modes:
|
||||
res += f"### {mode}\n"
|
||||
op = self.get_plots(
|
||||
_, op = self.get_plots(
|
||||
mode=mode,
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
save_fig=True,
|
||||
base_path=plot_path,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
@ -424,12 +430,15 @@ class DatasetReport:
|
|||
metric="acc",
|
||||
estimators=None,
|
||||
conf="default",
|
||||
return_fig=False,
|
||||
save_fig=True,
|
||||
base_path=None,
|
||||
backend=None,
|
||||
):
|
||||
if mode == "delta_train":
|
||||
_data = self.data(metric, estimators) if data is None else data
|
||||
avg_on_train = _data.groupby(level=1).mean()
|
||||
if avg_on_train.empty:
|
||||
return None
|
||||
prevs_on_train = np.sort(avg_on_train.index.unique(0))
|
||||
return plot.plot_delta(
|
||||
base_prevs=np.around(
|
||||
|
@ -441,12 +450,15 @@ class DatasetReport:
|
|||
name=conf,
|
||||
train_prev=None,
|
||||
avg="train",
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "stdev_train":
|
||||
_data = self.data(metric, estimators) if data is None else data
|
||||
avg_on_train = _data.groupby(level=1).mean()
|
||||
if avg_on_train.empty:
|
||||
return None
|
||||
prevs_on_train = np.sort(avg_on_train.index.unique(0))
|
||||
stdev_on_train = _data.groupby(level=1).std()
|
||||
return plot.plot_delta(
|
||||
|
@ -460,12 +472,15 @@ class DatasetReport:
|
|||
train_prev=None,
|
||||
stdevs=stdev_on_train.T.to_numpy(),
|
||||
avg="train",
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "delta_test":
|
||||
_data = self.data(metric, estimators) if data is None else data
|
||||
avg_on_test = _data.groupby(level=0).mean()
|
||||
if avg_on_test.empty:
|
||||
return None
|
||||
prevs_on_test = np.sort(avg_on_test.index.unique(0))
|
||||
return plot.plot_delta(
|
||||
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
|
||||
|
@ -475,12 +490,15 @@ class DatasetReport:
|
|||
name=conf,
|
||||
train_prev=None,
|
||||
avg="test",
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "stdev_test":
|
||||
_data = self.data(metric, estimators) if data is None else data
|
||||
avg_on_test = _data.groupby(level=0).mean()
|
||||
if avg_on_test.empty:
|
||||
return None
|
||||
prevs_on_test = np.sort(avg_on_test.index.unique(0))
|
||||
stdev_on_test = _data.groupby(level=0).std()
|
||||
return plot.plot_delta(
|
||||
|
@ -492,12 +510,15 @@ class DatasetReport:
|
|||
train_prev=None,
|
||||
stdevs=stdev_on_test.T.to_numpy(),
|
||||
avg="test",
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "shift":
|
||||
_shift_data = self.shift_data(metric, estimators) if data is None else data
|
||||
avg_shift = _shift_data.groupby(level=0).mean()
|
||||
if avg_shift.empty:
|
||||
return None
|
||||
count_shift = _shift_data.groupby(level=0).count()
|
||||
prevs_shift = np.sort(avg_shift.index.unique(0))
|
||||
return plot.plot_shift(
|
||||
|
@ -508,8 +529,9 @@ class DatasetReport:
|
|||
name=conf,
|
||||
train_prev=None,
|
||||
counts=count_shift.T.to_numpy(),
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
def to_md(
|
||||
|
@ -545,24 +567,26 @@ class DatasetReport:
|
|||
res += avg_on_train_tbl.to_html() + "\n\n"
|
||||
|
||||
if "delta_train" in dr_modes:
|
||||
delta_op = self.get_plots(
|
||||
_, delta_op = self.get_plots(
|
||||
data=_data,
|
||||
mode="delta_train",
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
base_path=plot_path,
|
||||
save_fig=True,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
if "stdev_train" in dr_modes:
|
||||
delta_stdev_op = self.get_plots(
|
||||
_, delta_stdev_op = self.get_plots(
|
||||
data=_data,
|
||||
mode="stdev_train",
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
base_path=plot_path,
|
||||
save_fig=True,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
|
@ -575,24 +599,26 @@ class DatasetReport:
|
|||
res += avg_on_test_tbl.to_html() + "\n\n"
|
||||
|
||||
if "delta_test" in dr_modes:
|
||||
delta_op = self.get_plots(
|
||||
_, delta_op = self.get_plots(
|
||||
data=_data,
|
||||
mode="delta_test",
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
base_path=plot_path,
|
||||
save_fig=True,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
if "stdev_test" in dr_modes:
|
||||
delta_stdev_op = self.get_plots(
|
||||
_, delta_stdev_op = self.get_plots(
|
||||
data=_data,
|
||||
mode="stdev_test",
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
base_path=plot_path,
|
||||
save_fig=True,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
|
@ -605,13 +631,14 @@ class DatasetReport:
|
|||
res += shift_on_train_tbl.to_html() + "\n\n"
|
||||
|
||||
if "shift" in dr_modes:
|
||||
shift_op = self.get_plots(
|
||||
_, shift_op = self.get_plots(
|
||||
data=_shift_data,
|
||||
mode="shift",
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
base_path=plot_path,
|
||||
save_fig=True,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
|
|
265
quacc/plot.py
265
quacc/plot.py
|
@ -1,265 +0,0 @@
|
|||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from cycler import cycler
|
||||
|
||||
from quacc import utils
|
||||
|
||||
matplotlib.use("agg")
|
||||
|
||||
|
||||
def _get_markers(n: int):
|
||||
ls = "ovx+sDph*^1234X><.Pd"
|
||||
if n > len(ls):
|
||||
ls = ls * (n / len(ls) + 1)
|
||||
return list(ls)[:n]
|
||||
|
||||
|
||||
def plot_delta(
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
avg=None,
|
||||
return_fig=False,
|
||||
base_path=None,
|
||||
) -> Path:
|
||||
_base_title = "delta_stdev" if stdevs is not None else "delta"
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
|
||||
|
||||
if base_path is None:
|
||||
base_path = utils.get_quacc_home() / "plots"
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
||||
|
||||
base_prevs = base_prevs[:, pos_class]
|
||||
for method, deltas, _cy in zip(columns, data, cy):
|
||||
ax.plot(
|
||||
base_prevs,
|
||||
deltas,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
if stdevs is not None:
|
||||
_col_idx = np.where(columns == method)[0]
|
||||
stdev = stdevs[_col_idx].flatten()
|
||||
nn_idx = np.intersect1d(
|
||||
np.where(deltas != np.nan)[0],
|
||||
np.where(stdev != np.nan)[0],
|
||||
)
|
||||
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
|
||||
ax.fill_between(
|
||||
_bps,
|
||||
_ds - _st,
|
||||
_ds + _st,
|
||||
color=_cy["color"],
|
||||
alpha=0.25,
|
||||
)
|
||||
|
||||
x_label = "test" if avg is None or avg == "train" else "train"
|
||||
ax.set(
|
||||
xlabel=f"{x_label} prevalence",
|
||||
ylabel=metric,
|
||||
title=title,
|
||||
)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
if return_fig:
|
||||
return fig
|
||||
|
||||
output_path = base_path / f"{title}.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
return output_path
|
||||
|
||||
|
||||
def plot_diagonal(
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
return_fig=False,
|
||||
base_path=None,
|
||||
):
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"diagonal_{name}_{metric}"
|
||||
|
||||
if base_path is None:
|
||||
base_path = utils.get_quacc_home() / "plots"
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
ax.set_aspect("equal")
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(
|
||||
color=[cm(i) for i in range(NUM_COLORS)],
|
||||
marker=_get_markers(NUM_COLORS),
|
||||
)
|
||||
|
||||
reference = np.array(reference)
|
||||
x_ticks = np.unique(reference)
|
||||
x_ticks.sort()
|
||||
|
||||
for deltas, _cy in zip(data, cy):
|
||||
ax.plot(
|
||||
reference,
|
||||
deltas,
|
||||
color=_cy["color"],
|
||||
linestyle="None",
|
||||
marker=_cy["marker"],
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
alpha=0.25,
|
||||
)
|
||||
|
||||
# ensure limits are equal for both axes
|
||||
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
|
||||
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
|
||||
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
|
||||
|
||||
for method, deltas, _cy in zip(columns, data, cy):
|
||||
slope, interc = np.polyfit(reference, deltas, 1)
|
||||
y_lr = np.array([slope * x + interc for x in _lims])
|
||||
ax.plot(
|
||||
_lims,
|
||||
y_lr,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
markersize="0",
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
# plot reference line
|
||||
ax.plot(
|
||||
_lims,
|
||||
_lims,
|
||||
color="black",
|
||||
linestyle="--",
|
||||
markersize=0,
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
if return_fig:
|
||||
return fig
|
||||
|
||||
output_path = base_path / f"{title}.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
return output_path
|
||||
|
||||
|
||||
def plot_shift(
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
return_fig=False,
|
||||
base_path=None,
|
||||
) -> Path:
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"shift_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"shift_{name}_avg_{metric}"
|
||||
|
||||
if base_path is None:
|
||||
base_path = utils.get_quacc_home() / "plots"
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
||||
|
||||
shift_prevs = shift_prevs[:, pos_class]
|
||||
for method, shifts, _cy in zip(columns, data, cy):
|
||||
ax.plot(
|
||||
shift_prevs,
|
||||
shifts,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
if counts is not None:
|
||||
_col_idx = np.where(columns == method)[0]
|
||||
count = counts[_col_idx].flatten()
|
||||
for prev, shift, cnt in zip(shift_prevs, shifts, count):
|
||||
label = f"{cnt}"
|
||||
plt.annotate(
|
||||
label,
|
||||
(prev, shift),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 10),
|
||||
ha="center",
|
||||
color=_cy["color"],
|
||||
fontsize=12.0,
|
||||
)
|
||||
|
||||
ax.set(xlabel="dataset shift", ylabel=metric, title=title)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
if return_fig:
|
||||
return fig
|
||||
|
||||
output_path = base_path / f"{title}.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
|
||||
return output_path
|
|
@ -0,0 +1 @@
|
|||
from quacc.plot.plot import get_backend, plot_delta, plot_diagonal, plot_shift
|
|
@ -0,0 +1,54 @@
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
class BasePlot:
|
||||
@classmethod
|
||||
def save_fig(cls, fig, base_path, title) -> Path:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def plot_diagonal(
|
||||
cls,
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def plot_delta(
|
||||
cls,
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="prevs.",
|
||||
y_label="error",
|
||||
legend=True,
|
||||
):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def plot_shift(
|
||||
cls,
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
):
|
||||
...
|
|
@ -0,0 +1,222 @@
|
|||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from cycler import cycler
|
||||
|
||||
from quacc import utils
|
||||
from quacc.plot.base import BasePlot
|
||||
|
||||
matplotlib.use("agg")
|
||||
|
||||
|
||||
class MplPlot(BasePlot):
|
||||
def _get_markers(self, n: int):
|
||||
ls = "ovx+sDph*^1234X><.Pd"
|
||||
if n > len(ls):
|
||||
ls = ls * (n / len(ls) + 1)
|
||||
return list(ls)[:n]
|
||||
|
||||
def save_fig(self, fig, base_path, title) -> Path:
|
||||
if base_path is None:
|
||||
base_path = utils.get_quacc_home() / "plots"
|
||||
output_path = base_path / f"{title}.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
return output_path
|
||||
|
||||
def plot_delta(
|
||||
self,
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="prevs.",
|
||||
y_label="error",
|
||||
legend=True,
|
||||
):
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
||||
|
||||
base_prevs = base_prevs[:, pos_class]
|
||||
for method, deltas, _cy in zip(columns, data, cy):
|
||||
ax.plot(
|
||||
base_prevs,
|
||||
deltas,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
if stdevs is not None:
|
||||
_col_idx = np.where(columns == method)[0]
|
||||
stdev = stdevs[_col_idx].flatten()
|
||||
nn_idx = np.intersect1d(
|
||||
np.where(deltas != np.nan)[0],
|
||||
np.where(stdev != np.nan)[0],
|
||||
)
|
||||
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
|
||||
ax.fill_between(
|
||||
_bps,
|
||||
_ds - _st,
|
||||
_ds + _st,
|
||||
color=_cy["color"],
|
||||
alpha=0.25,
|
||||
)
|
||||
|
||||
ax.set(
|
||||
xlabel=f"{x_label} prevalence",
|
||||
ylabel=y_label,
|
||||
title=title,
|
||||
)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
return fig
|
||||
|
||||
def plot_diagonal(
|
||||
self,
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
):
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
ax.set_aspect("equal")
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(
|
||||
color=[cm(i) for i in range(NUM_COLORS)],
|
||||
marker=self._get_markers(NUM_COLORS),
|
||||
)
|
||||
|
||||
reference = np.array(reference)
|
||||
x_ticks = np.unique(reference)
|
||||
x_ticks.sort()
|
||||
|
||||
for deltas, _cy in zip(data, cy):
|
||||
ax.plot(
|
||||
reference,
|
||||
deltas,
|
||||
color=_cy["color"],
|
||||
linestyle="None",
|
||||
marker=_cy["marker"],
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
alpha=0.25,
|
||||
)
|
||||
|
||||
# ensure limits are equal for both axes
|
||||
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
|
||||
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
|
||||
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
|
||||
|
||||
for method, deltas, _cy in zip(columns, data, cy):
|
||||
slope, interc = np.polyfit(reference, deltas, 1)
|
||||
y_lr = np.array([slope * x + interc for x in _lims])
|
||||
ax.plot(
|
||||
_lims,
|
||||
y_lr,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
markersize="0",
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
# plot reference line
|
||||
ax.plot(
|
||||
_lims,
|
||||
_lims,
|
||||
color="black",
|
||||
linestyle="--",
|
||||
markersize=0,
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
ax.set(xlabel=x_label, ylabel=y_label, title=title)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
return fig
|
||||
|
||||
def plot_shift(
|
||||
self,
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
):
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
||||
|
||||
shift_prevs = shift_prevs[:, pos_class]
|
||||
for method, shifts, _cy in zip(columns, data, cy):
|
||||
ax.plot(
|
||||
shift_prevs,
|
||||
shifts,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
if counts is not None:
|
||||
_col_idx = np.where(columns == method)[0]
|
||||
count = counts[_col_idx].flatten()
|
||||
for prev, shift, cnt in zip(shift_prevs, shifts, count):
|
||||
label = f"{cnt}"
|
||||
plt.annotate(
|
||||
label,
|
||||
(prev, shift),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 10),
|
||||
ha="center",
|
||||
color=_cy["color"],
|
||||
fontsize=12.0,
|
||||
)
|
||||
|
||||
ax.set(xlabel=x_label, ylabel=y_label, title=title)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
return fig
|
|
@ -0,0 +1,144 @@
|
|||
from quacc.plot.base import BasePlot
|
||||
from quacc.plot.mpl import MplPlot
|
||||
from quacc.plot.plotly import PlotlyPlot
|
||||
|
||||
__backend: BasePlot = MplPlot()
|
||||
|
||||
|
||||
def get_backend(be, theme=None):
|
||||
match be:
|
||||
case "matplotlib" | "mpl":
|
||||
return MplPlot()
|
||||
case "plotly":
|
||||
return PlotlyPlot(theme=theme)
|
||||
case _:
|
||||
return MplPlot()
|
||||
|
||||
|
||||
def plot_delta(
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
avg=None,
|
||||
save_fig=False,
|
||||
base_path=None,
|
||||
backend=None,
|
||||
):
|
||||
backend = __backend if backend is None else backend
|
||||
_base_title = "delta_stdev" if stdevs is not None else "delta"
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
|
||||
|
||||
x_label = f"{'test' if avg is None or avg == 'train' else 'train'} prevalence"
|
||||
y_label = f"{metric} error"
|
||||
fig = backend.plot_delta(
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
stdevs=stdevs,
|
||||
pos_class=pos_class,
|
||||
title=title,
|
||||
x_label=x_label,
|
||||
y_label=y_label,
|
||||
legend=legend,
|
||||
)
|
||||
|
||||
if save_fig:
|
||||
output_path = backend.save_fig(fig, base_path, title)
|
||||
return fig, output_path
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_diagonal(
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
save_fig=False,
|
||||
base_path=None,
|
||||
backend=None,
|
||||
):
|
||||
backend = __backend if backend is None else backend
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"diagonal_{name}_{metric}"
|
||||
|
||||
x_label = f"true {metric}"
|
||||
y_label = f"estim. {metric}"
|
||||
fig = backend.plot_diagonal(
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
pos_class=pos_class,
|
||||
title=title,
|
||||
x_label=x_label,
|
||||
y_label=y_label,
|
||||
legend=legend,
|
||||
)
|
||||
|
||||
if save_fig:
|
||||
output_path = backend.save_fig(fig, base_path, title)
|
||||
return fig, output_path
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_shift(
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
save_fig=False,
|
||||
base_path=None,
|
||||
backend=None,
|
||||
):
|
||||
backend = __backend if backend is None else backend
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"shift_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"shift_{name}_avg_{metric}"
|
||||
|
||||
x_label = "dataset shift"
|
||||
y_label = f"{metric} error"
|
||||
fig = backend.plot_shift(
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
counts=counts,
|
||||
pos_class=pos_class,
|
||||
title=title,
|
||||
x_label=x_label,
|
||||
y_label=y_label,
|
||||
legend=legend,
|
||||
)
|
||||
|
||||
if save_fig:
|
||||
output_path = backend.save_fig(fig, base_path, title)
|
||||
return fig, output_path
|
||||
|
||||
return fig
|
|
@ -0,0 +1,201 @@
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import plotly
|
||||
import plotly.graph_objects as go
|
||||
|
||||
from quacc.plot.base import BasePlot
|
||||
|
||||
|
||||
class PlotlyPlot(BasePlot):
|
||||
__themes = defaultdict(
|
||||
lambda: {
|
||||
"template": "seaborn",
|
||||
}
|
||||
)
|
||||
__themes = __themes | {
|
||||
"dark": {
|
||||
"template": "plotly_dark",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, theme=None):
|
||||
self.theme = PlotlyPlot.__themes[theme]
|
||||
|
||||
def hex_to_rgb(self, hex: str, t: float | None = None):
|
||||
hex = hex.lstrip("#")
|
||||
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
|
||||
if t is not None:
|
||||
rgb.append(t)
|
||||
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
|
||||
|
||||
def get_colors(self, num):
|
||||
match num:
|
||||
case v if v > 10:
|
||||
__colors = plotly.colors.qualitative.Light24
|
||||
case _:
|
||||
__colors = plotly.colors.qualitative.Plotly
|
||||
|
||||
def __generator(cs):
|
||||
while True:
|
||||
for c in cs:
|
||||
yield c
|
||||
|
||||
return __generator(__colors)
|
||||
|
||||
def update_layout(self, fig, title, x_label, y_label):
|
||||
fig.update_layout(
|
||||
title=title,
|
||||
xaxis_title=x_label,
|
||||
yaxis_title=y_label,
|
||||
template=self.theme["template"],
|
||||
)
|
||||
|
||||
def save_fig(self, fig, base_path, title) -> Path:
|
||||
return None
|
||||
|
||||
def plot_delta(
|
||||
self,
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="prevs.",
|
||||
y_label="error",
|
||||
legend=True,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
x = base_prevs[:, pos_class]
|
||||
line_colors = self.get_colors(len(columns))
|
||||
for name, delta in zip(columns, data):
|
||||
color = next(line_colors)
|
||||
_line = [
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=delta,
|
||||
mode="lines+markers",
|
||||
name=name,
|
||||
line=dict(color=self.hex_to_rgb(color)),
|
||||
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}",
|
||||
)
|
||||
]
|
||||
_error = []
|
||||
if stdevs is not None:
|
||||
_col_idx = np.where(columns == name)[0]
|
||||
stdev = stdevs[_col_idx].flatten()
|
||||
_error = [
|
||||
go.Scatter(
|
||||
x=np.concatenate([x, x[::-1]]),
|
||||
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
|
||||
name=int(_col_idx[0]),
|
||||
fill="toself",
|
||||
fillcolor=self.hex_to_rgb(color, t=0.2),
|
||||
line=dict(color="rgba(255, 255, 255, 0)"),
|
||||
hoverinfo="skip",
|
||||
showlegend=False,
|
||||
)
|
||||
]
|
||||
fig.add_traces(_line + _error)
|
||||
|
||||
self.update_layout(fig, title, x_label, y_label)
|
||||
return fig
|
||||
|
||||
def plot_diagonal(
|
||||
self,
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
x = reference
|
||||
line_colors = self.get_colors(len(columns))
|
||||
|
||||
_edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)]))
|
||||
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
|
||||
|
||||
for name, val in zip(columns, data):
|
||||
color = next(line_colors)
|
||||
slope, interc = np.polyfit(x, val, 1)
|
||||
y_lr = np.array([slope * _x + interc for _x in _lims[0]])
|
||||
fig.add_traces(
|
||||
[
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=val,
|
||||
customdata=np.stack((val - x,), axis=-1),
|
||||
mode="markers",
|
||||
name=name,
|
||||
line=dict(color=self.hex_to_rgb(color, t=0.5)),
|
||||
hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}",
|
||||
),
|
||||
go.Scatter(
|
||||
x=_lims[0],
|
||||
y=y_lr,
|
||||
mode="lines",
|
||||
name=name,
|
||||
line=dict(color=self.hex_to_rgb(color), width=3),
|
||||
showlegend=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=_lims[0],
|
||||
y=_lims[1],
|
||||
mode="lines",
|
||||
name="reference",
|
||||
showlegend=False,
|
||||
line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
|
||||
)
|
||||
)
|
||||
|
||||
self.update_layout(fig, title, x_label, y_label)
|
||||
fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0)
|
||||
return fig
|
||||
|
||||
def plot_shift(
|
||||
self,
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
x = shift_prevs[:, pos_class]
|
||||
line_colors = self.get_colors(len(columns))
|
||||
for name, delta in zip(columns, data):
|
||||
col_idx = (columns == name).nonzero()[0][0]
|
||||
color = next(line_colors)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=delta,
|
||||
customdata=np.stack((counts[col_idx],), axis=-1),
|
||||
mode="lines+markers",
|
||||
name=name,
|
||||
line=dict(color=self.hex_to_rgb(color)),
|
||||
hovertemplate="shift: %{x}<br>error: %{y}"
|
||||
+ "<br>count: %{customdata[0]}"
|
||||
if counts is not None
|
||||
else "",
|
||||
)
|
||||
)
|
||||
|
||||
self.update_layout(fig, title, x_label, y_label)
|
||||
return fig
|
Loading…
Reference in New Issue