Merge branch 'dash'

This commit is contained in:
Lorenzo Volpi 2023-11-30 18:30:44 +01:00
commit 44d820d4ab
17 changed files with 1592 additions and 587 deletions

2
.gitignore vendored
View File

@ -6,11 +6,13 @@ quavenv/*
__pycache__/* __pycache__/*
baselines/__pycache__/* baselines/__pycache__/*
baselines/densratio/__pycache__/* baselines/densratio/__pycache__/*
qcdash/__pycache__/*
qcpanel/__pycache__/* qcpanel/__pycache__/*
quacc/__pycache__/* quacc/__pycache__/*
quacc/evaluation/__pycache__/* quacc/evaluation/__pycache__/*
quacc/method/__pycache__/* quacc/method/__pycache__/*
quacc/quantification/__pycache__/* quacc/quantification/__pycache__/*
quacc/plot/__pycache__/*
tests/__pycache__/* tests/__pycache__/*
tests/*/__pycache__/* tests/*/__pycache__/*
tests/*/*/__pycache__/* tests/*/*/__pycache__/*

417
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -18,6 +18,7 @@ abstention = "^0.1.3.1"
main = "quacc.main:main" main = "quacc.main:main"
run = "run:run" run = "run:run"
panel = "qcpanel.run:run" panel = "qcpanel.run:run"
dash = "qcdash.app:run"
sync_up = "remote:sync_code" sync_up = "remote:sync_code"
sync_down = "remote:sync_output" sync_down = "remote:sync_output"
merge_data = "merge_data:run" merge_data = "merge_data:run"
@ -27,6 +28,7 @@ poetry_command = ""
[tool.poe.tasks] [tool.poe.tasks]
ilona = "ssh volpi@ilona.isti.cnr.it" ilona = "ssh volpi@ilona.isti.cnr.it"
dash = "gunicorn qcdash.app:server -b ilona.isti.cnr.it:33421"
[tool.poe.tasks.logr] [tool.poe.tasks.logr]
shell = """ shell = """
@ -48,6 +50,9 @@ ipympl = "^0.9.3"
ipykernel = "^6.26.0" ipykernel = "^6.26.0"
ipywidgets-bokeh = "^1.5.0" ipywidgets-bokeh = "^1.5.0"
pandas-stubs = "^2.1.1.230928" pandas-stubs = "^2.1.1.230928"
dash = "^2.14.1"
dash-bootstrap-components = "^1.5.0"
gunicorn = "^21.2.0"
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "--cov=quacc --capture=tee-sys" addopts = "--cov=quacc --capture=tee-sys"

0
qcdash/__init__.py Normal file
View File

413
qcdash/app.py Normal file
View File

@ -0,0 +1,413 @@
import json
import os
from collections import defaultdict
from json import JSONDecodeError
from pathlib import Path
from typing import List
from urllib.parse import parse_qsl, quote, urlencode, urlparse
import dash_bootstrap_components as dbc
import numpy as np
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
from dash.dash_table.Format import Format, Scheme
from quacc import plot
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import ttest_rel
backend = plot.get_backend("plotly")
valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes
def get_datasets(root: str | Path) -> List[DatasetReport]:
def load_dataset(dataset):
dataset = Path(dataset)
return DatasetReport.unpickle(dataset)
def explore_datasets(root: str | Path) -> List[Path]:
if isinstance(root, str):
root = Path(root)
if root.name == "plot":
return []
if not root.exists():
return []
dr_paths = []
for f in os.listdir(root):
if (root / f).is_dir():
dr_paths += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
dr_paths.append(root / f)
return dr_paths
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
def get_fig(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
match (view, mode):
case ("avg", _):
return dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
backend=backend,
)
case (_, _):
cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)]
return cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
backend=backend,
)
def get_table(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
_prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
match (view, mode):
case ("avg", "train_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
case ("avg", "test_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
case ("avg", "shift_table"):
return (
dr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
case ("avg", "stats_table"):
return ttest_rel(dr, metric=metric, estimators=estimators)
case (_, "train_table"):
cr = dr.crs[_prevs.index(view)]
return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
case (_, "shift_table"):
cr = dr.crs[_prevs.index(view)]
return (
cr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
def get_DataTable(df):
_primary = "#0d6efd"
if df.empty:
return None
df = df.reset_index()
columns = {
c: dict(
id=c,
name=c,
type="numeric",
format=Format(precision=6, scheme=Scheme.exponent),
)
for c in df.columns
}
columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
columns = list(columns.values())
data = df.to_dict("records")
return html.Div(
[
dash_table.DataTable(
data=data,
columns=columns,
id="table1",
style_cell={
"padding": "0 12px",
"border": "0",
"border-bottom": f"1px solid {_primary}",
},
style_table={
"margin": "6vh 15px",
"padding": "15px",
"maxWidth": "80vw",
"overflowX": "auto",
"border": f"0px solid {_primary}",
"border-radius": "6px",
},
)
],
style={
"display": "flex",
"flex-direction": "column",
# "justify-content": "center",
"align-items": "center",
"height": "100vh",
},
)
def get_Graph(fig):
if fig is None:
return None
return dcc.Graph(
id="graph1",
figure=fig,
style={
"margin": 0,
"height": "100vh",
},
)
datasets = get_datasets("output")
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# app.config.suppress_callback_exceptions = True
sidebar_style = {
"top": 0,
"left": 0,
"bottom": 0,
"padding": "1vw",
"padding-top": "2vw",
"margin": "0px",
"flex": 1,
"overflow": "scroll",
"height": "100vh",
}
content_style = {
"flex": 5,
"maxWidth": "84vw",
}
def parse_href(href: str):
parse_result = urlparse(href)
params = parse_qsl(parse_result.query)
return dict(params)
def get_sidebar():
return [
html.H4("Parameters:", style={"margin-bottom": "1vw"}),
dbc.Select(
# options=list(datasets.keys()),
# value=list(datasets.keys())[0],
id="dataset",
),
dbc.Select(
id="metric",
style={"margin-top": "1vh"},
),
html.Div(
[
dbc.RadioItems(
id="view",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
dbc.RadioItems(
id="mode",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
],
className="radio-group-v d-flex justify-content-around",
),
html.Div(
[
dbc.Checklist(
id="estimators",
className="btn-group mt-3",
inputClassName="btn-check",
labelClassName="btn btn-outline-primary",
labelCheckedClassName="active",
),
],
className="radio-group-wide",
),
]
app.layout = html.Div(
[
dcc.Interval(id="reload", interval=10 * 60 * 1000),
dcc.Location(id="url", refresh=False),
html.Div(
[
html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
html.Div(id="app_content", style=content_style),
],
id="page_layout",
style={"display": "flex", "flexDirection": "row"},
),
]
)
server = app.server
def apply_param(href, triggered_id, id, curr):
match triggered_id:
case "url":
params = parse_href(href)
return params.get(id, None)
case _:
return curr
@callback(
Output("dataset", "value"),
Output("dataset", "options"),
Input("url", "href"),
Input("reload", "n_intervals"),
State("dataset", "value"),
)
def update_dataset(href, n_intervals, dataset):
match ctx.triggered_id:
case "reload":
new_datasets = get_datasets("output")
global datasets
datasets = new_datasets
req_dataset = dataset
case "url":
params = parse_href(href)
req_dataset = params.get("dataset", None)
available_datasets = list(datasets.keys())
new_dataset = (
req_dataset if req_dataset in available_datasets else available_datasets[0]
)
return new_dataset, available_datasets
@callback(
Output("metric", "options"),
Output("metric", "value"),
Input("url", "href"),
Input("dataset", "value"),
State("metric", "value"),
)
def update_metrics(href, dataset, curr_metric):
dr = datasets[dataset]
old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
return valid_metrics, new_metric
@callback(
Output("estimators", "options"),
Output("estimators", "value"),
Input("url", "href"),
Input("dataset", "value"),
Input("metric", "value"),
State("estimators", "value"),
)
def update_estimators(href, dataset, metric, curr_estimators):
dr = datasets[dataset]
old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
if isinstance(old_estimators, str):
try:
old_estimators = json.loads(old_estimators)
except JSONDecodeError:
old_estimators = []
valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy()
new_estimators = valid_estimators[
np.isin(valid_estimators, old_estimators)
].tolist()
return valid_estimators, new_estimators
@callback(
Output("view", "options"),
Output("view", "value"),
Input("url", "href"),
Input("dataset", "value"),
State("view", "value"),
)
def update_view(href, dataset, curr_view):
dr = datasets[dataset]
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
new_view = old_view if old_view in valid_views else valid_views[0]
return valid_views, new_view
@callback(
Output("mode", "options"),
Output("mode", "value"),
Input("url", "href"),
Input("view", "value"),
State("mode", "value"),
)
def update_mode(href, view, curr_mode):
old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode)
valid_modes = valid_plot_modes[view]
new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
return valid_modes, new_mode
@callback(
Output("app_content", "children"),
Output("url", "search"),
Input("dataset", "value"),
Input("metric", "value"),
Input("estimators", "value"),
Input("view", "value"),
Input("mode", "value"),
)
def update_content(dataset, metric, estimators, view, mode):
search = urlencode(
dict(
dataset=dataset,
metric=metric,
estimators=json.dumps(estimators),
view=view,
mode=mode,
),
quote_via=quote,
)
dr = datasets[dataset]
match mode:
case m if m.endswith("table"):
df = get_table(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
dt = get_DataTable(df)
app_content = [] if dt is None else [dt]
case _:
fig = get_fig(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
g = get_Graph(fig)
app_content = [] if g is None else [g]
return app_content, f"?{search}"
def run():
app.run(debug=True)
if __name__ == "__main__":
run()

View File

@ -0,0 +1,86 @@
/* restyle radio items */
.radio-group .form-check {
padding-left: 0;
}
.radio-group .btn-group > .form-check:not(:last-child) > .btn {
border-top-right-radius: 0;
border-bottom-right-radius: 0;
}
.radio-group .btn-group > .form-check:not(:first-child) > .btn {
border-top-left-radius: 0;
border-bottom-left-radius: 0;
margin-left: -1px;
}
.radio-group-v{
padding: 0 10px;
}
.radio-group-v .form-check {
padding-top: 0;
padding-left: 2px;
}
.radio-group-v .btn-group {
flex-direction: column;
justify-content: center;
align-items: stretch;
flex-grow: 1;
}
.radio-group-v > .btn-group:first-child {
flex-grow: 0;
}
.radio-group-v > .btn-group:last-child {
margin-left: 20px;
}
.radio-group-v .btn-group > .form-check:not(:last-child) > .btn {
border-bottom-right-radius: 0;
border-bottom-left-radius: 0;
}
.radio-group-v .btn-group > .form-check:not(:first-child) > .btn {
border-top-right-radius: 0;
border-top-left-radius: 0;
margin-top: -3px;
}
.radio-group-v .btn-group .btn{
width: 100%;
}
.radio-group-wide .form-check {
padding-left: 0px;
}
.radio-group-wide .btn-group{
flex-wrap: wrap;
}
.radio-group-wide .btn-group .form-check{
flex: 1;
margin-top: -3px;
margin-left: -1px;
}
.radio-group-wide .btn-group .form-check .btn{
width: 100%;
border-radius: 0;
}
.radio-group-wide .btn-group .form-check:first-child > .btn{
border-top-left-radius: 10px;
}
.radio-group-wide .btn-group .form-check:last-child > .btn{
border-bottom-right-radius: 10px;
}
div#app-sidebar{
border-right: 2px solid var(--primary);
}

View File

@ -1,290 +0,0 @@
import argparse
import os
from pathlib import Path
import panel as pn
import param
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import DatasetReport
pn.extension(design="bootstrap")
def create_cr_plots(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
prev=None,
):
idx = [round(cr.train_prev[1] * 100) for cr in dr.crs].index(prev)
cr = dr.crs[idx]
estimators = CE.name[estimators]
_dpi = 112
return pn.pane.Matplotlib(
cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
),
tight=True,
format="png",
sizing_mode="scale_height",
# sizing_mode="scale_both",
)
def create_avg_plots(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
prev=None,
):
estimators = CE.name[estimators]
return pn.pane.Matplotlib(
dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
),
tight=True,
format="png",
sizing_mode="scale_height",
# sizing_mode="scale_both",
)
def build_cr_tab(dr: DatasetReport):
_data = dr.data()
_metrics = _data.columns.unique(0)
_estimators = _data.columns.unique(1)
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
metric_widget = pn.widgets.Select(
name="metric",
value="acc",
options=valid_metrics,
align="center",
)
valid_estimators = [e for e in _estimators if e != "ref"]
estimators_widget = pn.widgets.CheckButtonGroup(
name="estimators",
options=valid_estimators,
value=valid_estimators,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
valid_plot_modes = ["delta", "delta_stdev", "diagonal", "shift"]
plot_mode_widget = pn.widgets.RadioButtonGroup(
name="mode",
value=valid_plot_modes[0],
options=valid_plot_modes,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
valid_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
prevs_widget = pn.widgets.RadioButtonGroup(
name="train prevalence",
value=valid_prevs[0],
options=valid_prevs,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
)
plot_pane = pn.bind(
create_cr_plots,
dr=dr,
mode=plot_mode_widget,
metric=metric_widget,
estimators=estimators_widget,
prev=prevs_widget,
)
return pn.Row(
pn.Spacer(width=20),
pn.Column(
metric_widget,
pn.Row(
prevs_widget,
plot_mode_widget,
),
estimators_widget,
align="center",
),
pn.Spacer(sizing_mode="scale_width"),
plot_pane,
)
def build_avg_tab(dr: DatasetReport):
_data = dr.data()
_metrics = _data.columns.unique(0)
_estimators = _data.columns.unique(1)
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
metric_widget = pn.widgets.Select(
name="metric",
value="acc",
options=valid_metrics,
align="center",
)
valid_estimators = [e for e in _estimators if e != "ref"]
estimators_widget = pn.widgets.CheckButtonGroup(
name="estimators",
options=valid_estimators,
value=valid_estimators,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
valid_plot_modes = [
"delta_train",
"stdev_train",
"delta_test",
"stdev_test",
"shift",
]
plot_mode_widget = pn.widgets.RadioButtonGroup(
name="mode",
value=valid_plot_modes[0],
options=valid_plot_modes,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
plot_pane = pn.bind(
create_avg_plots,
dr=dr,
mode=plot_mode_widget,
metric=metric_widget,
estimators=estimators_widget,
)
return pn.Row(
pn.Spacer(width=20),
pn.Column(
metric_widget,
plot_mode_widget,
estimators_widget,
align="center",
),
pn.Spacer(sizing_mode="scale_width"),
plot_pane,
)
def build_dataset(dataset_path: Path):
dr: DatasetReport = DatasetReport.unpickle(dataset_path)
prevs_tab = ("train prevs.", build_cr_tab(dr))
avg_tab = ("avg", build_avg_tab(dr))
app = pn.Tabs(objects=[avg_tab, prevs_tab], dynamic=False)
app.servable()
return app
def explore_datasets(root: Path | str):
if isinstance(root, str):
root = Path(root)
if root.name == "plot":
return []
if not root.exists():
return []
drs = []
for f in os.listdir(root):
if (root / f).is_dir():
drs += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
drs.append((root, build_dataset(root / f)))
# drs.append((str(root),))
return drs
class PlotSelector(param.Parameterized):
metric = param.Selector(objects=["acc", "f1"])
view = param.Selector(objects=["train prevs", "avg"])
def plot_selector_widget():
return pn.Param(
PlotSelector.param,
widgets={
"metric": pn.widgets.Select,
"view": pn.widgets.Select,
},
)
def serve(address="localhost"):
# app = build_dataset(Path("output/rcv1_CCAT_9prevs/rcv1_CCAT_9prevs.pickle"))
__base_path = "output"
__tabs = sorted(
explore_datasets(__base_path), key=lambda t: (len(t[0].parts), t[0])
)
__tabs = [(str(p.relative_to(Path(__base_path))), d) for (p, d) in __tabs]
if len(__tabs) > 0:
app = pn.Tabs(
objects=__tabs,
tabs_location="left",
dynamic=False,
)
else:
app = pn.Column(
pn.pane.Str("No Dataset Found", styles={"font-size": "24pt"}),
align="center",
)
__port = 33420
__allowed = [address]
if address == "localhost":
__allowed.append("127.0.0.1")
pn.serve(
app,
autoreload=True,
port=__port,
show=False,
address=address,
websocket_origin=[f"{_a}:{__port}" for _a in __allowed],
)
def run():
parser = argparse.ArgumentParser()
parser.add_argument(
"--address",
action="store",
dest="address",
default="localhost",
)
args = parser.parse_args()
serve(address=args.address)

View File

@ -1,10 +1,11 @@
import argparse import argparse
import panel as pn import panel as pn
from panel.theme.fast import FastDarkTheme, FastDefaultTheme
from qcpanel.viewer import QuaccTestViewer from qcpanel.viewer import QuaccTestViewer
# pn.config.design = pn.theme.Bootstrap # pn.config.design = Fast
# pn.config.theme = "dark" # pn.config.theme = "dark"
pn.config.notifications = True pn.config.notifications = True
@ -59,8 +60,8 @@ def app_instance():
], ],
main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")], main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")],
modal=[qtv.modal_pane], modal=[qtv.modal_pane],
theme=pn.theme.DarkTheme, # theme=FastDefaultTheme,
theme_toggle=False, theme_toggle=True,
) )
app.servable() app.servable()

View File

@ -52,7 +52,7 @@ def create_plots(
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf="panel", conf="panel",
return_fig=True, save_fig=False,
) )
return ( return (
pn.pane.Matplotlib( pn.pane.Matplotlib(
@ -91,7 +91,7 @@ def create_plots(
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf="panel", conf="panel",
return_fig=True, save_fig=False,
) )
return ( return (
pn.pane.Matplotlib( pn.pane.Matplotlib(

View File

@ -6,7 +6,7 @@ from typing import List, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from quacc import plot import quacc.plot as plot
from quacc.utils import fmt_line_md from quacc.utils import fmt_line_md
@ -215,16 +215,17 @@ class CompReport:
def get_plots( def get_plots(
self, self,
mode="delta", mode="delta_train",
metric="acc", metric="acc",
estimators=None, estimators=None,
conf="default", conf="default",
return_fig=False, save_fig=True,
base_path=None, base_path=None,
backend=None,
) -> List[Tuple[str, Path]]: ) -> List[Tuple[str, Path]]:
if mode == "delta_train": if mode == "delta_train":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
if avg_data.empty is True: if avg_data.empty:
return None return None
return plot.plot_delta( return plot.plot_delta(
@ -234,8 +235,9 @@ class CompReport:
metric=metric, metric=metric,
name=conf, name=conf,
train_prev=self.train_prev, train_prev=self.train_prev,
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "stdev_train": elif mode == "stdev_train":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
@ -251,8 +253,9 @@ class CompReport:
name=conf, name=conf,
train_prev=self.train_prev, train_prev=self.train_prev,
stdevs=st_data.T.to_numpy(), stdevs=st_data.T.to_numpy(),
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "diagonal": elif mode == "diagonal":
f_data = self.data(metric=metric + "_score", estimators=estimators) f_data = self.data(metric=metric + "_score", estimators=estimators)
@ -268,8 +271,9 @@ class CompReport:
metric=metric, metric=metric,
name=conf, name=conf,
train_prev=self.train_prev, train_prev=self.train_prev,
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "shift": elif mode == "shift":
_shift_data = self.shift_data(metric=metric, estimators=estimators) _shift_data = self.shift_data(metric=metric, estimators=estimators)
@ -290,8 +294,9 @@ class CompReport:
name=conf, name=conf,
train_prev=self.train_prev, train_prev=self.train_prev,
counts=shift_counts.T.to_numpy(), counts=shift_counts.T.to_numpy(),
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
def to_md( def to_md(
@ -323,11 +328,12 @@ class CompReport:
plot_modes = [m for m in modes if not m.endswith("table")] plot_modes = [m for m in modes if not m.endswith("table")]
for mode in plot_modes: for mode in plot_modes:
res += f"### {mode}\n" res += f"### {mode}\n"
op = self.get_plots( _, op = self.get_plots(
mode=mode, mode=mode,
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
save_fig=True,
base_path=plot_path, base_path=plot_path,
) )
res += f"![plot_{mode}]({op.relative_to(op.parents[1]).as_posix()})\n" res += f"![plot_{mode}]({op.relative_to(op.parents[1]).as_posix()})\n"
@ -424,12 +430,15 @@ class DatasetReport:
metric="acc", metric="acc",
estimators=None, estimators=None,
conf="default", conf="default",
return_fig=False, save_fig=True,
base_path=None, base_path=None,
backend=None,
): ):
if mode == "delta_train": if mode == "delta_train":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
avg_on_train = _data.groupby(level=1).mean() avg_on_train = _data.groupby(level=1).mean()
if avg_on_train.empty:
return None
prevs_on_train = np.sort(avg_on_train.index.unique(0)) prevs_on_train = np.sort(avg_on_train.index.unique(0))
return plot.plot_delta( return plot.plot_delta(
base_prevs=np.around( base_prevs=np.around(
@ -441,12 +450,15 @@ class DatasetReport:
name=conf, name=conf,
train_prev=None, train_prev=None,
avg="train", avg="train",
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "stdev_train": elif mode == "stdev_train":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
avg_on_train = _data.groupby(level=1).mean() avg_on_train = _data.groupby(level=1).mean()
if avg_on_train.empty:
return None
prevs_on_train = np.sort(avg_on_train.index.unique(0)) prevs_on_train = np.sort(avg_on_train.index.unique(0))
stdev_on_train = _data.groupby(level=1).std() stdev_on_train = _data.groupby(level=1).std()
return plot.plot_delta( return plot.plot_delta(
@ -460,12 +472,15 @@ class DatasetReport:
train_prev=None, train_prev=None,
stdevs=stdev_on_train.T.to_numpy(), stdevs=stdev_on_train.T.to_numpy(),
avg="train", avg="train",
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "delta_test": elif mode == "delta_test":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
avg_on_test = _data.groupby(level=0).mean() avg_on_test = _data.groupby(level=0).mean()
if avg_on_test.empty:
return None
prevs_on_test = np.sort(avg_on_test.index.unique(0)) prevs_on_test = np.sort(avg_on_test.index.unique(0))
return plot.plot_delta( return plot.plot_delta(
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2), base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
@ -475,12 +490,15 @@ class DatasetReport:
name=conf, name=conf,
train_prev=None, train_prev=None,
avg="test", avg="test",
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "stdev_test": elif mode == "stdev_test":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
avg_on_test = _data.groupby(level=0).mean() avg_on_test = _data.groupby(level=0).mean()
if avg_on_test.empty:
return None
prevs_on_test = np.sort(avg_on_test.index.unique(0)) prevs_on_test = np.sort(avg_on_test.index.unique(0))
stdev_on_test = _data.groupby(level=0).std() stdev_on_test = _data.groupby(level=0).std()
return plot.plot_delta( return plot.plot_delta(
@ -492,12 +510,15 @@ class DatasetReport:
train_prev=None, train_prev=None,
stdevs=stdev_on_test.T.to_numpy(), stdevs=stdev_on_test.T.to_numpy(),
avg="test", avg="test",
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "shift": elif mode == "shift":
_shift_data = self.shift_data(metric, estimators) if data is None else data _shift_data = self.shift_data(metric, estimators) if data is None else data
avg_shift = _shift_data.groupby(level=0).mean() avg_shift = _shift_data.groupby(level=0).mean()
if avg_shift.empty:
return None
count_shift = _shift_data.groupby(level=0).count() count_shift = _shift_data.groupby(level=0).count()
prevs_shift = np.sort(avg_shift.index.unique(0)) prevs_shift = np.sort(avg_shift.index.unique(0))
return plot.plot_shift( return plot.plot_shift(
@ -508,8 +529,9 @@ class DatasetReport:
name=conf, name=conf,
train_prev=None, train_prev=None,
counts=count_shift.T.to_numpy(), counts=count_shift.T.to_numpy(),
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
def to_md( def to_md(
@ -545,24 +567,26 @@ class DatasetReport:
res += avg_on_train_tbl.to_html() + "\n\n" res += avg_on_train_tbl.to_html() + "\n\n"
if "delta_train" in dr_modes: if "delta_train" in dr_modes:
delta_op = self.get_plots( _, delta_op = self.get_plots(
data=_data, data=_data,
mode="delta_train", mode="delta_train",
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
base_path=plot_path, base_path=plot_path,
save_fig=True,
) )
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n"
if "stdev_train" in dr_modes: if "stdev_train" in dr_modes:
delta_stdev_op = self.get_plots( _, delta_stdev_op = self.get_plots(
data=_data, data=_data,
mode="stdev_train", mode="stdev_train",
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
base_path=plot_path, base_path=plot_path,
save_fig=True,
) )
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n"
@ -575,24 +599,26 @@ class DatasetReport:
res += avg_on_test_tbl.to_html() + "\n\n" res += avg_on_test_tbl.to_html() + "\n\n"
if "delta_test" in dr_modes: if "delta_test" in dr_modes:
delta_op = self.get_plots( _, delta_op = self.get_plots(
data=_data, data=_data,
mode="delta_test", mode="delta_test",
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
base_path=plot_path, base_path=plot_path,
save_fig=True,
) )
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n"
if "stdev_test" in dr_modes: if "stdev_test" in dr_modes:
delta_stdev_op = self.get_plots( _, delta_stdev_op = self.get_plots(
data=_data, data=_data,
mode="stdev_test", mode="stdev_test",
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
base_path=plot_path, base_path=plot_path,
save_fig=True,
) )
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n"
@ -605,13 +631,14 @@ class DatasetReport:
res += shift_on_train_tbl.to_html() + "\n\n" res += shift_on_train_tbl.to_html() + "\n\n"
if "shift" in dr_modes: if "shift" in dr_modes:
shift_op = self.get_plots( _, shift_op = self.get_plots(
data=_shift_data, data=_shift_data,
mode="shift", mode="shift",
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
base_path=plot_path, base_path=plot_path,
save_fig=True,
) )
res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n" res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n"

View File

@ -1,265 +0,0 @@
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from quacc import utils
matplotlib.use("agg")
def _get_markers(n: int):
ls = "ovx+sDph*^1234X><.Pd"
if n > len(ls):
ls = ls * (n / len(ls) + 1)
return list(ls)[:n]
def plot_delta(
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
avg=None,
return_fig=False,
base_path=None,
) -> Path:
_base_title = "delta_stdev" if stdevs is not None else "delta"
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
else:
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
base_prevs = base_prevs[:, pos_class]
for method, deltas, _cy in zip(columns, data, cy):
ax.plot(
base_prevs,
deltas,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if stdevs is not None:
_col_idx = np.where(columns == method)[0]
stdev = stdevs[_col_idx].flatten()
nn_idx = np.intersect1d(
np.where(deltas != np.nan)[0],
np.where(stdev != np.nan)[0],
)
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
ax.fill_between(
_bps,
_ds - _st,
_ds + _st,
color=_cy["color"],
alpha=0.25,
)
x_label = "test" if avg is None or avg == "train" else "train"
ax.set(
xlabel=f"{x_label} prevalence",
ylabel=metric,
title=title,
)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
if return_fig:
return fig
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
def plot_diagonal(
reference,
columns,
data,
*,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
return_fig=False,
base_path=None,
):
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
else:
title = f"diagonal_{name}_{metric}"
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
ax.set_aspect("equal")
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(
color=[cm(i) for i in range(NUM_COLORS)],
marker=_get_markers(NUM_COLORS),
)
reference = np.array(reference)
x_ticks = np.unique(reference)
x_ticks.sort()
for deltas, _cy in zip(data, cy):
ax.plot(
reference,
deltas,
color=_cy["color"],
linestyle="None",
marker=_cy["marker"],
markersize=3,
zorder=2,
alpha=0.25,
)
# ensure limits are equal for both axes
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
for method, deltas, _cy in zip(columns, data, cy):
slope, interc = np.polyfit(reference, deltas, 1)
y_lr = np.array([slope * x + interc for x in _lims])
ax.plot(
_lims,
y_lr,
label=method,
color=_cy["color"],
linestyle="-",
markersize="0",
zorder=1,
)
# plot reference line
ax.plot(
_lims,
_lims,
color="black",
linestyle="--",
markersize=0,
zorder=1,
)
ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
if return_fig:
return fig
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
def plot_shift(
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
return_fig=False,
base_path=None,
) -> Path:
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"shift_{name}_{t_prev_pos}_{metric}"
else:
title = f"shift_{name}_avg_{metric}"
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
shift_prevs = shift_prevs[:, pos_class]
for method, shifts, _cy in zip(columns, data, cy):
ax.plot(
shift_prevs,
shifts,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if counts is not None:
_col_idx = np.where(columns == method)[0]
count = counts[_col_idx].flatten()
for prev, shift, cnt in zip(shift_prevs, shifts, count):
label = f"{cnt}"
plt.annotate(
label,
(prev, shift),
textcoords="offset points",
xytext=(0, 10),
ha="center",
color=_cy["color"],
fontsize=12.0,
)
ax.set(xlabel="dataset shift", ylabel=metric, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
if return_fig:
return fig
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path

1
quacc/plot/__init__.py Normal file
View File

@ -0,0 +1 @@
from quacc.plot.plot import get_backend, plot_delta, plot_diagonal, plot_shift

54
quacc/plot/base.py Normal file
View File

@ -0,0 +1,54 @@
from pathlib import Path
class BasePlot:
@classmethod
def save_fig(cls, fig, base_path, title) -> Path:
...
@classmethod
def plot_diagonal(
cls,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
...
@classmethod
def plot_delta(
cls,
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
title="default",
x_label="prevs.",
y_label="error",
legend=True,
):
...
@classmethod
def plot_shift(
cls,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
...

222
quacc/plot/mpl.py Normal file
View File

@ -0,0 +1,222 @@
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from quacc import utils
from quacc.plot.base import BasePlot
matplotlib.use("agg")
class MplPlot(BasePlot):
def _get_markers(self, n: int):
ls = "ovx+sDph*^1234X><.Pd"
if n > len(ls):
ls = ls * (n / len(ls) + 1)
return list(ls)[:n]
def save_fig(self, fig, base_path, title) -> Path:
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
def plot_delta(
self,
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
title="default",
x_label="prevs.",
y_label="error",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
base_prevs = base_prevs[:, pos_class]
for method, deltas, _cy in zip(columns, data, cy):
ax.plot(
base_prevs,
deltas,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if stdevs is not None:
_col_idx = np.where(columns == method)[0]
stdev = stdevs[_col_idx].flatten()
nn_idx = np.intersect1d(
np.where(deltas != np.nan)[0],
np.where(stdev != np.nan)[0],
)
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
ax.fill_between(
_bps,
_ds - _st,
_ds + _st,
color=_cy["color"],
alpha=0.25,
)
ax.set(
xlabel=f"{x_label} prevalence",
ylabel=y_label,
title=title,
)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig
def plot_diagonal(
self,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
ax.set_aspect("equal")
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(
color=[cm(i) for i in range(NUM_COLORS)],
marker=self._get_markers(NUM_COLORS),
)
reference = np.array(reference)
x_ticks = np.unique(reference)
x_ticks.sort()
for deltas, _cy in zip(data, cy):
ax.plot(
reference,
deltas,
color=_cy["color"],
linestyle="None",
marker=_cy["marker"],
markersize=3,
zorder=2,
alpha=0.25,
)
# ensure limits are equal for both axes
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
for method, deltas, _cy in zip(columns, data, cy):
slope, interc = np.polyfit(reference, deltas, 1)
y_lr = np.array([slope * x + interc for x in _lims])
ax.plot(
_lims,
y_lr,
label=method,
color=_cy["color"],
linestyle="-",
markersize="0",
zorder=1,
)
# plot reference line
ax.plot(
_lims,
_lims,
color="black",
linestyle="--",
markersize=0,
zorder=1,
)
ax.set(xlabel=x_label, ylabel=y_label, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig
def plot_shift(
self,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
shift_prevs = shift_prevs[:, pos_class]
for method, shifts, _cy in zip(columns, data, cy):
ax.plot(
shift_prevs,
shifts,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if counts is not None:
_col_idx = np.where(columns == method)[0]
count = counts[_col_idx].flatten()
for prev, shift, cnt in zip(shift_prevs, shifts, count):
label = f"{cnt}"
plt.annotate(
label,
(prev, shift),
textcoords="offset points",
xytext=(0, 10),
ha="center",
color=_cy["color"],
fontsize=12.0,
)
ax.set(xlabel=x_label, ylabel=y_label, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig

144
quacc/plot/plot.py Normal file
View File

@ -0,0 +1,144 @@
from quacc.plot.base import BasePlot
from quacc.plot.mpl import MplPlot
from quacc.plot.plotly import PlotlyPlot
__backend: BasePlot = MplPlot()
def get_backend(be, theme=None):
match be:
case "matplotlib" | "mpl":
return MplPlot()
case "plotly":
return PlotlyPlot(theme=theme)
case _:
return MplPlot()
def plot_delta(
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
avg=None,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
_base_title = "delta_stdev" if stdevs is not None else "delta"
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
else:
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
x_label = f"{'test' if avg is None or avg == 'train' else 'train'} prevalence"
y_label = f"{metric} error"
fig = backend.plot_delta(
base_prevs,
columns,
data,
stdevs=stdevs,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig
def plot_diagonal(
reference,
columns,
data,
*,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
else:
title = f"diagonal_{name}_{metric}"
x_label = f"true {metric}"
y_label = f"estim. {metric}"
fig = backend.plot_diagonal(
reference,
columns,
data,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig
def plot_shift(
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"shift_{name}_{t_prev_pos}_{metric}"
else:
title = f"shift_{name}_avg_{metric}"
x_label = "dataset shift"
y_label = f"{metric} error"
fig = backend.plot_shift(
shift_prevs,
columns,
data,
counts=counts,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig

201
quacc/plot/plotly.py Normal file
View File

@ -0,0 +1,201 @@
from collections import defaultdict
from pathlib import Path
import numpy as np
import plotly
import plotly.graph_objects as go
from quacc.plot.base import BasePlot
class PlotlyPlot(BasePlot):
__themes = defaultdict(
lambda: {
"template": "seaborn",
}
)
__themes = __themes | {
"dark": {
"template": "plotly_dark",
},
}
def __init__(self, theme=None):
self.theme = PlotlyPlot.__themes[theme]
def hex_to_rgb(self, hex: str, t: float | None = None):
hex = hex.lstrip("#")
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
if t is not None:
rgb.append(t)
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
def get_colors(self, num):
match num:
case v if v > 10:
__colors = plotly.colors.qualitative.Light24
case _:
__colors = plotly.colors.qualitative.Plotly
def __generator(cs):
while True:
for c in cs:
yield c
return __generator(__colors)
def update_layout(self, fig, title, x_label, y_label):
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
template=self.theme["template"],
)
def save_fig(self, fig, base_path, title) -> Path:
return None
def plot_delta(
self,
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
title="default",
x_label="prevs.",
y_label="error",
legend=True,
) -> go.Figure:
fig = go.Figure()
x = base_prevs[:, pos_class]
line_colors = self.get_colors(len(columns))
for name, delta in zip(columns, data):
color = next(line_colors)
_line = [
go.Scatter(
x=x,
y=delta,
mode="lines+markers",
name=name,
line=dict(color=self.hex_to_rgb(color)),
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}",
)
]
_error = []
if stdevs is not None:
_col_idx = np.where(columns == name)[0]
stdev = stdevs[_col_idx].flatten()
_error = [
go.Scatter(
x=np.concatenate([x, x[::-1]]),
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
name=int(_col_idx[0]),
fill="toself",
fillcolor=self.hex_to_rgb(color, t=0.2),
line=dict(color="rgba(255, 255, 255, 0)"),
hoverinfo="skip",
showlegend=False,
)
]
fig.add_traces(_line + _error)
self.update_layout(fig, title, x_label, y_label)
return fig
def plot_diagonal(
self,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
) -> go.Figure:
fig = go.Figure()
x = reference
line_colors = self.get_colors(len(columns))
_edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)]))
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
for name, val in zip(columns, data):
color = next(line_colors)
slope, interc = np.polyfit(x, val, 1)
y_lr = np.array([slope * _x + interc for _x in _lims[0]])
fig.add_traces(
[
go.Scatter(
x=x,
y=val,
customdata=np.stack((val - x,), axis=-1),
mode="markers",
name=name,
line=dict(color=self.hex_to_rgb(color, t=0.5)),
hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}",
),
go.Scatter(
x=_lims[0],
y=y_lr,
mode="lines",
name=name,
line=dict(color=self.hex_to_rgb(color), width=3),
showlegend=False,
),
]
)
fig.add_trace(
go.Scatter(
x=_lims[0],
y=_lims[1],
mode="lines",
name="reference",
showlegend=False,
line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
)
)
self.update_layout(fig, title, x_label, y_label)
fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0)
return fig
def plot_shift(
self,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
) -> go.Figure:
fig = go.Figure()
x = shift_prevs[:, pos_class]
line_colors = self.get_colors(len(columns))
for name, delta in zip(columns, data):
col_idx = (columns == name).nonzero()[0][0]
color = next(line_colors)
fig.add_trace(
go.Scatter(
x=x,
y=delta,
customdata=np.stack((counts[col_idx],), axis=-1),
mode="lines+markers",
name=name,
line=dict(color=self.hex_to_rgb(color)),
hovertemplate="shift: %{x}<br>error: %{y}"
+ "<br>count: %{customdata[0]}"
if counts is not None
else "",
)
)
self.update_layout(fig, title, x_label, y_label)
return fig

View File

@ -24,6 +24,7 @@ __to_sync_up = {
"quacc", "quacc",
"baselines", "baselines",
"qcpanel", "qcpanel",
"qcdash",
], ],
"file": [ "file": [
"conf.yaml", "conf.yaml",