QuAcc/quacc/plot/plotly.py

210 lines
5.4 KiB
Python

import numpy as np
import plotly
import plotly.graph_objects as go
from quacc.plot.utils import _get_ref_limits
MODE = "lines"
L_WIDTH = 5
LEGEND = {
"font": {
"family": "DejaVu Sans",
"size": 24,
}
}
FONT = {"size": 24}
TEMPLATE = "ggplot2"
def _update_layout(fig, x_label, y_label, **kwargs):
fig.update_layout(
xaxis_title=x_label,
yaxis_title=y_label,
template=TEMPLATE,
font=FONT,
legend=LEGEND,
**kwargs,
)
def _hex_to_rgb(hex: str, t: float | None = None):
hex = hex.lstrip("#")
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
if t is not None:
rgb.append(t)
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
def _get_colors(num):
match num:
case v if v > 10:
__colors = plotly.colors.qualitative.Light24
case _:
__colors = plotly.colors.qualitative.G10
def __generator(cs):
while True:
for c in cs:
yield c
return __generator(__colors)
def plot_diagonal(
method_names,
true_accs,
estim_accs,
cls_name,
acc_name,
dataset_name,
*,
basedir=None,
) -> go.Figure:
fig = go.Figure()
line_colors = _get_colors(len(method_names))
_lims = _get_ref_limits(true_accs, estim_accs)
for name, x, estim in zip(method_names, true_accs, estim_accs):
color = next(line_colors)
slope, interc = np.polyfit(x, estim, 1)
fig.add_traces(
[
go.Scatter(
x=x,
y=estim,
customdata=np.stack((estim - x,), axis=-1),
mode="markers",
name=name,
marker=dict(color=_hex_to_rgb(color, t=0.5)),
hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}",
),
]
)
fig.add_trace(
go.Scatter(
x=_lims[0],
y=_lims[1],
mode="lines",
name="reference",
showlegend=False,
line=dict(color=_hex_to_rgb("#000000"), dash="dash"),
)
)
_update_layout(
fig,
x_label=f"True {acc_name}",
y_label=f"Estimated {acc_name}",
autosize=False,
width=1300,
height=1000,
yaxis_scaleanchor="x",
yaxis_scaleratio=1.0,
yaxis_range=[-0.1, 1.1],
)
# return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
return fig
def plot_delta(
method_names: list[str],
prevs: np.ndarray,
acc_errs: np.ndarray,
cls_name,
acc_name,
dataset_name,
prev_name,
*,
stdevs: np.ndarray | None = None,
basedir=None,
) -> go.Figure:
fig = go.Figure()
x = [str(bp) for bp in prevs]
line_colors = _get_colors(len(method_names))
if stdevs is None:
stdevs = [None] * len(method_names)
for name, delta, stdev in zip(method_names, acc_errs, stdevs):
color = next(line_colors)
_line = [
go.Scatter(
x=x,
y=delta,
mode=MODE,
name=name,
line=dict(color=_hex_to_rgb(color), width=L_WIDTH),
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}",
)
]
_error = []
if stdev is not None:
_error = [
go.Scatter(
x=np.concatenate([x, x[::-1]]),
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
name=name,
fill="toself",
fillcolor=_hex_to_rgb(color, t=0.2),
line=dict(color="rgba(255, 255, 255, 0)"),
hoverinfo="skip",
showlegend=False,
)
]
fig.add_traces(_line + _error)
_update_layout(
fig,
x_label=f"{prev_name} Prevalence",
y_label=f"Prediction Error for {acc_name}",
)
# return _save_or_return(
# fig,
# basedir,
# cls_name,
# acc_mame,
# dataset_name,
# "delta" if stdevs is None else "stdev",
# )
return fig
def plot_shift(
method_names: list[str],
prevs: np.ndarray,
acc_errs: np.ndarray,
cls_name,
acc_name,
dataset_name,
*,
counts: np.ndarray | None = None,
basedir=None,
) -> go.Figure:
fig = go.Figure()
x = prevs
line_colors = _get_colors(len(method_names))
if counts is None:
counts = [None] * len(method_names)
for name, delta, count in zip(method_names, acc_errs, counts):
color = next(line_colors)
fig.add_trace(
go.Scatter(
x=x,
y=delta,
customdata=np.stack((count,), axis=-1),
mode=MODE,
name=name,
line=dict(color=_hex_to_rgb(color), width=L_WIDTH),
hovertemplate="shift: %{x}<br>error: %{y}"
+ "<br>count: %{customdata[0]}"
if count is not None
else "",
)
)
_update_layout(
fig,
x_label="Amount of Prior Probability Shift",
y_label=f"Prediction Error for {acc_name}",
)
# return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "shift")
return fig