plotly plot backend added
This commit is contained in:
parent
f0bfb2e039
commit
c670f48b5b
265
quacc/plot.py
265
quacc/plot.py
|
@ -1,265 +0,0 @@
|
|||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from cycler import cycler
|
||||
|
||||
from quacc import utils
|
||||
|
||||
matplotlib.use("agg")
|
||||
|
||||
|
||||
def _get_markers(n: int):
|
||||
ls = "ovx+sDph*^1234X><.Pd"
|
||||
if n > len(ls):
|
||||
ls = ls * (n / len(ls) + 1)
|
||||
return list(ls)[:n]
|
||||
|
||||
|
||||
def plot_delta(
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
avg=None,
|
||||
return_fig=False,
|
||||
base_path=None,
|
||||
) -> Path:
|
||||
_base_title = "delta_stdev" if stdevs is not None else "delta"
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
|
||||
|
||||
if base_path is None:
|
||||
base_path = utils.get_quacc_home() / "plots"
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
||||
|
||||
base_prevs = base_prevs[:, pos_class]
|
||||
for method, deltas, _cy in zip(columns, data, cy):
|
||||
ax.plot(
|
||||
base_prevs,
|
||||
deltas,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
if stdevs is not None:
|
||||
_col_idx = np.where(columns == method)[0]
|
||||
stdev = stdevs[_col_idx].flatten()
|
||||
nn_idx = np.intersect1d(
|
||||
np.where(deltas != np.nan)[0],
|
||||
np.where(stdev != np.nan)[0],
|
||||
)
|
||||
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
|
||||
ax.fill_between(
|
||||
_bps,
|
||||
_ds - _st,
|
||||
_ds + _st,
|
||||
color=_cy["color"],
|
||||
alpha=0.25,
|
||||
)
|
||||
|
||||
x_label = "test" if avg is None or avg == "train" else "train"
|
||||
ax.set(
|
||||
xlabel=f"{x_label} prevalence",
|
||||
ylabel=metric,
|
||||
title=title,
|
||||
)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
if return_fig:
|
||||
return fig
|
||||
|
||||
output_path = base_path / f"{title}.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
return output_path
|
||||
|
||||
|
||||
def plot_diagonal(
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
return_fig=False,
|
||||
base_path=None,
|
||||
):
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"diagonal_{name}_{metric}"
|
||||
|
||||
if base_path is None:
|
||||
base_path = utils.get_quacc_home() / "plots"
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
ax.set_aspect("equal")
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(
|
||||
color=[cm(i) for i in range(NUM_COLORS)],
|
||||
marker=_get_markers(NUM_COLORS),
|
||||
)
|
||||
|
||||
reference = np.array(reference)
|
||||
x_ticks = np.unique(reference)
|
||||
x_ticks.sort()
|
||||
|
||||
for deltas, _cy in zip(data, cy):
|
||||
ax.plot(
|
||||
reference,
|
||||
deltas,
|
||||
color=_cy["color"],
|
||||
linestyle="None",
|
||||
marker=_cy["marker"],
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
alpha=0.25,
|
||||
)
|
||||
|
||||
# ensure limits are equal for both axes
|
||||
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
|
||||
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
|
||||
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
|
||||
|
||||
for method, deltas, _cy in zip(columns, data, cy):
|
||||
slope, interc = np.polyfit(reference, deltas, 1)
|
||||
y_lr = np.array([slope * x + interc for x in _lims])
|
||||
ax.plot(
|
||||
_lims,
|
||||
y_lr,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
markersize="0",
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
# plot reference line
|
||||
ax.plot(
|
||||
_lims,
|
||||
_lims,
|
||||
color="black",
|
||||
linestyle="--",
|
||||
markersize=0,
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
if return_fig:
|
||||
return fig
|
||||
|
||||
output_path = base_path / f"{title}.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
return output_path
|
||||
|
||||
|
||||
def plot_shift(
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
return_fig=False,
|
||||
base_path=None,
|
||||
) -> Path:
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"shift_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"shift_{name}_avg_{metric}"
|
||||
|
||||
if base_path is None:
|
||||
base_path = utils.get_quacc_home() / "plots"
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
||||
|
||||
shift_prevs = shift_prevs[:, pos_class]
|
||||
for method, shifts, _cy in zip(columns, data, cy):
|
||||
ax.plot(
|
||||
shift_prevs,
|
||||
shifts,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
if counts is not None:
|
||||
_col_idx = np.where(columns == method)[0]
|
||||
count = counts[_col_idx].flatten()
|
||||
for prev, shift, cnt in zip(shift_prevs, shifts, count):
|
||||
label = f"{cnt}"
|
||||
plt.annotate(
|
||||
label,
|
||||
(prev, shift),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 10),
|
||||
ha="center",
|
||||
color=_cy["color"],
|
||||
fontsize=12.0,
|
||||
)
|
||||
|
||||
ax.set(xlabel="dataset shift", ylabel=metric, title=title)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
if return_fig:
|
||||
return fig
|
||||
|
||||
output_path = base_path / f"{title}.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
|
||||
return output_path
|
|
@ -0,0 +1,201 @@
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import plotly
|
||||
import plotly.graph_objects as go
|
||||
|
||||
from quacc.plot.base import BasePlot
|
||||
|
||||
|
||||
class PlotlyPlot(BasePlot):
|
||||
__themes = defaultdict(
|
||||
lambda: {
|
||||
"template": "seaborn",
|
||||
}
|
||||
)
|
||||
__themes = __themes | {
|
||||
"dark": {
|
||||
"template": "plotly_dark",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, theme=None):
|
||||
self.theme = PlotlyPlot.__themes[theme]
|
||||
|
||||
def hex_to_rgb(self, hex: str, t: float | None = None):
|
||||
hex = hex.lstrip("#")
|
||||
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
|
||||
if t is not None:
|
||||
rgb.append(t)
|
||||
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
|
||||
|
||||
def get_colors(self, num):
|
||||
match num:
|
||||
case v if v > 10:
|
||||
__colors = plotly.colors.qualitative.Light24
|
||||
case _:
|
||||
__colors = plotly.colors.qualitative.Plotly
|
||||
|
||||
def __generator(cs):
|
||||
while True:
|
||||
for c in cs:
|
||||
yield c
|
||||
|
||||
return __generator(__colors)
|
||||
|
||||
def update_layout(self, fig, title, x_label, y_label):
|
||||
fig.update_layout(
|
||||
title=title,
|
||||
xaxis_title=x_label,
|
||||
yaxis_title=y_label,
|
||||
template=self.theme["template"],
|
||||
)
|
||||
|
||||
def save_fig(self, fig, base_path, title) -> Path:
|
||||
return None
|
||||
|
||||
def plot_delta(
|
||||
self,
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="prevs.",
|
||||
y_label="error",
|
||||
legend=True,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
x = base_prevs[:, pos_class]
|
||||
line_colors = self.get_colors(len(columns))
|
||||
for name, delta in zip(columns, data):
|
||||
color = next(line_colors)
|
||||
_line = [
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=delta,
|
||||
mode="lines+markers",
|
||||
name=name,
|
||||
line=dict(color=self.hex_to_rgb(color)),
|
||||
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}",
|
||||
)
|
||||
]
|
||||
_error = []
|
||||
if stdevs is not None:
|
||||
_col_idx = np.where(columns == name)[0]
|
||||
stdev = stdevs[_col_idx].flatten()
|
||||
_error = [
|
||||
go.Scatter(
|
||||
x=np.concatenate([x, x[::-1]]),
|
||||
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
|
||||
name=int(_col_idx[0]),
|
||||
fill="toself",
|
||||
fillcolor=self.hex_to_rgb(color, t=0.2),
|
||||
line=dict(color="rgba(255, 255, 255, 0)"),
|
||||
hoverinfo="skip",
|
||||
showlegend=False,
|
||||
)
|
||||
]
|
||||
fig.add_traces(_line + _error)
|
||||
|
||||
self.update_layout(fig, title, x_label, y_label)
|
||||
return fig
|
||||
|
||||
def plot_diagonal(
|
||||
self,
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
x = reference
|
||||
line_colors = self.get_colors(len(columns))
|
||||
|
||||
_edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)]))
|
||||
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
|
||||
|
||||
for name, val in zip(columns, data):
|
||||
color = next(line_colors)
|
||||
slope, interc = np.polyfit(x, val, 1)
|
||||
y_lr = np.array([slope * _x + interc for _x in _lims[0]])
|
||||
fig.add_traces(
|
||||
[
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=val,
|
||||
customdata=np.stack((val - x,), axis=-1),
|
||||
mode="markers",
|
||||
name=name,
|
||||
line=dict(color=self.hex_to_rgb(color, t=0.5)),
|
||||
hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}",
|
||||
),
|
||||
go.Scatter(
|
||||
x=_lims[0],
|
||||
y=y_lr,
|
||||
mode="lines",
|
||||
name=name,
|
||||
line=dict(color=self.hex_to_rgb(color), width=3),
|
||||
showlegend=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=_lims[0],
|
||||
y=_lims[1],
|
||||
mode="lines",
|
||||
name="reference",
|
||||
showlegend=False,
|
||||
line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
|
||||
)
|
||||
)
|
||||
|
||||
self.update_layout(fig, title, x_label, y_label)
|
||||
fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0)
|
||||
return fig
|
||||
|
||||
def plot_shift(
|
||||
self,
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
x = shift_prevs[:, pos_class]
|
||||
line_colors = self.get_colors(len(columns))
|
||||
for name, delta in zip(columns, data):
|
||||
col_idx = (columns == name).nonzero()[0][0]
|
||||
color = next(line_colors)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=delta,
|
||||
customdata=np.stack((counts[col_idx],), axis=-1),
|
||||
mode="lines+markers",
|
||||
name=name,
|
||||
line=dict(color=self.hex_to_rgb(color)),
|
||||
hovertemplate="shift: %{x}<br>error: %{y}"
|
||||
+ "<br>count: %{customdata[0]}"
|
||||
if counts is not None
|
||||
else "",
|
||||
)
|
||||
)
|
||||
|
||||
self.update_layout(fig, title, x_label, y_label)
|
||||
return fig
|
Loading…
Reference in New Issue