From c670f48b5b5c8d8226560d36e88cb548d71e00a3 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Wed, 29 Nov 2023 03:56:01 +0100 Subject: [PATCH] plotly plot backend added --- quacc/plot.py | 265 ------------------------------------------- quacc/plot/plotly.py | 201 ++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+), 265 deletions(-) delete mode 100644 quacc/plot.py create mode 100644 quacc/plot/plotly.py diff --git a/quacc/plot.py b/quacc/plot.py deleted file mode 100644 index e0bbefa..0000000 --- a/quacc/plot.py +++ /dev/null @@ -1,265 +0,0 @@ -from pathlib import Path - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -from cycler import cycler - -from quacc import utils - -matplotlib.use("agg") - - -def _get_markers(n: int): - ls = "ovx+sDph*^1234X><.Pd" - if n > len(ls): - ls = ls * (n / len(ls) + 1) - return list(ls)[:n] - - -def plot_delta( - base_prevs, - columns, - data, - *, - stdevs=None, - pos_class=1, - metric="acc", - name="default", - train_prev=None, - legend=True, - avg=None, - return_fig=False, - base_path=None, -) -> Path: - _base_title = "delta_stdev" if stdevs is not None else "delta" - if train_prev is not None: - t_prev_pos = int(round(train_prev[pos_class] * 100)) - title = f"{_base_title}_{name}_{t_prev_pos}_{metric}" - else: - title = f"{_base_title}_{name}_avg_{avg}_{metric}" - - if base_path is None: - base_path = utils.get_quacc_home() / "plots" - - fig, ax = plt.subplots() - ax.set_aspect("auto") - ax.grid() - - NUM_COLORS = len(data) - cm = plt.get_cmap("tab10") - if NUM_COLORS > 10: - cm = plt.get_cmap("tab20") - cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) - - base_prevs = base_prevs[:, pos_class] - for method, deltas, _cy in zip(columns, data, cy): - ax.plot( - base_prevs, - deltas, - label=method, - color=_cy["color"], - linestyle="-", - marker="o", - markersize=3, - zorder=2, - ) - if stdevs is not None: - _col_idx = np.where(columns == method)[0] - stdev = stdevs[_col_idx].flatten() - nn_idx = np.intersect1d( - np.where(deltas != np.nan)[0], - np.where(stdev != np.nan)[0], - ) - _bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx] - ax.fill_between( - _bps, - _ds - _st, - _ds + _st, - color=_cy["color"], - alpha=0.25, - ) - - x_label = "test" if avg is None or avg == "train" else "train" - ax.set( - xlabel=f"{x_label} prevalence", - ylabel=metric, - title=title, - ) - - if legend: - ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) - - if return_fig: - return fig - - output_path = base_path / f"{title}.png" - fig.savefig(output_path, bbox_inches="tight") - return output_path - - -def plot_diagonal( - reference, - columns, - data, - *, - pos_class=1, - metric="acc", - name="default", - train_prev=None, - legend=True, - return_fig=False, - base_path=None, -): - if train_prev is not None: - t_prev_pos = int(round(train_prev[pos_class] * 100)) - title = f"diagonal_{name}_{t_prev_pos}_{metric}" - else: - title = f"diagonal_{name}_{metric}" - - if base_path is None: - base_path = utils.get_quacc_home() / "plots" - - fig, ax = plt.subplots() - ax.set_aspect("auto") - ax.grid() - ax.set_aspect("equal") - - NUM_COLORS = len(data) - cm = plt.get_cmap("tab10") - if NUM_COLORS > 10: - cm = plt.get_cmap("tab20") - cy = cycler( - color=[cm(i) for i in range(NUM_COLORS)], - marker=_get_markers(NUM_COLORS), - ) - - reference = np.array(reference) - x_ticks = np.unique(reference) - x_ticks.sort() - - for deltas, _cy in zip(data, cy): - ax.plot( - reference, - deltas, - color=_cy["color"], - linestyle="None", - marker=_cy["marker"], - markersize=3, - zorder=2, - alpha=0.25, - ) - - # ensure limits are equal for both axes - _alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1) - _lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)]) - ax.set(xlim=tuple(_lims), ylim=tuple(_lims)) - - for method, deltas, _cy in zip(columns, data, cy): - slope, interc = np.polyfit(reference, deltas, 1) - y_lr = np.array([slope * x + interc for x in _lims]) - ax.plot( - _lims, - y_lr, - label=method, - color=_cy["color"], - linestyle="-", - markersize="0", - zorder=1, - ) - - # plot reference line - ax.plot( - _lims, - _lims, - color="black", - linestyle="--", - markersize=0, - zorder=1, - ) - - ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title) - - if legend: - ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) - - if return_fig: - return fig - - output_path = base_path / f"{title}.png" - fig.savefig(output_path, bbox_inches="tight") - return output_path - - -def plot_shift( - shift_prevs, - columns, - data, - *, - counts=None, - pos_class=1, - metric="acc", - name="default", - train_prev=None, - legend=True, - return_fig=False, - base_path=None, -) -> Path: - if train_prev is not None: - t_prev_pos = int(round(train_prev[pos_class] * 100)) - title = f"shift_{name}_{t_prev_pos}_{metric}" - else: - title = f"shift_{name}_avg_{metric}" - - if base_path is None: - base_path = utils.get_quacc_home() / "plots" - - fig, ax = plt.subplots() - ax.set_aspect("auto") - ax.grid() - - NUM_COLORS = len(data) - cm = plt.get_cmap("tab10") - if NUM_COLORS > 10: - cm = plt.get_cmap("tab20") - cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) - - shift_prevs = shift_prevs[:, pos_class] - for method, shifts, _cy in zip(columns, data, cy): - ax.plot( - shift_prevs, - shifts, - label=method, - color=_cy["color"], - linestyle="-", - marker="o", - markersize=3, - zorder=2, - ) - if counts is not None: - _col_idx = np.where(columns == method)[0] - count = counts[_col_idx].flatten() - for prev, shift, cnt in zip(shift_prevs, shifts, count): - label = f"{cnt}" - plt.annotate( - label, - (prev, shift), - textcoords="offset points", - xytext=(0, 10), - ha="center", - color=_cy["color"], - fontsize=12.0, - ) - - ax.set(xlabel="dataset shift", ylabel=metric, title=title) - - if legend: - ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) - - if return_fig: - return fig - - output_path = base_path / f"{title}.png" - fig.savefig(output_path, bbox_inches="tight") - - return output_path diff --git a/quacc/plot/plotly.py b/quacc/plot/plotly.py new file mode 100644 index 0000000..074c277 --- /dev/null +++ b/quacc/plot/plotly.py @@ -0,0 +1,201 @@ +from collections import defaultdict +from pathlib import Path + +import numpy as np +import plotly +import plotly.graph_objects as go + +from quacc.plot.base import BasePlot + + +class PlotlyPlot(BasePlot): + __themes = defaultdict( + lambda: { + "template": "seaborn", + } + ) + __themes = __themes | { + "dark": { + "template": "plotly_dark", + }, + } + + def __init__(self, theme=None): + self.theme = PlotlyPlot.__themes[theme] + + def hex_to_rgb(self, hex: str, t: float | None = None): + hex = hex.lstrip("#") + rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]] + if t is not None: + rgb.append(t) + return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}" + + def get_colors(self, num): + match num: + case v if v > 10: + __colors = plotly.colors.qualitative.Light24 + case _: + __colors = plotly.colors.qualitative.Plotly + + def __generator(cs): + while True: + for c in cs: + yield c + + return __generator(__colors) + + def update_layout(self, fig, title, x_label, y_label): + fig.update_layout( + title=title, + xaxis_title=x_label, + yaxis_title=y_label, + template=self.theme["template"], + ) + + def save_fig(self, fig, base_path, title) -> Path: + return None + + def plot_delta( + self, + base_prevs, + columns, + data, + *, + stdevs=None, + pos_class=1, + title="default", + x_label="prevs.", + y_label="error", + legend=True, + ) -> go.Figure: + fig = go.Figure() + x = base_prevs[:, pos_class] + line_colors = self.get_colors(len(columns)) + for name, delta in zip(columns, data): + color = next(line_colors) + _line = [ + go.Scatter( + x=x, + y=delta, + mode="lines+markers", + name=name, + line=dict(color=self.hex_to_rgb(color)), + hovertemplate="prev.: %{x}
error: %{y:,.4f}", + ) + ] + _error = [] + if stdevs is not None: + _col_idx = np.where(columns == name)[0] + stdev = stdevs[_col_idx].flatten() + _error = [ + go.Scatter( + x=np.concatenate([x, x[::-1]]), + y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]), + name=int(_col_idx[0]), + fill="toself", + fillcolor=self.hex_to_rgb(color, t=0.2), + line=dict(color="rgba(255, 255, 255, 0)"), + hoverinfo="skip", + showlegend=False, + ) + ] + fig.add_traces(_line + _error) + + self.update_layout(fig, title, x_label, y_label) + return fig + + def plot_diagonal( + self, + reference, + columns, + data, + *, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ) -> go.Figure: + fig = go.Figure() + x = reference + line_colors = self.get_colors(len(columns)) + + _edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)])) + _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]]) + + for name, val in zip(columns, data): + color = next(line_colors) + slope, interc = np.polyfit(x, val, 1) + y_lr = np.array([slope * _x + interc for _x in _lims[0]]) + fig.add_traces( + [ + go.Scatter( + x=x, + y=val, + customdata=np.stack((val - x,), axis=-1), + mode="markers", + name=name, + line=dict(color=self.hex_to_rgb(color, t=0.5)), + hovertemplate="true acc: %{x:,.4f}
estim. acc: %{y:,.4f}
acc err.: %{customdata[0]:,.4f}", + ), + go.Scatter( + x=_lims[0], + y=y_lr, + mode="lines", + name=name, + line=dict(color=self.hex_to_rgb(color), width=3), + showlegend=False, + ), + ] + ) + fig.add_trace( + go.Scatter( + x=_lims[0], + y=_lims[1], + mode="lines", + name="reference", + showlegend=False, + line=dict(color=self.hex_to_rgb("#000000"), dash="dash"), + ) + ) + + self.update_layout(fig, title, x_label, y_label) + fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0) + return fig + + def plot_shift( + self, + shift_prevs, + columns, + data, + *, + counts=None, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ) -> go.Figure: + fig = go.Figure() + x = shift_prevs[:, pos_class] + line_colors = self.get_colors(len(columns)) + for name, delta in zip(columns, data): + col_idx = (columns == name).nonzero()[0][0] + color = next(line_colors) + fig.add_trace( + go.Scatter( + x=x, + y=delta, + customdata=np.stack((counts[col_idx],), axis=-1), + mode="lines+markers", + name=name, + line=dict(color=self.hex_to_rgb(color)), + hovertemplate="shift: %{x}
error: %{y}" + + "
count: %{customdata[0]}" + if counts is not None + else "", + ) + ) + + self.update_layout(fig, title, x_label, y_label) + return fig