from collections import defaultdict from pathlib import Path import numpy as np import plotly import plotly.graph_objects as go from quacc.plot.base import BasePlot class PlotCfg: def __init__(self, mode, lwidth, font=None, legend=None, template="seaborn"): self.mode = mode self.lwidth = lwidth self.legend = {} if legend is None else legend self.font = {} if font is None else font self.template = template web_cfg = PlotCfg("lines+markers", 2) png_cfg = PlotCfg( "lines", 5, legend=dict( orientation="h", yanchor="bottom", xanchor="right", y=1.02, x=1, font=dict(size=24), ), font=dict(size=24), # template="ggplot2", ) _cfg = png_cfg class PlotlyPlot(BasePlot): __themes = defaultdict( lambda: { "template": _cfg.template, } ) __themes = __themes | { "dark": { "template": "plotly_dark", }, } def __init__(self, theme=None): self.theme = PlotlyPlot.__themes[theme] def hex_to_rgb(self, hex: str, t: float | None = None): hex = hex.lstrip("#") rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]] if t is not None: rgb.append(t) return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}" def get_colors(self, num): match num: case v if v > 10: __colors = plotly.colors.qualitative.Light24 case _: __colors = plotly.colors.qualitative.G10 def __generator(cs): while True: for c in cs: yield c return __generator(__colors) def update_layout(self, fig, title, x_label, y_label): fig.update_layout( # title=title, xaxis_title=x_label, yaxis_title=y_label, template=self.theme["template"], font=_cfg.font, legend=_cfg.legend, ) def save_fig(self, fig, base_path, title) -> Path: return None def plot_delta( self, base_prevs, columns, data, *, stdevs=None, pos_class=1, title="default", x_label="prevs.", y_label="error", legend=True, ) -> go.Figure: fig = go.Figure() if isinstance(base_prevs[0], float): base_prevs = np.around([(1 - bp, bp) for bp in base_prevs], decimals=4) x = [str(tuple(bp)) for bp in base_prevs] line_colors = self.get_colors(len(columns)) for name, delta in zip(columns, data): color = next(line_colors) _line = [ go.Scatter( x=x, y=delta, mode=_cfg.mode, name=name, line=dict(color=self.hex_to_rgb(color), width=_cfg.lwidth), hovertemplate="prev.: %{x}<br>error: %{y:,.4f}", ) ] _error = [] if stdevs is not None: _col_idx = np.where(columns == name)[0] stdev = stdevs[_col_idx].flatten() _error = [ go.Scatter( x=np.concatenate([x, x[::-1]]), y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]), name=int(_col_idx[0]), fill="toself", fillcolor=self.hex_to_rgb(color, t=0.2), line=dict(color="rgba(255, 255, 255, 0)"), hoverinfo="skip", showlegend=False, ) ] fig.add_traces(_line + _error) self.update_layout(fig, title, x_label, y_label) return fig def plot_diagonal( self, reference, columns, data, *, pos_class=1, title="default", x_label="true", y_label="estim.", legend=True, ) -> go.Figure: fig = go.Figure() x = reference line_colors = self.get_colors(len(columns)) _edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)])) _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]]) for name, val in zip(columns, data): color = next(line_colors) slope, interc = np.polyfit(x, val, 1) y_lr = np.array([slope * _x + interc for _x in _lims[0]]) fig.add_traces( [ go.Scatter( x=x, y=val, customdata=np.stack((val - x,), axis=-1), mode="markers", name=name, line=dict(color=self.hex_to_rgb(color, t=0.5)), hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}", ), go.Scatter( x=_lims[0], y=y_lr, mode="lines", name=name, line=dict(color=self.hex_to_rgb(color), width=3), showlegend=False, ), ] ) fig.add_trace( go.Scatter( x=_lims[0], y=_lims[1], mode="lines", name="reference", showlegend=False, line=dict(color=self.hex_to_rgb("#000000"), dash="dash"), ) ) self.update_layout(fig, title, x_label, y_label) fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0) return fig def plot_shift( self, shift_prevs, columns, data, *, counts=None, pos_class=1, title="default", x_label="true", y_label="estim.", legend=True, ) -> go.Figure: fig = go.Figure() # x = shift_prevs[:, pos_class] x = shift_prevs line_colors = self.get_colors(len(columns)) for name, delta in zip(columns, data): col_idx = (columns == name).nonzero()[0][0] color = next(line_colors) fig.add_trace( go.Scatter( x=x, y=delta, customdata=np.stack((counts[col_idx],), axis=-1), mode=_cfg.mode, name=name, line=dict(color=self.hex_to_rgb(color), width=_cfg.lwidth), hovertemplate="shift: %{x}<br>error: %{y}" + "<br>count: %{customdata[0]}" if counts is not None else "", ) ) self.update_layout(fig, title, x_label, y_label) return fig def plot_fit_scores( self, train_prevs, scores, *, pos_class=1, title="default", x_label="prev.", y_label="position", legend=True, ) -> go.Figure: fig = go.Figure() # x = train_prevs x = [str(tuple(bp)) for bp in train_prevs] fig.add_trace( go.Scatter( x=x, y=scores, mode="lines+markers", showlegend=False, ), ) self.update_layout(fig, title, x_label, y_label) return fig