Merge branch 'dash'
This commit is contained in:
commit
44d820d4ab
|
@ -6,11 +6,13 @@ quavenv/*
|
||||||
__pycache__/*
|
__pycache__/*
|
||||||
baselines/__pycache__/*
|
baselines/__pycache__/*
|
||||||
baselines/densratio/__pycache__/*
|
baselines/densratio/__pycache__/*
|
||||||
|
qcdash/__pycache__/*
|
||||||
qcpanel/__pycache__/*
|
qcpanel/__pycache__/*
|
||||||
quacc/__pycache__/*
|
quacc/__pycache__/*
|
||||||
quacc/evaluation/__pycache__/*
|
quacc/evaluation/__pycache__/*
|
||||||
quacc/method/__pycache__/*
|
quacc/method/__pycache__/*
|
||||||
quacc/quantification/__pycache__/*
|
quacc/quantification/__pycache__/*
|
||||||
|
quacc/plot/__pycache__/*
|
||||||
tests/__pycache__/*
|
tests/__pycache__/*
|
||||||
tests/*/__pycache__/*
|
tests/*/__pycache__/*
|
||||||
tests/*/*/__pycache__/*
|
tests/*/*/__pycache__/*
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -18,6 +18,7 @@ abstention = "^0.1.3.1"
|
||||||
main = "quacc.main:main"
|
main = "quacc.main:main"
|
||||||
run = "run:run"
|
run = "run:run"
|
||||||
panel = "qcpanel.run:run"
|
panel = "qcpanel.run:run"
|
||||||
|
dash = "qcdash.app:run"
|
||||||
sync_up = "remote:sync_code"
|
sync_up = "remote:sync_code"
|
||||||
sync_down = "remote:sync_output"
|
sync_down = "remote:sync_output"
|
||||||
merge_data = "merge_data:run"
|
merge_data = "merge_data:run"
|
||||||
|
@ -27,6 +28,7 @@ poetry_command = ""
|
||||||
|
|
||||||
[tool.poe.tasks]
|
[tool.poe.tasks]
|
||||||
ilona = "ssh volpi@ilona.isti.cnr.it"
|
ilona = "ssh volpi@ilona.isti.cnr.it"
|
||||||
|
dash = "gunicorn qcdash.app:server -b ilona.isti.cnr.it:33421"
|
||||||
|
|
||||||
[tool.poe.tasks.logr]
|
[tool.poe.tasks.logr]
|
||||||
shell = """
|
shell = """
|
||||||
|
@ -48,6 +50,9 @@ ipympl = "^0.9.3"
|
||||||
ipykernel = "^6.26.0"
|
ipykernel = "^6.26.0"
|
||||||
ipywidgets-bokeh = "^1.5.0"
|
ipywidgets-bokeh = "^1.5.0"
|
||||||
pandas-stubs = "^2.1.1.230928"
|
pandas-stubs = "^2.1.1.230928"
|
||||||
|
dash = "^2.14.1"
|
||||||
|
dash-bootstrap-components = "^1.5.0"
|
||||||
|
gunicorn = "^21.2.0"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--cov=quacc --capture=tee-sys"
|
addopts = "--cov=quacc --capture=tee-sys"
|
||||||
|
|
|
@ -0,0 +1,413 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
from urllib.parse import parse_qsl, quote, urlencode, urlparse
|
||||||
|
|
||||||
|
import dash_bootstrap_components as dbc
|
||||||
|
import numpy as np
|
||||||
|
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
|
||||||
|
from dash.dash_table.Format import Format, Scheme
|
||||||
|
|
||||||
|
from quacc import plot
|
||||||
|
from quacc.evaluation.estimators import CE
|
||||||
|
from quacc.evaluation.report import CompReport, DatasetReport
|
||||||
|
from quacc.evaluation.stats import ttest_rel
|
||||||
|
|
||||||
|
backend = plot.get_backend("plotly")
|
||||||
|
|
||||||
|
valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
|
||||||
|
valid_plot_modes["avg"] = DatasetReport._default_dr_modes
|
||||||
|
|
||||||
|
|
||||||
|
def get_datasets(root: str | Path) -> List[DatasetReport]:
|
||||||
|
def load_dataset(dataset):
|
||||||
|
dataset = Path(dataset)
|
||||||
|
return DatasetReport.unpickle(dataset)
|
||||||
|
|
||||||
|
def explore_datasets(root: str | Path) -> List[Path]:
|
||||||
|
if isinstance(root, str):
|
||||||
|
root = Path(root)
|
||||||
|
|
||||||
|
if root.name == "plot":
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not root.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
dr_paths = []
|
||||||
|
for f in os.listdir(root):
|
||||||
|
if (root / f).is_dir():
|
||||||
|
dr_paths += explore_datasets(root / f)
|
||||||
|
elif f == f"{root.name}.pickle":
|
||||||
|
dr_paths.append(root / f)
|
||||||
|
|
||||||
|
return dr_paths
|
||||||
|
|
||||||
|
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
|
||||||
|
return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
|
||||||
|
|
||||||
|
|
||||||
|
def get_fig(dr: DatasetReport, metric, estimators, view, mode):
|
||||||
|
estimators = CE.name[estimators]
|
||||||
|
match (view, mode):
|
||||||
|
case ("avg", _):
|
||||||
|
return dr.get_plots(
|
||||||
|
mode=mode,
|
||||||
|
metric=metric,
|
||||||
|
estimators=estimators,
|
||||||
|
conf="plotly",
|
||||||
|
save_fig=False,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
case (_, _):
|
||||||
|
cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)]
|
||||||
|
return cr.get_plots(
|
||||||
|
mode=mode,
|
||||||
|
metric=metric,
|
||||||
|
estimators=estimators,
|
||||||
|
conf="plotly",
|
||||||
|
save_fig=False,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_table(dr: DatasetReport, metric, estimators, view, mode):
|
||||||
|
estimators = CE.name[estimators]
|
||||||
|
_prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
|
||||||
|
match (view, mode):
|
||||||
|
case ("avg", "train_table"):
|
||||||
|
return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
|
||||||
|
case ("avg", "test_table"):
|
||||||
|
return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||||
|
case ("avg", "shift_table"):
|
||||||
|
return (
|
||||||
|
dr.shift_data(metric=metric, estimators=estimators)
|
||||||
|
.groupby(level=0)
|
||||||
|
.mean()
|
||||||
|
)
|
||||||
|
case ("avg", "stats_table"):
|
||||||
|
return ttest_rel(dr, metric=metric, estimators=estimators)
|
||||||
|
case (_, "train_table"):
|
||||||
|
cr = dr.crs[_prevs.index(view)]
|
||||||
|
return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
||||||
|
case (_, "shift_table"):
|
||||||
|
cr = dr.crs[_prevs.index(view)]
|
||||||
|
return (
|
||||||
|
cr.shift_data(metric=metric, estimators=estimators)
|
||||||
|
.groupby(level=0)
|
||||||
|
.mean()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_DataTable(df):
|
||||||
|
_primary = "#0d6efd"
|
||||||
|
if df.empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
df = df.reset_index()
|
||||||
|
columns = {
|
||||||
|
c: dict(
|
||||||
|
id=c,
|
||||||
|
name=c,
|
||||||
|
type="numeric",
|
||||||
|
format=Format(precision=6, scheme=Scheme.exponent),
|
||||||
|
)
|
||||||
|
for c in df.columns
|
||||||
|
}
|
||||||
|
columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
|
||||||
|
columns = list(columns.values())
|
||||||
|
data = df.to_dict("records")
|
||||||
|
|
||||||
|
return html.Div(
|
||||||
|
[
|
||||||
|
dash_table.DataTable(
|
||||||
|
data=data,
|
||||||
|
columns=columns,
|
||||||
|
id="table1",
|
||||||
|
style_cell={
|
||||||
|
"padding": "0 12px",
|
||||||
|
"border": "0",
|
||||||
|
"border-bottom": f"1px solid {_primary}",
|
||||||
|
},
|
||||||
|
style_table={
|
||||||
|
"margin": "6vh 15px",
|
||||||
|
"padding": "15px",
|
||||||
|
"maxWidth": "80vw",
|
||||||
|
"overflowX": "auto",
|
||||||
|
"border": f"0px solid {_primary}",
|
||||||
|
"border-radius": "6px",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
style={
|
||||||
|
"display": "flex",
|
||||||
|
"flex-direction": "column",
|
||||||
|
# "justify-content": "center",
|
||||||
|
"align-items": "center",
|
||||||
|
"height": "100vh",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_Graph(fig):
|
||||||
|
if fig is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return dcc.Graph(
|
||||||
|
id="graph1",
|
||||||
|
figure=fig,
|
||||||
|
style={
|
||||||
|
"margin": 0,
|
||||||
|
"height": "100vh",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
datasets = get_datasets("output")
|
||||||
|
|
||||||
|
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
|
||||||
|
# app.config.suppress_callback_exceptions = True
|
||||||
|
sidebar_style = {
|
||||||
|
"top": 0,
|
||||||
|
"left": 0,
|
||||||
|
"bottom": 0,
|
||||||
|
"padding": "1vw",
|
||||||
|
"padding-top": "2vw",
|
||||||
|
"margin": "0px",
|
||||||
|
"flex": 1,
|
||||||
|
"overflow": "scroll",
|
||||||
|
"height": "100vh",
|
||||||
|
}
|
||||||
|
|
||||||
|
content_style = {
|
||||||
|
"flex": 5,
|
||||||
|
"maxWidth": "84vw",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_href(href: str):
|
||||||
|
parse_result = urlparse(href)
|
||||||
|
params = parse_qsl(parse_result.query)
|
||||||
|
return dict(params)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sidebar():
|
||||||
|
return [
|
||||||
|
html.H4("Parameters:", style={"margin-bottom": "1vw"}),
|
||||||
|
dbc.Select(
|
||||||
|
# options=list(datasets.keys()),
|
||||||
|
# value=list(datasets.keys())[0],
|
||||||
|
id="dataset",
|
||||||
|
),
|
||||||
|
dbc.Select(
|
||||||
|
id="metric",
|
||||||
|
style={"margin-top": "1vh"},
|
||||||
|
),
|
||||||
|
html.Div(
|
||||||
|
[
|
||||||
|
dbc.RadioItems(
|
||||||
|
id="view",
|
||||||
|
class_name="btn-group mt-3",
|
||||||
|
input_class_name="btn-check",
|
||||||
|
label_class_name="btn btn-outline-primary",
|
||||||
|
label_checked_class_name="active",
|
||||||
|
),
|
||||||
|
dbc.RadioItems(
|
||||||
|
id="mode",
|
||||||
|
class_name="btn-group mt-3",
|
||||||
|
input_class_name="btn-check",
|
||||||
|
label_class_name="btn btn-outline-primary",
|
||||||
|
label_checked_class_name="active",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
className="radio-group-v d-flex justify-content-around",
|
||||||
|
),
|
||||||
|
html.Div(
|
||||||
|
[
|
||||||
|
dbc.Checklist(
|
||||||
|
id="estimators",
|
||||||
|
className="btn-group mt-3",
|
||||||
|
inputClassName="btn-check",
|
||||||
|
labelClassName="btn btn-outline-primary",
|
||||||
|
labelCheckedClassName="active",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
className="radio-group-wide",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
app.layout = html.Div(
|
||||||
|
[
|
||||||
|
dcc.Interval(id="reload", interval=10 * 60 * 1000),
|
||||||
|
dcc.Location(id="url", refresh=False),
|
||||||
|
html.Div(
|
||||||
|
[
|
||||||
|
html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
|
||||||
|
html.Div(id="app_content", style=content_style),
|
||||||
|
],
|
||||||
|
id="page_layout",
|
||||||
|
style={"display": "flex", "flexDirection": "row"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
server = app.server
|
||||||
|
|
||||||
|
|
||||||
|
def apply_param(href, triggered_id, id, curr):
|
||||||
|
match triggered_id:
|
||||||
|
case "url":
|
||||||
|
params = parse_href(href)
|
||||||
|
return params.get(id, None)
|
||||||
|
case _:
|
||||||
|
return curr
|
||||||
|
|
||||||
|
|
||||||
|
@callback(
|
||||||
|
Output("dataset", "value"),
|
||||||
|
Output("dataset", "options"),
|
||||||
|
Input("url", "href"),
|
||||||
|
Input("reload", "n_intervals"),
|
||||||
|
State("dataset", "value"),
|
||||||
|
)
|
||||||
|
def update_dataset(href, n_intervals, dataset):
|
||||||
|
match ctx.triggered_id:
|
||||||
|
case "reload":
|
||||||
|
new_datasets = get_datasets("output")
|
||||||
|
global datasets
|
||||||
|
datasets = new_datasets
|
||||||
|
req_dataset = dataset
|
||||||
|
case "url":
|
||||||
|
params = parse_href(href)
|
||||||
|
req_dataset = params.get("dataset", None)
|
||||||
|
|
||||||
|
available_datasets = list(datasets.keys())
|
||||||
|
new_dataset = (
|
||||||
|
req_dataset if req_dataset in available_datasets else available_datasets[0]
|
||||||
|
)
|
||||||
|
return new_dataset, available_datasets
|
||||||
|
|
||||||
|
|
||||||
|
@callback(
|
||||||
|
Output("metric", "options"),
|
||||||
|
Output("metric", "value"),
|
||||||
|
Input("url", "href"),
|
||||||
|
Input("dataset", "value"),
|
||||||
|
State("metric", "value"),
|
||||||
|
)
|
||||||
|
def update_metrics(href, dataset, curr_metric):
|
||||||
|
dr = datasets[dataset]
|
||||||
|
old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
|
||||||
|
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
|
||||||
|
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
|
||||||
|
return valid_metrics, new_metric
|
||||||
|
|
||||||
|
|
||||||
|
@callback(
|
||||||
|
Output("estimators", "options"),
|
||||||
|
Output("estimators", "value"),
|
||||||
|
Input("url", "href"),
|
||||||
|
Input("dataset", "value"),
|
||||||
|
Input("metric", "value"),
|
||||||
|
State("estimators", "value"),
|
||||||
|
)
|
||||||
|
def update_estimators(href, dataset, metric, curr_estimators):
|
||||||
|
dr = datasets[dataset]
|
||||||
|
old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
|
||||||
|
if isinstance(old_estimators, str):
|
||||||
|
try:
|
||||||
|
old_estimators = json.loads(old_estimators)
|
||||||
|
except JSONDecodeError:
|
||||||
|
old_estimators = []
|
||||||
|
valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy()
|
||||||
|
new_estimators = valid_estimators[
|
||||||
|
np.isin(valid_estimators, old_estimators)
|
||||||
|
].tolist()
|
||||||
|
return valid_estimators, new_estimators
|
||||||
|
|
||||||
|
|
||||||
|
@callback(
|
||||||
|
Output("view", "options"),
|
||||||
|
Output("view", "value"),
|
||||||
|
Input("url", "href"),
|
||||||
|
Input("dataset", "value"),
|
||||||
|
State("view", "value"),
|
||||||
|
)
|
||||||
|
def update_view(href, dataset, curr_view):
|
||||||
|
dr = datasets[dataset]
|
||||||
|
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
|
||||||
|
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
|
||||||
|
new_view = old_view if old_view in valid_views else valid_views[0]
|
||||||
|
return valid_views, new_view
|
||||||
|
|
||||||
|
|
||||||
|
@callback(
|
||||||
|
Output("mode", "options"),
|
||||||
|
Output("mode", "value"),
|
||||||
|
Input("url", "href"),
|
||||||
|
Input("view", "value"),
|
||||||
|
State("mode", "value"),
|
||||||
|
)
|
||||||
|
def update_mode(href, view, curr_mode):
|
||||||
|
old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode)
|
||||||
|
valid_modes = valid_plot_modes[view]
|
||||||
|
new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
|
||||||
|
return valid_modes, new_mode
|
||||||
|
|
||||||
|
|
||||||
|
@callback(
|
||||||
|
Output("app_content", "children"),
|
||||||
|
Output("url", "search"),
|
||||||
|
Input("dataset", "value"),
|
||||||
|
Input("metric", "value"),
|
||||||
|
Input("estimators", "value"),
|
||||||
|
Input("view", "value"),
|
||||||
|
Input("mode", "value"),
|
||||||
|
)
|
||||||
|
def update_content(dataset, metric, estimators, view, mode):
|
||||||
|
search = urlencode(
|
||||||
|
dict(
|
||||||
|
dataset=dataset,
|
||||||
|
metric=metric,
|
||||||
|
estimators=json.dumps(estimators),
|
||||||
|
view=view,
|
||||||
|
mode=mode,
|
||||||
|
),
|
||||||
|
quote_via=quote,
|
||||||
|
)
|
||||||
|
dr = datasets[dataset]
|
||||||
|
match mode:
|
||||||
|
case m if m.endswith("table"):
|
||||||
|
df = get_table(
|
||||||
|
dr=dr,
|
||||||
|
metric=metric,
|
||||||
|
estimators=estimators,
|
||||||
|
view=view,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
dt = get_DataTable(df)
|
||||||
|
app_content = [] if dt is None else [dt]
|
||||||
|
case _:
|
||||||
|
fig = get_fig(
|
||||||
|
dr=dr,
|
||||||
|
metric=metric,
|
||||||
|
estimators=estimators,
|
||||||
|
view=view,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
g = get_Graph(fig)
|
||||||
|
app_content = [] if g is None else [g]
|
||||||
|
|
||||||
|
return app_content, f"?{search}"
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
app.run(debug=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
|
@ -0,0 +1,86 @@
|
||||||
|
/* restyle radio items */
|
||||||
|
.radio-group .form-check {
|
||||||
|
padding-left: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group .btn-group > .form-check:not(:last-child) > .btn {
|
||||||
|
border-top-right-radius: 0;
|
||||||
|
border-bottom-right-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group .btn-group > .form-check:not(:first-child) > .btn {
|
||||||
|
border-top-left-radius: 0;
|
||||||
|
border-bottom-left-radius: 0;
|
||||||
|
margin-left: -1px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-v{
|
||||||
|
padding: 0 10px;
|
||||||
|
}
|
||||||
|
.radio-group-v .form-check {
|
||||||
|
padding-top: 0;
|
||||||
|
padding-left: 2px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-v .btn-group {
|
||||||
|
flex-direction: column;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: stretch;
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-v > .btn-group:first-child {
|
||||||
|
flex-grow: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-v > .btn-group:last-child {
|
||||||
|
margin-left: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-v .btn-group > .form-check:not(:last-child) > .btn {
|
||||||
|
border-bottom-right-radius: 0;
|
||||||
|
border-bottom-left-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-v .btn-group > .form-check:not(:first-child) > .btn {
|
||||||
|
border-top-right-radius: 0;
|
||||||
|
border-top-left-radius: 0;
|
||||||
|
margin-top: -3px;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
.radio-group-v .btn-group .btn{
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-wide .form-check {
|
||||||
|
padding-left: 0px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-wide .btn-group{
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-wide .btn-group .form-check{
|
||||||
|
flex: 1;
|
||||||
|
margin-top: -3px;
|
||||||
|
margin-left: -1px;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
.radio-group-wide .btn-group .form-check .btn{
|
||||||
|
width: 100%;
|
||||||
|
border-radius: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-wide .btn-group .form-check:first-child > .btn{
|
||||||
|
border-top-left-radius: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-group-wide .btn-group .form-check:last-child > .btn{
|
||||||
|
border-bottom-right-radius: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
div#app-sidebar{
|
||||||
|
border-right: 2px solid var(--primary);
|
||||||
|
}
|
|
@ -1,290 +0,0 @@
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import panel as pn
|
|
||||||
import param
|
|
||||||
|
|
||||||
from quacc.evaluation.estimators import CE
|
|
||||||
from quacc.evaluation.report import DatasetReport
|
|
||||||
|
|
||||||
pn.extension(design="bootstrap")
|
|
||||||
|
|
||||||
|
|
||||||
def create_cr_plots(
|
|
||||||
dr: DatasetReport,
|
|
||||||
mode="delta",
|
|
||||||
metric="acc",
|
|
||||||
estimators=None,
|
|
||||||
prev=None,
|
|
||||||
):
|
|
||||||
idx = [round(cr.train_prev[1] * 100) for cr in dr.crs].index(prev)
|
|
||||||
cr = dr.crs[idx]
|
|
||||||
estimators = CE.name[estimators]
|
|
||||||
_dpi = 112
|
|
||||||
return pn.pane.Matplotlib(
|
|
||||||
cr.get_plots(
|
|
||||||
mode=mode,
|
|
||||||
metric=metric,
|
|
||||||
estimators=estimators,
|
|
||||||
conf="panel",
|
|
||||||
return_fig=True,
|
|
||||||
),
|
|
||||||
tight=True,
|
|
||||||
format="png",
|
|
||||||
sizing_mode="scale_height",
|
|
||||||
# sizing_mode="scale_both",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_avg_plots(
|
|
||||||
dr: DatasetReport,
|
|
||||||
mode="delta",
|
|
||||||
metric="acc",
|
|
||||||
estimators=None,
|
|
||||||
prev=None,
|
|
||||||
):
|
|
||||||
estimators = CE.name[estimators]
|
|
||||||
return pn.pane.Matplotlib(
|
|
||||||
dr.get_plots(
|
|
||||||
mode=mode,
|
|
||||||
metric=metric,
|
|
||||||
estimators=estimators,
|
|
||||||
conf="panel",
|
|
||||||
return_fig=True,
|
|
||||||
),
|
|
||||||
tight=True,
|
|
||||||
format="png",
|
|
||||||
sizing_mode="scale_height",
|
|
||||||
# sizing_mode="scale_both",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_cr_tab(dr: DatasetReport):
|
|
||||||
_data = dr.data()
|
|
||||||
_metrics = _data.columns.unique(0)
|
|
||||||
_estimators = _data.columns.unique(1)
|
|
||||||
|
|
||||||
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
|
|
||||||
metric_widget = pn.widgets.Select(
|
|
||||||
name="metric",
|
|
||||||
value="acc",
|
|
||||||
options=valid_metrics,
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_estimators = [e for e in _estimators if e != "ref"]
|
|
||||||
estimators_widget = pn.widgets.CheckButtonGroup(
|
|
||||||
name="estimators",
|
|
||||||
options=valid_estimators,
|
|
||||||
value=valid_estimators,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_plot_modes = ["delta", "delta_stdev", "diagonal", "shift"]
|
|
||||||
plot_mode_widget = pn.widgets.RadioButtonGroup(
|
|
||||||
name="mode",
|
|
||||||
value=valid_plot_modes[0],
|
|
||||||
options=valid_plot_modes,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
|
|
||||||
prevs_widget = pn.widgets.RadioButtonGroup(
|
|
||||||
name="train prevalence",
|
|
||||||
value=valid_prevs[0],
|
|
||||||
options=valid_prevs,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
)
|
|
||||||
|
|
||||||
plot_pane = pn.bind(
|
|
||||||
create_cr_plots,
|
|
||||||
dr=dr,
|
|
||||||
mode=plot_mode_widget,
|
|
||||||
metric=metric_widget,
|
|
||||||
estimators=estimators_widget,
|
|
||||||
prev=prevs_widget,
|
|
||||||
)
|
|
||||||
|
|
||||||
return pn.Row(
|
|
||||||
pn.Spacer(width=20),
|
|
||||||
pn.Column(
|
|
||||||
metric_widget,
|
|
||||||
pn.Row(
|
|
||||||
prevs_widget,
|
|
||||||
plot_mode_widget,
|
|
||||||
),
|
|
||||||
estimators_widget,
|
|
||||||
align="center",
|
|
||||||
),
|
|
||||||
pn.Spacer(sizing_mode="scale_width"),
|
|
||||||
plot_pane,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_avg_tab(dr: DatasetReport):
|
|
||||||
_data = dr.data()
|
|
||||||
_metrics = _data.columns.unique(0)
|
|
||||||
_estimators = _data.columns.unique(1)
|
|
||||||
|
|
||||||
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
|
|
||||||
metric_widget = pn.widgets.Select(
|
|
||||||
name="metric",
|
|
||||||
value="acc",
|
|
||||||
options=valid_metrics,
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_estimators = [e for e in _estimators if e != "ref"]
|
|
||||||
estimators_widget = pn.widgets.CheckButtonGroup(
|
|
||||||
name="estimators",
|
|
||||||
options=valid_estimators,
|
|
||||||
value=valid_estimators,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_plot_modes = [
|
|
||||||
"delta_train",
|
|
||||||
"stdev_train",
|
|
||||||
"delta_test",
|
|
||||||
"stdev_test",
|
|
||||||
"shift",
|
|
||||||
]
|
|
||||||
plot_mode_widget = pn.widgets.RadioButtonGroup(
|
|
||||||
name="mode",
|
|
||||||
value=valid_plot_modes[0],
|
|
||||||
options=valid_plot_modes,
|
|
||||||
button_style="outline",
|
|
||||||
button_type="primary",
|
|
||||||
align="center",
|
|
||||||
orientation="vertical",
|
|
||||||
sizing_mode="scale_width",
|
|
||||||
)
|
|
||||||
|
|
||||||
plot_pane = pn.bind(
|
|
||||||
create_avg_plots,
|
|
||||||
dr=dr,
|
|
||||||
mode=plot_mode_widget,
|
|
||||||
metric=metric_widget,
|
|
||||||
estimators=estimators_widget,
|
|
||||||
)
|
|
||||||
|
|
||||||
return pn.Row(
|
|
||||||
pn.Spacer(width=20),
|
|
||||||
pn.Column(
|
|
||||||
metric_widget,
|
|
||||||
plot_mode_widget,
|
|
||||||
estimators_widget,
|
|
||||||
align="center",
|
|
||||||
),
|
|
||||||
pn.Spacer(sizing_mode="scale_width"),
|
|
||||||
plot_pane,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_dataset(dataset_path: Path):
|
|
||||||
dr: DatasetReport = DatasetReport.unpickle(dataset_path)
|
|
||||||
|
|
||||||
prevs_tab = ("train prevs.", build_cr_tab(dr))
|
|
||||||
avg_tab = ("avg", build_avg_tab(dr))
|
|
||||||
|
|
||||||
app = pn.Tabs(objects=[avg_tab, prevs_tab], dynamic=False)
|
|
||||||
app.servable()
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
def explore_datasets(root: Path | str):
|
|
||||||
if isinstance(root, str):
|
|
||||||
root = Path(root)
|
|
||||||
|
|
||||||
if root.name == "plot":
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not root.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
drs = []
|
|
||||||
for f in os.listdir(root):
|
|
||||||
if (root / f).is_dir():
|
|
||||||
drs += explore_datasets(root / f)
|
|
||||||
elif f == f"{root.name}.pickle":
|
|
||||||
drs.append((root, build_dataset(root / f)))
|
|
||||||
# drs.append((str(root),))
|
|
||||||
|
|
||||||
return drs
|
|
||||||
|
|
||||||
|
|
||||||
class PlotSelector(param.Parameterized):
|
|
||||||
metric = param.Selector(objects=["acc", "f1"])
|
|
||||||
view = param.Selector(objects=["train prevs", "avg"])
|
|
||||||
|
|
||||||
|
|
||||||
def plot_selector_widget():
|
|
||||||
return pn.Param(
|
|
||||||
PlotSelector.param,
|
|
||||||
widgets={
|
|
||||||
"metric": pn.widgets.Select,
|
|
||||||
"view": pn.widgets.Select,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def serve(address="localhost"):
|
|
||||||
# app = build_dataset(Path("output/rcv1_CCAT_9prevs/rcv1_CCAT_9prevs.pickle"))
|
|
||||||
__base_path = "output"
|
|
||||||
__tabs = sorted(
|
|
||||||
explore_datasets(__base_path), key=lambda t: (len(t[0].parts), t[0])
|
|
||||||
)
|
|
||||||
__tabs = [(str(p.relative_to(Path(__base_path))), d) for (p, d) in __tabs]
|
|
||||||
if len(__tabs) > 0:
|
|
||||||
app = pn.Tabs(
|
|
||||||
objects=__tabs,
|
|
||||||
tabs_location="left",
|
|
||||||
dynamic=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
app = pn.Column(
|
|
||||||
pn.pane.Str("No Dataset Found", styles={"font-size": "24pt"}),
|
|
||||||
align="center",
|
|
||||||
)
|
|
||||||
|
|
||||||
__port = 33420
|
|
||||||
__allowed = [address]
|
|
||||||
if address == "localhost":
|
|
||||||
__allowed.append("127.0.0.1")
|
|
||||||
|
|
||||||
pn.serve(
|
|
||||||
app,
|
|
||||||
autoreload=True,
|
|
||||||
port=__port,
|
|
||||||
show=False,
|
|
||||||
address=address,
|
|
||||||
websocket_origin=[f"{_a}:{__port}" for _a in __allowed],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--address",
|
|
||||||
action="store",
|
|
||||||
dest="address",
|
|
||||||
default="localhost",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
serve(address=args.address)
|
|
|
@ -1,10 +1,11 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import panel as pn
|
import panel as pn
|
||||||
|
from panel.theme.fast import FastDarkTheme, FastDefaultTheme
|
||||||
|
|
||||||
from qcpanel.viewer import QuaccTestViewer
|
from qcpanel.viewer import QuaccTestViewer
|
||||||
|
|
||||||
# pn.config.design = pn.theme.Bootstrap
|
# pn.config.design = Fast
|
||||||
# pn.config.theme = "dark"
|
# pn.config.theme = "dark"
|
||||||
pn.config.notifications = True
|
pn.config.notifications = True
|
||||||
|
|
||||||
|
@ -59,8 +60,8 @@ def app_instance():
|
||||||
],
|
],
|
||||||
main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")],
|
main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")],
|
||||||
modal=[qtv.modal_pane],
|
modal=[qtv.modal_pane],
|
||||||
theme=pn.theme.DarkTheme,
|
# theme=FastDefaultTheme,
|
||||||
theme_toggle=False,
|
theme_toggle=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.servable()
|
app.servable()
|
||||||
|
|
|
@ -52,7 +52,7 @@ def create_plots(
|
||||||
metric=metric,
|
metric=metric,
|
||||||
estimators=estimators,
|
estimators=estimators,
|
||||||
conf="panel",
|
conf="panel",
|
||||||
return_fig=True,
|
save_fig=False,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
pn.pane.Matplotlib(
|
pn.pane.Matplotlib(
|
||||||
|
@ -91,7 +91,7 @@ def create_plots(
|
||||||
metric=metric,
|
metric=metric,
|
||||||
estimators=estimators,
|
estimators=estimators,
|
||||||
conf="panel",
|
conf="panel",
|
||||||
return_fig=True,
|
save_fig=False,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
pn.pane.Matplotlib(
|
pn.pane.Matplotlib(
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import List, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from quacc import plot
|
import quacc.plot as plot
|
||||||
from quacc.utils import fmt_line_md
|
from quacc.utils import fmt_line_md
|
||||||
|
|
||||||
|
|
||||||
|
@ -215,16 +215,17 @@ class CompReport:
|
||||||
|
|
||||||
def get_plots(
|
def get_plots(
|
||||||
self,
|
self,
|
||||||
mode="delta",
|
mode="delta_train",
|
||||||
metric="acc",
|
metric="acc",
|
||||||
estimators=None,
|
estimators=None,
|
||||||
conf="default",
|
conf="default",
|
||||||
return_fig=False,
|
save_fig=True,
|
||||||
base_path=None,
|
base_path=None,
|
||||||
|
backend=None,
|
||||||
) -> List[Tuple[str, Path]]:
|
) -> List[Tuple[str, Path]]:
|
||||||
if mode == "delta_train":
|
if mode == "delta_train":
|
||||||
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
|
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
|
||||||
if avg_data.empty is True:
|
if avg_data.empty:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return plot.plot_delta(
|
return plot.plot_delta(
|
||||||
|
@ -234,8 +235,9 @@ class CompReport:
|
||||||
metric=metric,
|
metric=metric,
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=self.train_prev,
|
train_prev=self.train_prev,
|
||||||
return_fig=return_fig,
|
save_fig=save_fig,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
elif mode == "stdev_train":
|
elif mode == "stdev_train":
|
||||||
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
|
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
|
||||||
|
@ -251,8 +253,9 @@ class CompReport:
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=self.train_prev,
|
train_prev=self.train_prev,
|
||||||
stdevs=st_data.T.to_numpy(),
|
stdevs=st_data.T.to_numpy(),
|
||||||
return_fig=return_fig,
|
save_fig=save_fig,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
elif mode == "diagonal":
|
elif mode == "diagonal":
|
||||||
f_data = self.data(metric=metric + "_score", estimators=estimators)
|
f_data = self.data(metric=metric + "_score", estimators=estimators)
|
||||||
|
@ -268,8 +271,9 @@ class CompReport:
|
||||||
metric=metric,
|
metric=metric,
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=self.train_prev,
|
train_prev=self.train_prev,
|
||||||
return_fig=return_fig,
|
save_fig=save_fig,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
elif mode == "shift":
|
elif mode == "shift":
|
||||||
_shift_data = self.shift_data(metric=metric, estimators=estimators)
|
_shift_data = self.shift_data(metric=metric, estimators=estimators)
|
||||||
|
@ -290,8 +294,9 @@ class CompReport:
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=self.train_prev,
|
train_prev=self.train_prev,
|
||||||
counts=shift_counts.T.to_numpy(),
|
counts=shift_counts.T.to_numpy(),
|
||||||
return_fig=return_fig,
|
save_fig=save_fig,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_md(
|
def to_md(
|
||||||
|
@ -323,11 +328,12 @@ class CompReport:
|
||||||
plot_modes = [m for m in modes if not m.endswith("table")]
|
plot_modes = [m for m in modes if not m.endswith("table")]
|
||||||
for mode in plot_modes:
|
for mode in plot_modes:
|
||||||
res += f"### {mode}\n"
|
res += f"### {mode}\n"
|
||||||
op = self.get_plots(
|
_, op = self.get_plots(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
metric=metric,
|
metric=metric,
|
||||||
estimators=estimators,
|
estimators=estimators,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
|
save_fig=True,
|
||||||
base_path=plot_path,
|
base_path=plot_path,
|
||||||
)
|
)
|
||||||
res += f".as_posix()})\n"
|
res += f".as_posix()})\n"
|
||||||
|
@ -424,12 +430,15 @@ class DatasetReport:
|
||||||
metric="acc",
|
metric="acc",
|
||||||
estimators=None,
|
estimators=None,
|
||||||
conf="default",
|
conf="default",
|
||||||
return_fig=False,
|
save_fig=True,
|
||||||
base_path=None,
|
base_path=None,
|
||||||
|
backend=None,
|
||||||
):
|
):
|
||||||
if mode == "delta_train":
|
if mode == "delta_train":
|
||||||
_data = self.data(metric, estimators) if data is None else data
|
_data = self.data(metric, estimators) if data is None else data
|
||||||
avg_on_train = _data.groupby(level=1).mean()
|
avg_on_train = _data.groupby(level=1).mean()
|
||||||
|
if avg_on_train.empty:
|
||||||
|
return None
|
||||||
prevs_on_train = np.sort(avg_on_train.index.unique(0))
|
prevs_on_train = np.sort(avg_on_train.index.unique(0))
|
||||||
return plot.plot_delta(
|
return plot.plot_delta(
|
||||||
base_prevs=np.around(
|
base_prevs=np.around(
|
||||||
|
@ -441,12 +450,15 @@ class DatasetReport:
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=None,
|
train_prev=None,
|
||||||
avg="train",
|
avg="train",
|
||||||
return_fig=return_fig,
|
save_fig=save_fig,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
elif mode == "stdev_train":
|
elif mode == "stdev_train":
|
||||||
_data = self.data(metric, estimators) if data is None else data
|
_data = self.data(metric, estimators) if data is None else data
|
||||||
avg_on_train = _data.groupby(level=1).mean()
|
avg_on_train = _data.groupby(level=1).mean()
|
||||||
|
if avg_on_train.empty:
|
||||||
|
return None
|
||||||
prevs_on_train = np.sort(avg_on_train.index.unique(0))
|
prevs_on_train = np.sort(avg_on_train.index.unique(0))
|
||||||
stdev_on_train = _data.groupby(level=1).std()
|
stdev_on_train = _data.groupby(level=1).std()
|
||||||
return plot.plot_delta(
|
return plot.plot_delta(
|
||||||
|
@ -460,12 +472,15 @@ class DatasetReport:
|
||||||
train_prev=None,
|
train_prev=None,
|
||||||
stdevs=stdev_on_train.T.to_numpy(),
|
stdevs=stdev_on_train.T.to_numpy(),
|
||||||
avg="train",
|
avg="train",
|
||||||
return_fig=return_fig,
|
save_fig=save_fig,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
elif mode == "delta_test":
|
elif mode == "delta_test":
|
||||||
_data = self.data(metric, estimators) if data is None else data
|
_data = self.data(metric, estimators) if data is None else data
|
||||||
avg_on_test = _data.groupby(level=0).mean()
|
avg_on_test = _data.groupby(level=0).mean()
|
||||||
|
if avg_on_test.empty:
|
||||||
|
return None
|
||||||
prevs_on_test = np.sort(avg_on_test.index.unique(0))
|
prevs_on_test = np.sort(avg_on_test.index.unique(0))
|
||||||
return plot.plot_delta(
|
return plot.plot_delta(
|
||||||
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
|
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
|
||||||
|
@ -475,12 +490,15 @@ class DatasetReport:
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=None,
|
train_prev=None,
|
||||||
avg="test",
|
avg="test",
|
||||||
return_fig=return_fig,
|
save_fig=save_fig,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
elif mode == "stdev_test":
|
elif mode == "stdev_test":
|
||||||
_data = self.data(metric, estimators) if data is None else data
|
_data = self.data(metric, estimators) if data is None else data
|
||||||
avg_on_test = _data.groupby(level=0).mean()
|
avg_on_test = _data.groupby(level=0).mean()
|
||||||
|
if avg_on_test.empty:
|
||||||
|
return None
|
||||||
prevs_on_test = np.sort(avg_on_test.index.unique(0))
|
prevs_on_test = np.sort(avg_on_test.index.unique(0))
|
||||||
stdev_on_test = _data.groupby(level=0).std()
|
stdev_on_test = _data.groupby(level=0).std()
|
||||||
return plot.plot_delta(
|
return plot.plot_delta(
|
||||||
|
@ -492,12 +510,15 @@ class DatasetReport:
|
||||||
train_prev=None,
|
train_prev=None,
|
||||||
stdevs=stdev_on_test.T.to_numpy(),
|
stdevs=stdev_on_test.T.to_numpy(),
|
||||||
avg="test",
|
avg="test",
|
||||||
return_fig=return_fig,
|
save_fig=save_fig,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
elif mode == "shift":
|
elif mode == "shift":
|
||||||
_shift_data = self.shift_data(metric, estimators) if data is None else data
|
_shift_data = self.shift_data(metric, estimators) if data is None else data
|
||||||
avg_shift = _shift_data.groupby(level=0).mean()
|
avg_shift = _shift_data.groupby(level=0).mean()
|
||||||
|
if avg_shift.empty:
|
||||||
|
return None
|
||||||
count_shift = _shift_data.groupby(level=0).count()
|
count_shift = _shift_data.groupby(level=0).count()
|
||||||
prevs_shift = np.sort(avg_shift.index.unique(0))
|
prevs_shift = np.sort(avg_shift.index.unique(0))
|
||||||
return plot.plot_shift(
|
return plot.plot_shift(
|
||||||
|
@ -508,8 +529,9 @@ class DatasetReport:
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=None,
|
train_prev=None,
|
||||||
counts=count_shift.T.to_numpy(),
|
counts=count_shift.T.to_numpy(),
|
||||||
return_fig=return_fig,
|
save_fig=save_fig,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_md(
|
def to_md(
|
||||||
|
@ -545,24 +567,26 @@ class DatasetReport:
|
||||||
res += avg_on_train_tbl.to_html() + "\n\n"
|
res += avg_on_train_tbl.to_html() + "\n\n"
|
||||||
|
|
||||||
if "delta_train" in dr_modes:
|
if "delta_train" in dr_modes:
|
||||||
delta_op = self.get_plots(
|
_, delta_op = self.get_plots(
|
||||||
data=_data,
|
data=_data,
|
||||||
mode="delta_train",
|
mode="delta_train",
|
||||||
metric=metric,
|
metric=metric,
|
||||||
estimators=estimators,
|
estimators=estimators,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
base_path=plot_path,
|
base_path=plot_path,
|
||||||
|
save_fig=True,
|
||||||
)
|
)
|
||||||
res += f".as_posix()})\n"
|
res += f".as_posix()})\n"
|
||||||
|
|
||||||
if "stdev_train" in dr_modes:
|
if "stdev_train" in dr_modes:
|
||||||
delta_stdev_op = self.get_plots(
|
_, delta_stdev_op = self.get_plots(
|
||||||
data=_data,
|
data=_data,
|
||||||
mode="stdev_train",
|
mode="stdev_train",
|
||||||
metric=metric,
|
metric=metric,
|
||||||
estimators=estimators,
|
estimators=estimators,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
base_path=plot_path,
|
base_path=plot_path,
|
||||||
|
save_fig=True,
|
||||||
)
|
)
|
||||||
res += f".as_posix()})\n"
|
res += f".as_posix()})\n"
|
||||||
|
|
||||||
|
@ -575,24 +599,26 @@ class DatasetReport:
|
||||||
res += avg_on_test_tbl.to_html() + "\n\n"
|
res += avg_on_test_tbl.to_html() + "\n\n"
|
||||||
|
|
||||||
if "delta_test" in dr_modes:
|
if "delta_test" in dr_modes:
|
||||||
delta_op = self.get_plots(
|
_, delta_op = self.get_plots(
|
||||||
data=_data,
|
data=_data,
|
||||||
mode="delta_test",
|
mode="delta_test",
|
||||||
metric=metric,
|
metric=metric,
|
||||||
estimators=estimators,
|
estimators=estimators,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
base_path=plot_path,
|
base_path=plot_path,
|
||||||
|
save_fig=True,
|
||||||
)
|
)
|
||||||
res += f".as_posix()})\n"
|
res += f".as_posix()})\n"
|
||||||
|
|
||||||
if "stdev_test" in dr_modes:
|
if "stdev_test" in dr_modes:
|
||||||
delta_stdev_op = self.get_plots(
|
_, delta_stdev_op = self.get_plots(
|
||||||
data=_data,
|
data=_data,
|
||||||
mode="stdev_test",
|
mode="stdev_test",
|
||||||
metric=metric,
|
metric=metric,
|
||||||
estimators=estimators,
|
estimators=estimators,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
base_path=plot_path,
|
base_path=plot_path,
|
||||||
|
save_fig=True,
|
||||||
)
|
)
|
||||||
res += f".as_posix()})\n"
|
res += f".as_posix()})\n"
|
||||||
|
|
||||||
|
@ -605,13 +631,14 @@ class DatasetReport:
|
||||||
res += shift_on_train_tbl.to_html() + "\n\n"
|
res += shift_on_train_tbl.to_html() + "\n\n"
|
||||||
|
|
||||||
if "shift" in dr_modes:
|
if "shift" in dr_modes:
|
||||||
shift_op = self.get_plots(
|
_, shift_op = self.get_plots(
|
||||||
data=_shift_data,
|
data=_shift_data,
|
||||||
mode="shift",
|
mode="shift",
|
||||||
metric=metric,
|
metric=metric,
|
||||||
estimators=estimators,
|
estimators=estimators,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
base_path=plot_path,
|
base_path=plot_path,
|
||||||
|
save_fig=True,
|
||||||
)
|
)
|
||||||
res += f".as_posix()})\n"
|
res += f".as_posix()})\n"
|
||||||
|
|
||||||
|
|
265
quacc/plot.py
265
quacc/plot.py
|
@ -1,265 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
from cycler import cycler
|
|
||||||
|
|
||||||
from quacc import utils
|
|
||||||
|
|
||||||
matplotlib.use("agg")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_markers(n: int):
|
|
||||||
ls = "ovx+sDph*^1234X><.Pd"
|
|
||||||
if n > len(ls):
|
|
||||||
ls = ls * (n / len(ls) + 1)
|
|
||||||
return list(ls)[:n]
|
|
||||||
|
|
||||||
|
|
||||||
def plot_delta(
|
|
||||||
base_prevs,
|
|
||||||
columns,
|
|
||||||
data,
|
|
||||||
*,
|
|
||||||
stdevs=None,
|
|
||||||
pos_class=1,
|
|
||||||
metric="acc",
|
|
||||||
name="default",
|
|
||||||
train_prev=None,
|
|
||||||
legend=True,
|
|
||||||
avg=None,
|
|
||||||
return_fig=False,
|
|
||||||
base_path=None,
|
|
||||||
) -> Path:
|
|
||||||
_base_title = "delta_stdev" if stdevs is not None else "delta"
|
|
||||||
if train_prev is not None:
|
|
||||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
|
||||||
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
|
|
||||||
else:
|
|
||||||
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
|
|
||||||
|
|
||||||
if base_path is None:
|
|
||||||
base_path = utils.get_quacc_home() / "plots"
|
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
ax.set_aspect("auto")
|
|
||||||
ax.grid()
|
|
||||||
|
|
||||||
NUM_COLORS = len(data)
|
|
||||||
cm = plt.get_cmap("tab10")
|
|
||||||
if NUM_COLORS > 10:
|
|
||||||
cm = plt.get_cmap("tab20")
|
|
||||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
|
||||||
|
|
||||||
base_prevs = base_prevs[:, pos_class]
|
|
||||||
for method, deltas, _cy in zip(columns, data, cy):
|
|
||||||
ax.plot(
|
|
||||||
base_prevs,
|
|
||||||
deltas,
|
|
||||||
label=method,
|
|
||||||
color=_cy["color"],
|
|
||||||
linestyle="-",
|
|
||||||
marker="o",
|
|
||||||
markersize=3,
|
|
||||||
zorder=2,
|
|
||||||
)
|
|
||||||
if stdevs is not None:
|
|
||||||
_col_idx = np.where(columns == method)[0]
|
|
||||||
stdev = stdevs[_col_idx].flatten()
|
|
||||||
nn_idx = np.intersect1d(
|
|
||||||
np.where(deltas != np.nan)[0],
|
|
||||||
np.where(stdev != np.nan)[0],
|
|
||||||
)
|
|
||||||
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
|
|
||||||
ax.fill_between(
|
|
||||||
_bps,
|
|
||||||
_ds - _st,
|
|
||||||
_ds + _st,
|
|
||||||
color=_cy["color"],
|
|
||||||
alpha=0.25,
|
|
||||||
)
|
|
||||||
|
|
||||||
x_label = "test" if avg is None or avg == "train" else "train"
|
|
||||||
ax.set(
|
|
||||||
xlabel=f"{x_label} prevalence",
|
|
||||||
ylabel=metric,
|
|
||||||
title=title,
|
|
||||||
)
|
|
||||||
|
|
||||||
if legend:
|
|
||||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
|
||||||
|
|
||||||
if return_fig:
|
|
||||||
return fig
|
|
||||||
|
|
||||||
output_path = base_path / f"{title}.png"
|
|
||||||
fig.savefig(output_path, bbox_inches="tight")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def plot_diagonal(
|
|
||||||
reference,
|
|
||||||
columns,
|
|
||||||
data,
|
|
||||||
*,
|
|
||||||
pos_class=1,
|
|
||||||
metric="acc",
|
|
||||||
name="default",
|
|
||||||
train_prev=None,
|
|
||||||
legend=True,
|
|
||||||
return_fig=False,
|
|
||||||
base_path=None,
|
|
||||||
):
|
|
||||||
if train_prev is not None:
|
|
||||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
|
||||||
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
|
|
||||||
else:
|
|
||||||
title = f"diagonal_{name}_{metric}"
|
|
||||||
|
|
||||||
if base_path is None:
|
|
||||||
base_path = utils.get_quacc_home() / "plots"
|
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
ax.set_aspect("auto")
|
|
||||||
ax.grid()
|
|
||||||
ax.set_aspect("equal")
|
|
||||||
|
|
||||||
NUM_COLORS = len(data)
|
|
||||||
cm = plt.get_cmap("tab10")
|
|
||||||
if NUM_COLORS > 10:
|
|
||||||
cm = plt.get_cmap("tab20")
|
|
||||||
cy = cycler(
|
|
||||||
color=[cm(i) for i in range(NUM_COLORS)],
|
|
||||||
marker=_get_markers(NUM_COLORS),
|
|
||||||
)
|
|
||||||
|
|
||||||
reference = np.array(reference)
|
|
||||||
x_ticks = np.unique(reference)
|
|
||||||
x_ticks.sort()
|
|
||||||
|
|
||||||
for deltas, _cy in zip(data, cy):
|
|
||||||
ax.plot(
|
|
||||||
reference,
|
|
||||||
deltas,
|
|
||||||
color=_cy["color"],
|
|
||||||
linestyle="None",
|
|
||||||
marker=_cy["marker"],
|
|
||||||
markersize=3,
|
|
||||||
zorder=2,
|
|
||||||
alpha=0.25,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ensure limits are equal for both axes
|
|
||||||
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
|
|
||||||
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
|
|
||||||
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
|
|
||||||
|
|
||||||
for method, deltas, _cy in zip(columns, data, cy):
|
|
||||||
slope, interc = np.polyfit(reference, deltas, 1)
|
|
||||||
y_lr = np.array([slope * x + interc for x in _lims])
|
|
||||||
ax.plot(
|
|
||||||
_lims,
|
|
||||||
y_lr,
|
|
||||||
label=method,
|
|
||||||
color=_cy["color"],
|
|
||||||
linestyle="-",
|
|
||||||
markersize="0",
|
|
||||||
zorder=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# plot reference line
|
|
||||||
ax.plot(
|
|
||||||
_lims,
|
|
||||||
_lims,
|
|
||||||
color="black",
|
|
||||||
linestyle="--",
|
|
||||||
markersize=0,
|
|
||||||
zorder=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title)
|
|
||||||
|
|
||||||
if legend:
|
|
||||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
|
||||||
|
|
||||||
if return_fig:
|
|
||||||
return fig
|
|
||||||
|
|
||||||
output_path = base_path / f"{title}.png"
|
|
||||||
fig.savefig(output_path, bbox_inches="tight")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def plot_shift(
|
|
||||||
shift_prevs,
|
|
||||||
columns,
|
|
||||||
data,
|
|
||||||
*,
|
|
||||||
counts=None,
|
|
||||||
pos_class=1,
|
|
||||||
metric="acc",
|
|
||||||
name="default",
|
|
||||||
train_prev=None,
|
|
||||||
legend=True,
|
|
||||||
return_fig=False,
|
|
||||||
base_path=None,
|
|
||||||
) -> Path:
|
|
||||||
if train_prev is not None:
|
|
||||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
|
||||||
title = f"shift_{name}_{t_prev_pos}_{metric}"
|
|
||||||
else:
|
|
||||||
title = f"shift_{name}_avg_{metric}"
|
|
||||||
|
|
||||||
if base_path is None:
|
|
||||||
base_path = utils.get_quacc_home() / "plots"
|
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
ax.set_aspect("auto")
|
|
||||||
ax.grid()
|
|
||||||
|
|
||||||
NUM_COLORS = len(data)
|
|
||||||
cm = plt.get_cmap("tab10")
|
|
||||||
if NUM_COLORS > 10:
|
|
||||||
cm = plt.get_cmap("tab20")
|
|
||||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
|
||||||
|
|
||||||
shift_prevs = shift_prevs[:, pos_class]
|
|
||||||
for method, shifts, _cy in zip(columns, data, cy):
|
|
||||||
ax.plot(
|
|
||||||
shift_prevs,
|
|
||||||
shifts,
|
|
||||||
label=method,
|
|
||||||
color=_cy["color"],
|
|
||||||
linestyle="-",
|
|
||||||
marker="o",
|
|
||||||
markersize=3,
|
|
||||||
zorder=2,
|
|
||||||
)
|
|
||||||
if counts is not None:
|
|
||||||
_col_idx = np.where(columns == method)[0]
|
|
||||||
count = counts[_col_idx].flatten()
|
|
||||||
for prev, shift, cnt in zip(shift_prevs, shifts, count):
|
|
||||||
label = f"{cnt}"
|
|
||||||
plt.annotate(
|
|
||||||
label,
|
|
||||||
(prev, shift),
|
|
||||||
textcoords="offset points",
|
|
||||||
xytext=(0, 10),
|
|
||||||
ha="center",
|
|
||||||
color=_cy["color"],
|
|
||||||
fontsize=12.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set(xlabel="dataset shift", ylabel=metric, title=title)
|
|
||||||
|
|
||||||
if legend:
|
|
||||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
|
||||||
|
|
||||||
if return_fig:
|
|
||||||
return fig
|
|
||||||
|
|
||||||
output_path = base_path / f"{title}.png"
|
|
||||||
fig.savefig(output_path, bbox_inches="tight")
|
|
||||||
|
|
||||||
return output_path
|
|
|
@ -0,0 +1 @@
|
||||||
|
from quacc.plot.plot import get_backend, plot_delta, plot_diagonal, plot_shift
|
|
@ -0,0 +1,54 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlot:
|
||||||
|
@classmethod
|
||||||
|
def save_fig(cls, fig, base_path, title) -> Path:
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def plot_diagonal(
|
||||||
|
cls,
|
||||||
|
reference,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="true",
|
||||||
|
y_label="estim.",
|
||||||
|
legend=True,
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def plot_delta(
|
||||||
|
cls,
|
||||||
|
base_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
stdevs=None,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="prevs.",
|
||||||
|
y_label="error",
|
||||||
|
legend=True,
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def plot_shift(
|
||||||
|
cls,
|
||||||
|
shift_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
counts=None,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="true",
|
||||||
|
y_label="estim.",
|
||||||
|
legend=True,
|
||||||
|
):
|
||||||
|
...
|
|
@ -0,0 +1,222 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from cycler import cycler
|
||||||
|
|
||||||
|
from quacc import utils
|
||||||
|
from quacc.plot.base import BasePlot
|
||||||
|
|
||||||
|
matplotlib.use("agg")
|
||||||
|
|
||||||
|
|
||||||
|
class MplPlot(BasePlot):
|
||||||
|
def _get_markers(self, n: int):
|
||||||
|
ls = "ovx+sDph*^1234X><.Pd"
|
||||||
|
if n > len(ls):
|
||||||
|
ls = ls * (n / len(ls) + 1)
|
||||||
|
return list(ls)[:n]
|
||||||
|
|
||||||
|
def save_fig(self, fig, base_path, title) -> Path:
|
||||||
|
if base_path is None:
|
||||||
|
base_path = utils.get_quacc_home() / "plots"
|
||||||
|
output_path = base_path / f"{title}.png"
|
||||||
|
fig.savefig(output_path, bbox_inches="tight")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
def plot_delta(
|
||||||
|
self,
|
||||||
|
base_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
stdevs=None,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="prevs.",
|
||||||
|
y_label="error",
|
||||||
|
legend=True,
|
||||||
|
):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.set_aspect("auto")
|
||||||
|
ax.grid()
|
||||||
|
|
||||||
|
NUM_COLORS = len(data)
|
||||||
|
cm = plt.get_cmap("tab10")
|
||||||
|
if NUM_COLORS > 10:
|
||||||
|
cm = plt.get_cmap("tab20")
|
||||||
|
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
||||||
|
|
||||||
|
base_prevs = base_prevs[:, pos_class]
|
||||||
|
for method, deltas, _cy in zip(columns, data, cy):
|
||||||
|
ax.plot(
|
||||||
|
base_prevs,
|
||||||
|
deltas,
|
||||||
|
label=method,
|
||||||
|
color=_cy["color"],
|
||||||
|
linestyle="-",
|
||||||
|
marker="o",
|
||||||
|
markersize=3,
|
||||||
|
zorder=2,
|
||||||
|
)
|
||||||
|
if stdevs is not None:
|
||||||
|
_col_idx = np.where(columns == method)[0]
|
||||||
|
stdev = stdevs[_col_idx].flatten()
|
||||||
|
nn_idx = np.intersect1d(
|
||||||
|
np.where(deltas != np.nan)[0],
|
||||||
|
np.where(stdev != np.nan)[0],
|
||||||
|
)
|
||||||
|
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
|
||||||
|
ax.fill_between(
|
||||||
|
_bps,
|
||||||
|
_ds - _st,
|
||||||
|
_ds + _st,
|
||||||
|
color=_cy["color"],
|
||||||
|
alpha=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set(
|
||||||
|
xlabel=f"{x_label} prevalence",
|
||||||
|
ylabel=y_label,
|
||||||
|
title=title,
|
||||||
|
)
|
||||||
|
|
||||||
|
if legend:
|
||||||
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_diagonal(
|
||||||
|
self,
|
||||||
|
reference,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="true",
|
||||||
|
y_label="estim.",
|
||||||
|
legend=True,
|
||||||
|
):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.set_aspect("auto")
|
||||||
|
ax.grid()
|
||||||
|
ax.set_aspect("equal")
|
||||||
|
|
||||||
|
NUM_COLORS = len(data)
|
||||||
|
cm = plt.get_cmap("tab10")
|
||||||
|
if NUM_COLORS > 10:
|
||||||
|
cm = plt.get_cmap("tab20")
|
||||||
|
cy = cycler(
|
||||||
|
color=[cm(i) for i in range(NUM_COLORS)],
|
||||||
|
marker=self._get_markers(NUM_COLORS),
|
||||||
|
)
|
||||||
|
|
||||||
|
reference = np.array(reference)
|
||||||
|
x_ticks = np.unique(reference)
|
||||||
|
x_ticks.sort()
|
||||||
|
|
||||||
|
for deltas, _cy in zip(data, cy):
|
||||||
|
ax.plot(
|
||||||
|
reference,
|
||||||
|
deltas,
|
||||||
|
color=_cy["color"],
|
||||||
|
linestyle="None",
|
||||||
|
marker=_cy["marker"],
|
||||||
|
markersize=3,
|
||||||
|
zorder=2,
|
||||||
|
alpha=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure limits are equal for both axes
|
||||||
|
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
|
||||||
|
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
|
||||||
|
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
|
||||||
|
|
||||||
|
for method, deltas, _cy in zip(columns, data, cy):
|
||||||
|
slope, interc = np.polyfit(reference, deltas, 1)
|
||||||
|
y_lr = np.array([slope * x + interc for x in _lims])
|
||||||
|
ax.plot(
|
||||||
|
_lims,
|
||||||
|
y_lr,
|
||||||
|
label=method,
|
||||||
|
color=_cy["color"],
|
||||||
|
linestyle="-",
|
||||||
|
markersize="0",
|
||||||
|
zorder=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# plot reference line
|
||||||
|
ax.plot(
|
||||||
|
_lims,
|
||||||
|
_lims,
|
||||||
|
color="black",
|
||||||
|
linestyle="--",
|
||||||
|
markersize=0,
|
||||||
|
zorder=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set(xlabel=x_label, ylabel=y_label, title=title)
|
||||||
|
|
||||||
|
if legend:
|
||||||
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_shift(
|
||||||
|
self,
|
||||||
|
shift_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
counts=None,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="true",
|
||||||
|
y_label="estim.",
|
||||||
|
legend=True,
|
||||||
|
):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.set_aspect("auto")
|
||||||
|
ax.grid()
|
||||||
|
|
||||||
|
NUM_COLORS = len(data)
|
||||||
|
cm = plt.get_cmap("tab10")
|
||||||
|
if NUM_COLORS > 10:
|
||||||
|
cm = plt.get_cmap("tab20")
|
||||||
|
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
||||||
|
|
||||||
|
shift_prevs = shift_prevs[:, pos_class]
|
||||||
|
for method, shifts, _cy in zip(columns, data, cy):
|
||||||
|
ax.plot(
|
||||||
|
shift_prevs,
|
||||||
|
shifts,
|
||||||
|
label=method,
|
||||||
|
color=_cy["color"],
|
||||||
|
linestyle="-",
|
||||||
|
marker="o",
|
||||||
|
markersize=3,
|
||||||
|
zorder=2,
|
||||||
|
)
|
||||||
|
if counts is not None:
|
||||||
|
_col_idx = np.where(columns == method)[0]
|
||||||
|
count = counts[_col_idx].flatten()
|
||||||
|
for prev, shift, cnt in zip(shift_prevs, shifts, count):
|
||||||
|
label = f"{cnt}"
|
||||||
|
plt.annotate(
|
||||||
|
label,
|
||||||
|
(prev, shift),
|
||||||
|
textcoords="offset points",
|
||||||
|
xytext=(0, 10),
|
||||||
|
ha="center",
|
||||||
|
color=_cy["color"],
|
||||||
|
fontsize=12.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set(xlabel=x_label, ylabel=y_label, title=title)
|
||||||
|
|
||||||
|
if legend:
|
||||||
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
|
return fig
|
|
@ -0,0 +1,144 @@
|
||||||
|
from quacc.plot.base import BasePlot
|
||||||
|
from quacc.plot.mpl import MplPlot
|
||||||
|
from quacc.plot.plotly import PlotlyPlot
|
||||||
|
|
||||||
|
__backend: BasePlot = MplPlot()
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend(be, theme=None):
|
||||||
|
match be:
|
||||||
|
case "matplotlib" | "mpl":
|
||||||
|
return MplPlot()
|
||||||
|
case "plotly":
|
||||||
|
return PlotlyPlot(theme=theme)
|
||||||
|
case _:
|
||||||
|
return MplPlot()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_delta(
|
||||||
|
base_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
stdevs=None,
|
||||||
|
pos_class=1,
|
||||||
|
metric="acc",
|
||||||
|
name="default",
|
||||||
|
train_prev=None,
|
||||||
|
legend=True,
|
||||||
|
avg=None,
|
||||||
|
save_fig=False,
|
||||||
|
base_path=None,
|
||||||
|
backend=None,
|
||||||
|
):
|
||||||
|
backend = __backend if backend is None else backend
|
||||||
|
_base_title = "delta_stdev" if stdevs is not None else "delta"
|
||||||
|
if train_prev is not None:
|
||||||
|
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||||
|
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
|
||||||
|
else:
|
||||||
|
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
|
||||||
|
|
||||||
|
x_label = f"{'test' if avg is None or avg == 'train' else 'train'} prevalence"
|
||||||
|
y_label = f"{metric} error"
|
||||||
|
fig = backend.plot_delta(
|
||||||
|
base_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
stdevs=stdevs,
|
||||||
|
pos_class=pos_class,
|
||||||
|
title=title,
|
||||||
|
x_label=x_label,
|
||||||
|
y_label=y_label,
|
||||||
|
legend=legend,
|
||||||
|
)
|
||||||
|
|
||||||
|
if save_fig:
|
||||||
|
output_path = backend.save_fig(fig, base_path, title)
|
||||||
|
return fig, output_path
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_diagonal(
|
||||||
|
reference,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
pos_class=1,
|
||||||
|
metric="acc",
|
||||||
|
name="default",
|
||||||
|
train_prev=None,
|
||||||
|
legend=True,
|
||||||
|
save_fig=False,
|
||||||
|
base_path=None,
|
||||||
|
backend=None,
|
||||||
|
):
|
||||||
|
backend = __backend if backend is None else backend
|
||||||
|
if train_prev is not None:
|
||||||
|
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||||
|
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
|
||||||
|
else:
|
||||||
|
title = f"diagonal_{name}_{metric}"
|
||||||
|
|
||||||
|
x_label = f"true {metric}"
|
||||||
|
y_label = f"estim. {metric}"
|
||||||
|
fig = backend.plot_diagonal(
|
||||||
|
reference,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
pos_class=pos_class,
|
||||||
|
title=title,
|
||||||
|
x_label=x_label,
|
||||||
|
y_label=y_label,
|
||||||
|
legend=legend,
|
||||||
|
)
|
||||||
|
|
||||||
|
if save_fig:
|
||||||
|
output_path = backend.save_fig(fig, base_path, title)
|
||||||
|
return fig, output_path
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_shift(
|
||||||
|
shift_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
counts=None,
|
||||||
|
pos_class=1,
|
||||||
|
metric="acc",
|
||||||
|
name="default",
|
||||||
|
train_prev=None,
|
||||||
|
legend=True,
|
||||||
|
save_fig=False,
|
||||||
|
base_path=None,
|
||||||
|
backend=None,
|
||||||
|
):
|
||||||
|
backend = __backend if backend is None else backend
|
||||||
|
if train_prev is not None:
|
||||||
|
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||||
|
title = f"shift_{name}_{t_prev_pos}_{metric}"
|
||||||
|
else:
|
||||||
|
title = f"shift_{name}_avg_{metric}"
|
||||||
|
|
||||||
|
x_label = "dataset shift"
|
||||||
|
y_label = f"{metric} error"
|
||||||
|
fig = backend.plot_shift(
|
||||||
|
shift_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
counts=counts,
|
||||||
|
pos_class=pos_class,
|
||||||
|
title=title,
|
||||||
|
x_label=x_label,
|
||||||
|
y_label=y_label,
|
||||||
|
legend=legend,
|
||||||
|
)
|
||||||
|
|
||||||
|
if save_fig:
|
||||||
|
output_path = backend.save_fig(fig, base_path, title)
|
||||||
|
return fig, output_path
|
||||||
|
|
||||||
|
return fig
|
|
@ -0,0 +1,201 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import plotly
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
|
||||||
|
from quacc.plot.base import BasePlot
|
||||||
|
|
||||||
|
|
||||||
|
class PlotlyPlot(BasePlot):
|
||||||
|
__themes = defaultdict(
|
||||||
|
lambda: {
|
||||||
|
"template": "seaborn",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
__themes = __themes | {
|
||||||
|
"dark": {
|
||||||
|
"template": "plotly_dark",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, theme=None):
|
||||||
|
self.theme = PlotlyPlot.__themes[theme]
|
||||||
|
|
||||||
|
def hex_to_rgb(self, hex: str, t: float | None = None):
|
||||||
|
hex = hex.lstrip("#")
|
||||||
|
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
|
||||||
|
if t is not None:
|
||||||
|
rgb.append(t)
|
||||||
|
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
|
||||||
|
|
||||||
|
def get_colors(self, num):
|
||||||
|
match num:
|
||||||
|
case v if v > 10:
|
||||||
|
__colors = plotly.colors.qualitative.Light24
|
||||||
|
case _:
|
||||||
|
__colors = plotly.colors.qualitative.Plotly
|
||||||
|
|
||||||
|
def __generator(cs):
|
||||||
|
while True:
|
||||||
|
for c in cs:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
return __generator(__colors)
|
||||||
|
|
||||||
|
def update_layout(self, fig, title, x_label, y_label):
|
||||||
|
fig.update_layout(
|
||||||
|
title=title,
|
||||||
|
xaxis_title=x_label,
|
||||||
|
yaxis_title=y_label,
|
||||||
|
template=self.theme["template"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_fig(self, fig, base_path, title) -> Path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def plot_delta(
|
||||||
|
self,
|
||||||
|
base_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
stdevs=None,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="prevs.",
|
||||||
|
y_label="error",
|
||||||
|
legend=True,
|
||||||
|
) -> go.Figure:
|
||||||
|
fig = go.Figure()
|
||||||
|
x = base_prevs[:, pos_class]
|
||||||
|
line_colors = self.get_colors(len(columns))
|
||||||
|
for name, delta in zip(columns, data):
|
||||||
|
color = next(line_colors)
|
||||||
|
_line = [
|
||||||
|
go.Scatter(
|
||||||
|
x=x,
|
||||||
|
y=delta,
|
||||||
|
mode="lines+markers",
|
||||||
|
name=name,
|
||||||
|
line=dict(color=self.hex_to_rgb(color)),
|
||||||
|
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
_error = []
|
||||||
|
if stdevs is not None:
|
||||||
|
_col_idx = np.where(columns == name)[0]
|
||||||
|
stdev = stdevs[_col_idx].flatten()
|
||||||
|
_error = [
|
||||||
|
go.Scatter(
|
||||||
|
x=np.concatenate([x, x[::-1]]),
|
||||||
|
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
|
||||||
|
name=int(_col_idx[0]),
|
||||||
|
fill="toself",
|
||||||
|
fillcolor=self.hex_to_rgb(color, t=0.2),
|
||||||
|
line=dict(color="rgba(255, 255, 255, 0)"),
|
||||||
|
hoverinfo="skip",
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
fig.add_traces(_line + _error)
|
||||||
|
|
||||||
|
self.update_layout(fig, title, x_label, y_label)
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_diagonal(
|
||||||
|
self,
|
||||||
|
reference,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="true",
|
||||||
|
y_label="estim.",
|
||||||
|
legend=True,
|
||||||
|
) -> go.Figure:
|
||||||
|
fig = go.Figure()
|
||||||
|
x = reference
|
||||||
|
line_colors = self.get_colors(len(columns))
|
||||||
|
|
||||||
|
_edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)]))
|
||||||
|
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
|
||||||
|
|
||||||
|
for name, val in zip(columns, data):
|
||||||
|
color = next(line_colors)
|
||||||
|
slope, interc = np.polyfit(x, val, 1)
|
||||||
|
y_lr = np.array([slope * _x + interc for _x in _lims[0]])
|
||||||
|
fig.add_traces(
|
||||||
|
[
|
||||||
|
go.Scatter(
|
||||||
|
x=x,
|
||||||
|
y=val,
|
||||||
|
customdata=np.stack((val - x,), axis=-1),
|
||||||
|
mode="markers",
|
||||||
|
name=name,
|
||||||
|
line=dict(color=self.hex_to_rgb(color, t=0.5)),
|
||||||
|
hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}",
|
||||||
|
),
|
||||||
|
go.Scatter(
|
||||||
|
x=_lims[0],
|
||||||
|
y=y_lr,
|
||||||
|
mode="lines",
|
||||||
|
name=name,
|
||||||
|
line=dict(color=self.hex_to_rgb(color), width=3),
|
||||||
|
showlegend=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=_lims[0],
|
||||||
|
y=_lims[1],
|
||||||
|
mode="lines",
|
||||||
|
name="reference",
|
||||||
|
showlegend=False,
|
||||||
|
line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_layout(fig, title, x_label, y_label)
|
||||||
|
fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0)
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_shift(
|
||||||
|
self,
|
||||||
|
shift_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
counts=None,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="true",
|
||||||
|
y_label="estim.",
|
||||||
|
legend=True,
|
||||||
|
) -> go.Figure:
|
||||||
|
fig = go.Figure()
|
||||||
|
x = shift_prevs[:, pos_class]
|
||||||
|
line_colors = self.get_colors(len(columns))
|
||||||
|
for name, delta in zip(columns, data):
|
||||||
|
col_idx = (columns == name).nonzero()[0][0]
|
||||||
|
color = next(line_colors)
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=x,
|
||||||
|
y=delta,
|
||||||
|
customdata=np.stack((counts[col_idx],), axis=-1),
|
||||||
|
mode="lines+markers",
|
||||||
|
name=name,
|
||||||
|
line=dict(color=self.hex_to_rgb(color)),
|
||||||
|
hovertemplate="shift: %{x}<br>error: %{y}"
|
||||||
|
+ "<br>count: %{customdata[0]}"
|
||||||
|
if counts is not None
|
||||||
|
else "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_layout(fig, title, x_label, y_label)
|
||||||
|
return fig
|
Loading…
Reference in New Issue