From b8e43c02f2d675102934250152fa220631a08f71 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 4 Apr 2024 17:03:52 +0200 Subject: [PATCH] plots refactoring started --- quacc/plot/__init__.py | 4 +- quacc/plot/plotly.py | 487 ++++++++++++++++------------------------- 2 files changed, 189 insertions(+), 302 deletions(-) diff --git a/quacc/plot/__init__.py b/quacc/plot/__init__.py index c16c75e..433e7e1 100644 --- a/quacc/plot/__init__.py +++ b/quacc/plot/__init__.py @@ -1,7 +1,7 @@ -from quacc.plot.plot import ( +from quacc.legacy.plot.plot import ( get_backend, plot_delta, plot_diagonal, - plot_shift, plot_fit_scores, + plot_shift, ) diff --git a/quacc/plot/plotly.py b/quacc/plot/plotly.py index 52a514d..2fb6978 100644 --- a/quacc/plot/plotly.py +++ b/quacc/plot/plotly.py @@ -1,330 +1,217 @@ -from collections import defaultdict -from pathlib import Path +import os import numpy as np import plotly import plotly.graph_objects as go -from quacc.evaluation.estimators import CE, _renames -from quacc.plot.base import BasePlot - - -class PlotCfg: - def __init__(self, mode, lwidth, font=None, legend=None, template="seaborn"): - self.mode = mode - self.lwidth = lwidth - self.legend = {} if legend is None else legend - self.font = {} if font is None else font - self.template = template - - -web_cfg = PlotCfg("lines+markers", 2) -png_cfg_old = PlotCfg( - "lines", - 5, - legend=dict( - orientation="h", - yanchor="bottom", - xanchor="right", - y=1.02, - x=1, - font=dict(size=24), - ), - font=dict(size=24), - # template="ggplot2", -) -png_cfg = PlotCfg( - "lines", - 5, - legend=dict( - font=dict( - family="DejaVu Sans", - size=24, - ), - ), - font=dict(size=24), - # template="ggplot2", -) - -_cfg = png_cfg - - -class PlotlyPlot(BasePlot): - __themes = defaultdict( - lambda: { - "template": _cfg.template, - } - ) - __themes = __themes | { - "dark": { - "template": "plotly_dark", - }, +MODE = "lines" +L_WIDTH = 5 +LEGEND = { + "font": { + "family": "DejaVu Sans", + "size": 24, } +} +FONT = {"size": 24} +TEMPLATE = "ggplot2" - def __init__(self, theme=None): - self.theme = PlotlyPlot.__themes[theme] - self.rename = True - def hex_to_rgb(self, hex: str, t: float | None = None): - hex = hex.lstrip("#") - rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]] - if t is not None: - rgb.append(t) - return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}" +def _save_or_return(fig, basedir, dataset_name, measure_name, plot_type): + if basedir is not None: + plotsubdir = dataset_name + os.path.join(basedir, "plots", measure_name, plotsubdir, plot_type + ".svg") - def get_colors(self, num): - match num: - case v if v > 10: - __colors = plotly.colors.qualitative.Light24 - case _: - __colors = plotly.colors.qualitative.G10 + return fig - def __generator(cs): - while True: - for c in cs: - yield c - return __generator(__colors) +def _update_layout(fig, title, x_label, y_label, **kwargs): + fig.update_layout( + xaxis_title=x_label, + yaxis_title=y_label, + template=TEMPLATE, + font=FONT, + legend=LEGEND, + **kwargs, + ) - def update_layout(self, fig, title, x_label, y_label): - fig.update_layout( - # title=title, - xaxis_title=x_label, - yaxis_title=y_label, - template=self.theme["template"], - font=_cfg.font, - legend=_cfg.legend, - ) - def save_fig(self, fig, base_path, title) -> Path: - return None +def _hex_to_rgb(self, hex: str, t: float | None = None): + hex = hex.lstrip("#") + rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]] + if t is not None: + rgb.append(t) + return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}" - def rename_plots( - self, - columns, - ): - if not self.rename: - return columns - new_columns = [] - for c in columns: - nc = c - for old, new in _renames.items(): - if c.startswith(old): - nc = new + c[len(old) :] +def _get_colors(self, num): + match num: + case v if v > 10: + __colors = plotly.colors.qualitative.Light24 + case _: + __colors = plotly.colors.qualitative.G10 - new_columns.append(nc) + def __generator(cs): + while True: + for c in cs: + yield c - return np.array(new_columns) + return __generator(__colors) - def plot_delta( - self, - base_prevs, - columns, - data, - *, - stdevs=None, - pos_class=1, - title="default", - x_label="prevs.", - y_label="error", - legend=True, - ) -> go.Figure: - fig = go.Figure() - if isinstance(base_prevs[0], float): - base_prevs = np.around([(1 - bp, bp) for bp in base_prevs], decimals=4) - x = [str(tuple(bp)) for bp in base_prevs] - named_data = {c: d for c, d in zip(columns, data)} - r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))} - line_colors = self.get_colors(len(columns)) - # for name, delta in zip(columns, data): - columns = np.array(CE.name.sort(columns)) - for name in columns: - delta = named_data[name] - r_name = r_columns[name] - color = next(line_colors) - _line = [ + +def _get_ref_limits(true_accs: np.ndarray, estim_accs: dict[str, np.ndarray]): + """get lmits of reference line""" + + _edges = ( + np.min([np.min(true_accs), np.min(estim_accs)]), + np.max([np.max(true_accs), np.max(estim_accs)]), + ) + _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]]) + + +def plot_diagonal( + method_names, + true_accs, + estim_accs, + *, + measure_name="vanilla_accuracy", + dataset_name=None, + basedir=None, +) -> go.Figure: + fig = go.Figure() + x = true_accs + line_colors = _get_colors(len(method_names)) + _lims = _get_ref_limits(true_accs, estim_accs) + + for name, estim in zip(method_names, estim_accs): + color = next(line_colors) + slope, interc = np.polyfit(x, estim, 1) + fig.add_traces( + [ go.Scatter( x=x, - y=delta, - mode=_cfg.mode, - name=r_name, - line=dict(color=self.hex_to_rgb(color), width=_cfg.lwidth), - hovertemplate="prev.: %{x}
error: %{y:,.4f}", + y=estim, + customdata=np.stack((estim - x,), axis=-1), + mode="markers", + name=name, + marker=dict(color=_hex_to_rgb(color, t=0.5)), + hovertemplate="true acc: %{x:,.4f}
estim. acc: %{y:,.4f}
acc err.: %{customdata[0]:,.4f}", + ), + ] + ) + fig.add_trace( + go.Scatter( + x=_lims[0], + y=_lims[1], + mode="lines", + name="reference", + showlegend=False, + line=dict(color=_hex_to_rgb("#000000"), dash="dash"), + ) + ) + + _update_layout( + fig, + x_label=f"True {measure_name}", + y_label=f"Estimated {measure_name}", + autosize=False, + width=1300, + height=1000, + yaxis_scaleanchor="x", + yaxis_scaleratio=1.0, + yaxis_range=[-0.1, 1.1], + ) + return _save_or_return(fig, basedir, dataset_name, measure_name, "diagonal") + + +def plot_delta( + method_names: list[str], + prevs: np.ndarray, + acc_errs: np.ndarray, + *, + stdevs: np.ndarray | None = None, + prev_name="Test", + measure_name="Vanilla Accuracy", + dataset_name=None, + basedir=None, +) -> go.Figure: + fig = go.Figure() + x = [str(bp) for bp in prevs] + line_colors = _get_colors(len(method_names)) + if stdevs is None: + stdevs = [None] * len(method_names) + for name, delta, stdev in zip(method_names, acc_errs, stdevs): + color = next(line_colors) + _line = [ + go.Scatter( + x=x, + y=delta, + mode=MODE, + name=name, + line=dict(color=_hex_to_rgb(color), width=L_WIDTH), + hovertemplate="prev.: %{x}
error: %{y:,.4f}", + ) + ] + _error = [] + if stdev is not None: + _error = [ + go.Scatter( + x=np.concatenate([x, x[::-1]]), + y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]), + name=name, + fill="toself", + fillcolor=_hex_to_rgb(color, t=0.2), + line=dict(color="rgba(255, 255, 255, 0)"), + hoverinfo="skip", + showlegend=False, ) ] - _error = [] - if stdevs is not None: - _col_idx = np.where(columns == name)[0] - stdev = stdevs[_col_idx].flatten() - _error = [ - go.Scatter( - x=np.concatenate([x, x[::-1]]), - y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]), - name=int(_col_idx[0]), - fill="toself", - fillcolor=self.hex_to_rgb(color, t=0.2), - line=dict(color="rgba(255, 255, 255, 0)"), - hoverinfo="skip", - showlegend=False, - ) - ] - fig.add_traces(_line + _error) + fig.add_traces(_line + _error) - self.update_layout(fig, title, x_label, y_label) - return fig + _update_layout( + fig, + x_label=f"{prev_name} Prevalence", + y_label=f"Prediction Error for {measure_name}", + ) + return _save_or_return( + fig, basedir, dataset_name, measure_name, "delta" if stdevs is None else "stdev" + ) - def plot_diagonal( - self, - reference, - columns, - data, - *, - pos_class=1, - title="default", - x_label="true", - y_label="estim.", - fixed_lim=False, - legend=True, - ) -> go.Figure: - fig = go.Figure() - x = reference - line_colors = self.get_colors(len(columns)) - if fixed_lim: - _lims = np.array([[0.0, 1.0], [0.0, 1.0]]) - else: - _edges = ( - np.min([np.min(x), np.min(data)]), - np.max([np.max(x), np.max(data)]), - ) - _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]]) - - named_data = {c: d for c, d in zip(columns, data)} - r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))} - columns = np.array(CE.name.sort(columns)) - for name in columns: - val = named_data[name] - r_name = r_columns[name] - color = next(line_colors) - slope, interc = np.polyfit(x, val, 1) - # y_lr = np.array([slope * _x + interc for _x in _lims[0]]) - fig.add_traces( - [ - go.Scatter( - x=x, - y=val, - customdata=np.stack((val - x,), axis=-1), - mode="markers", - name=r_name, - marker=dict(color=self.hex_to_rgb(color, t=0.5)), - hovertemplate="true acc: %{x:,.4f}
estim. acc: %{y:,.4f}
acc err.: %{customdata[0]:,.4f}", - # showlegend=False, - ), - # go.Scatter( - # x=[x[-1]], - # y=[val[-1]], - # mode="markers", - # marker=dict(color=self.hex_to_rgb(color), size=8), - # name=r_name, - # ), - # go.Scatter( - # x=_lims[0], - # y=y_lr, - # mode="lines", - # name=name, - # line=dict(color=self.hex_to_rgb(color), width=3), - # showlegend=False, - # ), - ] - ) - fig.add_trace( - go.Scatter( - x=_lims[0], - y=_lims[1], - mode="lines", - name="reference", - showlegend=False, - line=dict(color=self.hex_to_rgb("#000000"), dash="dash"), - ) - ) - - self.update_layout(fig, title, x_label, y_label) - fig.update_layout( - autosize=False, - width=1300, - height=1000, - yaxis_scaleanchor="x", - yaxis_scaleratio=1.0, - yaxis_range=[-0.1, 1.1], - ) - return fig - - def plot_shift( - self, - shift_prevs, - columns, - data, - *, - counts=None, - pos_class=1, - title="default", - x_label="true", - y_label="estim.", - legend=True, - ) -> go.Figure: - fig = go.Figure() - # x = shift_prevs[:, pos_class] - x = shift_prevs - line_colors = self.get_colors(len(columns)) - named_data = {c: d for c, d in zip(columns, data)} - r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))} - columns = np.array(CE.name.sort(columns)) - for name in columns: - delta = named_data[name] - r_name = r_columns[name] - col_idx = (columns == name).nonzero()[0][0] - color = next(line_colors) - fig.add_trace( - go.Scatter( - x=x, - y=delta, - customdata=np.stack((counts[col_idx],), axis=-1), - mode=_cfg.mode, - name=r_name, - line=dict(color=self.hex_to_rgb(color), width=_cfg.lwidth), - hovertemplate="shift: %{x}
error: %{y}" - + "
count: %{customdata[0]}" - if counts is not None - else "", - ) - ) - - self.update_layout(fig, title, x_label, y_label) - return fig - - def plot_fit_scores( - self, - train_prevs, - scores, - *, - pos_class=1, - title="default", - x_label="prev.", - y_label="position", - legend=True, - ) -> go.Figure: - fig = go.Figure() - # x = train_prevs - x = [str(tuple(bp)) for bp in train_prevs] +def plot_shift( + method_names: list[str], + prevs: np.ndarray, + acc_errs: np.ndarray, + *, + counts: np.ndarray | None = None, + measure_name="Vanilla Accuracy", + dataset_name=None, + basedir=None, +) -> go.Figure: + fig = go.Figure() + x = prevs + line_colors = _get_colors(len(method_names)) + if counts is None: + counts = [None] * len(method_names) + for name, delta, count in zip(method_names, acc_errs, counts): + color = next(line_colors) fig.add_trace( go.Scatter( x=x, - y=scores, - mode="lines+markers", - showlegend=False, - ), + y=delta, + customdata=np.stack((count,), axis=-1), + mode=MODE, + name=name, + line=dict(color=_hex_to_rgb(color), width=L_WIDTH), + hovertemplate="shift: %{x}
error: %{y}" + + "
count: %{customdata[0]}" + if count is not None + else "", + ) ) - self.update_layout(fig, title, x_label, y_label) - return fig + _update_layout( + fig, + x_label="Amount of Prior Probability Shift", + y_label=f"Prediction Error for {measure_name}", + ) + return _save_or_return(fig, basedir, dataset_name, measure_name, "shift")