plotly plot backend added
This commit is contained in:
parent
f0bfb2e039
commit
c670f48b5b
265
quacc/plot.py
265
quacc/plot.py
|
@ -1,265 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
from cycler import cycler
|
|
||||||
|
|
||||||
from quacc import utils
|
|
||||||
|
|
||||||
matplotlib.use("agg")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_markers(n: int):
|
|
||||||
ls = "ovx+sDph*^1234X><.Pd"
|
|
||||||
if n > len(ls):
|
|
||||||
ls = ls * (n / len(ls) + 1)
|
|
||||||
return list(ls)[:n]
|
|
||||||
|
|
||||||
|
|
||||||
def plot_delta(
|
|
||||||
base_prevs,
|
|
||||||
columns,
|
|
||||||
data,
|
|
||||||
*,
|
|
||||||
stdevs=None,
|
|
||||||
pos_class=1,
|
|
||||||
metric="acc",
|
|
||||||
name="default",
|
|
||||||
train_prev=None,
|
|
||||||
legend=True,
|
|
||||||
avg=None,
|
|
||||||
return_fig=False,
|
|
||||||
base_path=None,
|
|
||||||
) -> Path:
|
|
||||||
_base_title = "delta_stdev" if stdevs is not None else "delta"
|
|
||||||
if train_prev is not None:
|
|
||||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
|
||||||
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
|
|
||||||
else:
|
|
||||||
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
|
|
||||||
|
|
||||||
if base_path is None:
|
|
||||||
base_path = utils.get_quacc_home() / "plots"
|
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
ax.set_aspect("auto")
|
|
||||||
ax.grid()
|
|
||||||
|
|
||||||
NUM_COLORS = len(data)
|
|
||||||
cm = plt.get_cmap("tab10")
|
|
||||||
if NUM_COLORS > 10:
|
|
||||||
cm = plt.get_cmap("tab20")
|
|
||||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
|
||||||
|
|
||||||
base_prevs = base_prevs[:, pos_class]
|
|
||||||
for method, deltas, _cy in zip(columns, data, cy):
|
|
||||||
ax.plot(
|
|
||||||
base_prevs,
|
|
||||||
deltas,
|
|
||||||
label=method,
|
|
||||||
color=_cy["color"],
|
|
||||||
linestyle="-",
|
|
||||||
marker="o",
|
|
||||||
markersize=3,
|
|
||||||
zorder=2,
|
|
||||||
)
|
|
||||||
if stdevs is not None:
|
|
||||||
_col_idx = np.where(columns == method)[0]
|
|
||||||
stdev = stdevs[_col_idx].flatten()
|
|
||||||
nn_idx = np.intersect1d(
|
|
||||||
np.where(deltas != np.nan)[0],
|
|
||||||
np.where(stdev != np.nan)[0],
|
|
||||||
)
|
|
||||||
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
|
|
||||||
ax.fill_between(
|
|
||||||
_bps,
|
|
||||||
_ds - _st,
|
|
||||||
_ds + _st,
|
|
||||||
color=_cy["color"],
|
|
||||||
alpha=0.25,
|
|
||||||
)
|
|
||||||
|
|
||||||
x_label = "test" if avg is None or avg == "train" else "train"
|
|
||||||
ax.set(
|
|
||||||
xlabel=f"{x_label} prevalence",
|
|
||||||
ylabel=metric,
|
|
||||||
title=title,
|
|
||||||
)
|
|
||||||
|
|
||||||
if legend:
|
|
||||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
|
||||||
|
|
||||||
if return_fig:
|
|
||||||
return fig
|
|
||||||
|
|
||||||
output_path = base_path / f"{title}.png"
|
|
||||||
fig.savefig(output_path, bbox_inches="tight")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def plot_diagonal(
|
|
||||||
reference,
|
|
||||||
columns,
|
|
||||||
data,
|
|
||||||
*,
|
|
||||||
pos_class=1,
|
|
||||||
metric="acc",
|
|
||||||
name="default",
|
|
||||||
train_prev=None,
|
|
||||||
legend=True,
|
|
||||||
return_fig=False,
|
|
||||||
base_path=None,
|
|
||||||
):
|
|
||||||
if train_prev is not None:
|
|
||||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
|
||||||
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
|
|
||||||
else:
|
|
||||||
title = f"diagonal_{name}_{metric}"
|
|
||||||
|
|
||||||
if base_path is None:
|
|
||||||
base_path = utils.get_quacc_home() / "plots"
|
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
ax.set_aspect("auto")
|
|
||||||
ax.grid()
|
|
||||||
ax.set_aspect("equal")
|
|
||||||
|
|
||||||
NUM_COLORS = len(data)
|
|
||||||
cm = plt.get_cmap("tab10")
|
|
||||||
if NUM_COLORS > 10:
|
|
||||||
cm = plt.get_cmap("tab20")
|
|
||||||
cy = cycler(
|
|
||||||
color=[cm(i) for i in range(NUM_COLORS)],
|
|
||||||
marker=_get_markers(NUM_COLORS),
|
|
||||||
)
|
|
||||||
|
|
||||||
reference = np.array(reference)
|
|
||||||
x_ticks = np.unique(reference)
|
|
||||||
x_ticks.sort()
|
|
||||||
|
|
||||||
for deltas, _cy in zip(data, cy):
|
|
||||||
ax.plot(
|
|
||||||
reference,
|
|
||||||
deltas,
|
|
||||||
color=_cy["color"],
|
|
||||||
linestyle="None",
|
|
||||||
marker=_cy["marker"],
|
|
||||||
markersize=3,
|
|
||||||
zorder=2,
|
|
||||||
alpha=0.25,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ensure limits are equal for both axes
|
|
||||||
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
|
|
||||||
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
|
|
||||||
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
|
|
||||||
|
|
||||||
for method, deltas, _cy in zip(columns, data, cy):
|
|
||||||
slope, interc = np.polyfit(reference, deltas, 1)
|
|
||||||
y_lr = np.array([slope * x + interc for x in _lims])
|
|
||||||
ax.plot(
|
|
||||||
_lims,
|
|
||||||
y_lr,
|
|
||||||
label=method,
|
|
||||||
color=_cy["color"],
|
|
||||||
linestyle="-",
|
|
||||||
markersize="0",
|
|
||||||
zorder=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# plot reference line
|
|
||||||
ax.plot(
|
|
||||||
_lims,
|
|
||||||
_lims,
|
|
||||||
color="black",
|
|
||||||
linestyle="--",
|
|
||||||
markersize=0,
|
|
||||||
zorder=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title)
|
|
||||||
|
|
||||||
if legend:
|
|
||||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
|
||||||
|
|
||||||
if return_fig:
|
|
||||||
return fig
|
|
||||||
|
|
||||||
output_path = base_path / f"{title}.png"
|
|
||||||
fig.savefig(output_path, bbox_inches="tight")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def plot_shift(
|
|
||||||
shift_prevs,
|
|
||||||
columns,
|
|
||||||
data,
|
|
||||||
*,
|
|
||||||
counts=None,
|
|
||||||
pos_class=1,
|
|
||||||
metric="acc",
|
|
||||||
name="default",
|
|
||||||
train_prev=None,
|
|
||||||
legend=True,
|
|
||||||
return_fig=False,
|
|
||||||
base_path=None,
|
|
||||||
) -> Path:
|
|
||||||
if train_prev is not None:
|
|
||||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
|
||||||
title = f"shift_{name}_{t_prev_pos}_{metric}"
|
|
||||||
else:
|
|
||||||
title = f"shift_{name}_avg_{metric}"
|
|
||||||
|
|
||||||
if base_path is None:
|
|
||||||
base_path = utils.get_quacc_home() / "plots"
|
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
ax.set_aspect("auto")
|
|
||||||
ax.grid()
|
|
||||||
|
|
||||||
NUM_COLORS = len(data)
|
|
||||||
cm = plt.get_cmap("tab10")
|
|
||||||
if NUM_COLORS > 10:
|
|
||||||
cm = plt.get_cmap("tab20")
|
|
||||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
|
||||||
|
|
||||||
shift_prevs = shift_prevs[:, pos_class]
|
|
||||||
for method, shifts, _cy in zip(columns, data, cy):
|
|
||||||
ax.plot(
|
|
||||||
shift_prevs,
|
|
||||||
shifts,
|
|
||||||
label=method,
|
|
||||||
color=_cy["color"],
|
|
||||||
linestyle="-",
|
|
||||||
marker="o",
|
|
||||||
markersize=3,
|
|
||||||
zorder=2,
|
|
||||||
)
|
|
||||||
if counts is not None:
|
|
||||||
_col_idx = np.where(columns == method)[0]
|
|
||||||
count = counts[_col_idx].flatten()
|
|
||||||
for prev, shift, cnt in zip(shift_prevs, shifts, count):
|
|
||||||
label = f"{cnt}"
|
|
||||||
plt.annotate(
|
|
||||||
label,
|
|
||||||
(prev, shift),
|
|
||||||
textcoords="offset points",
|
|
||||||
xytext=(0, 10),
|
|
||||||
ha="center",
|
|
||||||
color=_cy["color"],
|
|
||||||
fontsize=12.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set(xlabel="dataset shift", ylabel=metric, title=title)
|
|
||||||
|
|
||||||
if legend:
|
|
||||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
|
||||||
|
|
||||||
if return_fig:
|
|
||||||
return fig
|
|
||||||
|
|
||||||
output_path = base_path / f"{title}.png"
|
|
||||||
fig.savefig(output_path, bbox_inches="tight")
|
|
||||||
|
|
||||||
return output_path
|
|
|
@ -0,0 +1,201 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import plotly
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
|
||||||
|
from quacc.plot.base import BasePlot
|
||||||
|
|
||||||
|
|
||||||
|
class PlotlyPlot(BasePlot):
|
||||||
|
__themes = defaultdict(
|
||||||
|
lambda: {
|
||||||
|
"template": "seaborn",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
__themes = __themes | {
|
||||||
|
"dark": {
|
||||||
|
"template": "plotly_dark",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, theme=None):
|
||||||
|
self.theme = PlotlyPlot.__themes[theme]
|
||||||
|
|
||||||
|
def hex_to_rgb(self, hex: str, t: float | None = None):
|
||||||
|
hex = hex.lstrip("#")
|
||||||
|
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
|
||||||
|
if t is not None:
|
||||||
|
rgb.append(t)
|
||||||
|
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
|
||||||
|
|
||||||
|
def get_colors(self, num):
|
||||||
|
match num:
|
||||||
|
case v if v > 10:
|
||||||
|
__colors = plotly.colors.qualitative.Light24
|
||||||
|
case _:
|
||||||
|
__colors = plotly.colors.qualitative.Plotly
|
||||||
|
|
||||||
|
def __generator(cs):
|
||||||
|
while True:
|
||||||
|
for c in cs:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
return __generator(__colors)
|
||||||
|
|
||||||
|
def update_layout(self, fig, title, x_label, y_label):
|
||||||
|
fig.update_layout(
|
||||||
|
title=title,
|
||||||
|
xaxis_title=x_label,
|
||||||
|
yaxis_title=y_label,
|
||||||
|
template=self.theme["template"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_fig(self, fig, base_path, title) -> Path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def plot_delta(
|
||||||
|
self,
|
||||||
|
base_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
stdevs=None,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="prevs.",
|
||||||
|
y_label="error",
|
||||||
|
legend=True,
|
||||||
|
) -> go.Figure:
|
||||||
|
fig = go.Figure()
|
||||||
|
x = base_prevs[:, pos_class]
|
||||||
|
line_colors = self.get_colors(len(columns))
|
||||||
|
for name, delta in zip(columns, data):
|
||||||
|
color = next(line_colors)
|
||||||
|
_line = [
|
||||||
|
go.Scatter(
|
||||||
|
x=x,
|
||||||
|
y=delta,
|
||||||
|
mode="lines+markers",
|
||||||
|
name=name,
|
||||||
|
line=dict(color=self.hex_to_rgb(color)),
|
||||||
|
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
_error = []
|
||||||
|
if stdevs is not None:
|
||||||
|
_col_idx = np.where(columns == name)[0]
|
||||||
|
stdev = stdevs[_col_idx].flatten()
|
||||||
|
_error = [
|
||||||
|
go.Scatter(
|
||||||
|
x=np.concatenate([x, x[::-1]]),
|
||||||
|
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
|
||||||
|
name=int(_col_idx[0]),
|
||||||
|
fill="toself",
|
||||||
|
fillcolor=self.hex_to_rgb(color, t=0.2),
|
||||||
|
line=dict(color="rgba(255, 255, 255, 0)"),
|
||||||
|
hoverinfo="skip",
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
fig.add_traces(_line + _error)
|
||||||
|
|
||||||
|
self.update_layout(fig, title, x_label, y_label)
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_diagonal(
|
||||||
|
self,
|
||||||
|
reference,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="true",
|
||||||
|
y_label="estim.",
|
||||||
|
legend=True,
|
||||||
|
) -> go.Figure:
|
||||||
|
fig = go.Figure()
|
||||||
|
x = reference
|
||||||
|
line_colors = self.get_colors(len(columns))
|
||||||
|
|
||||||
|
_edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)]))
|
||||||
|
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
|
||||||
|
|
||||||
|
for name, val in zip(columns, data):
|
||||||
|
color = next(line_colors)
|
||||||
|
slope, interc = np.polyfit(x, val, 1)
|
||||||
|
y_lr = np.array([slope * _x + interc for _x in _lims[0]])
|
||||||
|
fig.add_traces(
|
||||||
|
[
|
||||||
|
go.Scatter(
|
||||||
|
x=x,
|
||||||
|
y=val,
|
||||||
|
customdata=np.stack((val - x,), axis=-1),
|
||||||
|
mode="markers",
|
||||||
|
name=name,
|
||||||
|
line=dict(color=self.hex_to_rgb(color, t=0.5)),
|
||||||
|
hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}",
|
||||||
|
),
|
||||||
|
go.Scatter(
|
||||||
|
x=_lims[0],
|
||||||
|
y=y_lr,
|
||||||
|
mode="lines",
|
||||||
|
name=name,
|
||||||
|
line=dict(color=self.hex_to_rgb(color), width=3),
|
||||||
|
showlegend=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=_lims[0],
|
||||||
|
y=_lims[1],
|
||||||
|
mode="lines",
|
||||||
|
name="reference",
|
||||||
|
showlegend=False,
|
||||||
|
line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_layout(fig, title, x_label, y_label)
|
||||||
|
fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0)
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_shift(
|
||||||
|
self,
|
||||||
|
shift_prevs,
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
counts=None,
|
||||||
|
pos_class=1,
|
||||||
|
title="default",
|
||||||
|
x_label="true",
|
||||||
|
y_label="estim.",
|
||||||
|
legend=True,
|
||||||
|
) -> go.Figure:
|
||||||
|
fig = go.Figure()
|
||||||
|
x = shift_prevs[:, pos_class]
|
||||||
|
line_colors = self.get_colors(len(columns))
|
||||||
|
for name, delta in zip(columns, data):
|
||||||
|
col_idx = (columns == name).nonzero()[0][0]
|
||||||
|
color = next(line_colors)
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=x,
|
||||||
|
y=delta,
|
||||||
|
customdata=np.stack((counts[col_idx],), axis=-1),
|
||||||
|
mode="lines+markers",
|
||||||
|
name=name,
|
||||||
|
line=dict(color=self.hex_to_rgb(color)),
|
||||||
|
hovertemplate="shift: %{x}<br>error: %{y}"
|
||||||
|
+ "<br>count: %{customdata[0]}"
|
||||||
|
if counts is not None
|
||||||
|
else "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_layout(fig, title, x_label, y_label)
|
||||||
|
return fig
|
Loading…
Reference in New Issue