diff --git a/quacc/plot.py b/quacc/plot.py
deleted file mode 100644
index e0bbefa..0000000
--- a/quacc/plot.py
+++ /dev/null
@@ -1,265 +0,0 @@
-from pathlib import Path
-
-import matplotlib
-import matplotlib.pyplot as plt
-import numpy as np
-from cycler import cycler
-
-from quacc import utils
-
-matplotlib.use("agg")
-
-
-def _get_markers(n: int):
- ls = "ovx+sDph*^1234X><.Pd"
- if n > len(ls):
- ls = ls * (n / len(ls) + 1)
- return list(ls)[:n]
-
-
-def plot_delta(
- base_prevs,
- columns,
- data,
- *,
- stdevs=None,
- pos_class=1,
- metric="acc",
- name="default",
- train_prev=None,
- legend=True,
- avg=None,
- return_fig=False,
- base_path=None,
-) -> Path:
- _base_title = "delta_stdev" if stdevs is not None else "delta"
- if train_prev is not None:
- t_prev_pos = int(round(train_prev[pos_class] * 100))
- title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
- else:
- title = f"{_base_title}_{name}_avg_{avg}_{metric}"
-
- if base_path is None:
- base_path = utils.get_quacc_home() / "plots"
-
- fig, ax = plt.subplots()
- ax.set_aspect("auto")
- ax.grid()
-
- NUM_COLORS = len(data)
- cm = plt.get_cmap("tab10")
- if NUM_COLORS > 10:
- cm = plt.get_cmap("tab20")
- cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
-
- base_prevs = base_prevs[:, pos_class]
- for method, deltas, _cy in zip(columns, data, cy):
- ax.plot(
- base_prevs,
- deltas,
- label=method,
- color=_cy["color"],
- linestyle="-",
- marker="o",
- markersize=3,
- zorder=2,
- )
- if stdevs is not None:
- _col_idx = np.where(columns == method)[0]
- stdev = stdevs[_col_idx].flatten()
- nn_idx = np.intersect1d(
- np.where(deltas != np.nan)[0],
- np.where(stdev != np.nan)[0],
- )
- _bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
- ax.fill_between(
- _bps,
- _ds - _st,
- _ds + _st,
- color=_cy["color"],
- alpha=0.25,
- )
-
- x_label = "test" if avg is None or avg == "train" else "train"
- ax.set(
- xlabel=f"{x_label} prevalence",
- ylabel=metric,
- title=title,
- )
-
- if legend:
- ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
-
- if return_fig:
- return fig
-
- output_path = base_path / f"{title}.png"
- fig.savefig(output_path, bbox_inches="tight")
- return output_path
-
-
-def plot_diagonal(
- reference,
- columns,
- data,
- *,
- pos_class=1,
- metric="acc",
- name="default",
- train_prev=None,
- legend=True,
- return_fig=False,
- base_path=None,
-):
- if train_prev is not None:
- t_prev_pos = int(round(train_prev[pos_class] * 100))
- title = f"diagonal_{name}_{t_prev_pos}_{metric}"
- else:
- title = f"diagonal_{name}_{metric}"
-
- if base_path is None:
- base_path = utils.get_quacc_home() / "plots"
-
- fig, ax = plt.subplots()
- ax.set_aspect("auto")
- ax.grid()
- ax.set_aspect("equal")
-
- NUM_COLORS = len(data)
- cm = plt.get_cmap("tab10")
- if NUM_COLORS > 10:
- cm = plt.get_cmap("tab20")
- cy = cycler(
- color=[cm(i) for i in range(NUM_COLORS)],
- marker=_get_markers(NUM_COLORS),
- )
-
- reference = np.array(reference)
- x_ticks = np.unique(reference)
- x_ticks.sort()
-
- for deltas, _cy in zip(data, cy):
- ax.plot(
- reference,
- deltas,
- color=_cy["color"],
- linestyle="None",
- marker=_cy["marker"],
- markersize=3,
- zorder=2,
- alpha=0.25,
- )
-
- # ensure limits are equal for both axes
- _alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
- _lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
- ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
-
- for method, deltas, _cy in zip(columns, data, cy):
- slope, interc = np.polyfit(reference, deltas, 1)
- y_lr = np.array([slope * x + interc for x in _lims])
- ax.plot(
- _lims,
- y_lr,
- label=method,
- color=_cy["color"],
- linestyle="-",
- markersize="0",
- zorder=1,
- )
-
- # plot reference line
- ax.plot(
- _lims,
- _lims,
- color="black",
- linestyle="--",
- markersize=0,
- zorder=1,
- )
-
- ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title)
-
- if legend:
- ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
-
- if return_fig:
- return fig
-
- output_path = base_path / f"{title}.png"
- fig.savefig(output_path, bbox_inches="tight")
- return output_path
-
-
-def plot_shift(
- shift_prevs,
- columns,
- data,
- *,
- counts=None,
- pos_class=1,
- metric="acc",
- name="default",
- train_prev=None,
- legend=True,
- return_fig=False,
- base_path=None,
-) -> Path:
- if train_prev is not None:
- t_prev_pos = int(round(train_prev[pos_class] * 100))
- title = f"shift_{name}_{t_prev_pos}_{metric}"
- else:
- title = f"shift_{name}_avg_{metric}"
-
- if base_path is None:
- base_path = utils.get_quacc_home() / "plots"
-
- fig, ax = plt.subplots()
- ax.set_aspect("auto")
- ax.grid()
-
- NUM_COLORS = len(data)
- cm = plt.get_cmap("tab10")
- if NUM_COLORS > 10:
- cm = plt.get_cmap("tab20")
- cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
-
- shift_prevs = shift_prevs[:, pos_class]
- for method, shifts, _cy in zip(columns, data, cy):
- ax.plot(
- shift_prevs,
- shifts,
- label=method,
- color=_cy["color"],
- linestyle="-",
- marker="o",
- markersize=3,
- zorder=2,
- )
- if counts is not None:
- _col_idx = np.where(columns == method)[0]
- count = counts[_col_idx].flatten()
- for prev, shift, cnt in zip(shift_prevs, shifts, count):
- label = f"{cnt}"
- plt.annotate(
- label,
- (prev, shift),
- textcoords="offset points",
- xytext=(0, 10),
- ha="center",
- color=_cy["color"],
- fontsize=12.0,
- )
-
- ax.set(xlabel="dataset shift", ylabel=metric, title=title)
-
- if legend:
- ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
-
- if return_fig:
- return fig
-
- output_path = base_path / f"{title}.png"
- fig.savefig(output_path, bbox_inches="tight")
-
- return output_path
diff --git a/quacc/plot/plotly.py b/quacc/plot/plotly.py
new file mode 100644
index 0000000..074c277
--- /dev/null
+++ b/quacc/plot/plotly.py
@@ -0,0 +1,201 @@
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+import plotly
+import plotly.graph_objects as go
+
+from quacc.plot.base import BasePlot
+
+
+class PlotlyPlot(BasePlot):
+ __themes = defaultdict(
+ lambda: {
+ "template": "seaborn",
+ }
+ )
+ __themes = __themes | {
+ "dark": {
+ "template": "plotly_dark",
+ },
+ }
+
+ def __init__(self, theme=None):
+ self.theme = PlotlyPlot.__themes[theme]
+
+ def hex_to_rgb(self, hex: str, t: float | None = None):
+ hex = hex.lstrip("#")
+ rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
+ if t is not None:
+ rgb.append(t)
+ return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
+
+ def get_colors(self, num):
+ match num:
+ case v if v > 10:
+ __colors = plotly.colors.qualitative.Light24
+ case _:
+ __colors = plotly.colors.qualitative.Plotly
+
+ def __generator(cs):
+ while True:
+ for c in cs:
+ yield c
+
+ return __generator(__colors)
+
+ def update_layout(self, fig, title, x_label, y_label):
+ fig.update_layout(
+ title=title,
+ xaxis_title=x_label,
+ yaxis_title=y_label,
+ template=self.theme["template"],
+ )
+
+ def save_fig(self, fig, base_path, title) -> Path:
+ return None
+
+ def plot_delta(
+ self,
+ base_prevs,
+ columns,
+ data,
+ *,
+ stdevs=None,
+ pos_class=1,
+ title="default",
+ x_label="prevs.",
+ y_label="error",
+ legend=True,
+ ) -> go.Figure:
+ fig = go.Figure()
+ x = base_prevs[:, pos_class]
+ line_colors = self.get_colors(len(columns))
+ for name, delta in zip(columns, data):
+ color = next(line_colors)
+ _line = [
+ go.Scatter(
+ x=x,
+ y=delta,
+ mode="lines+markers",
+ name=name,
+ line=dict(color=self.hex_to_rgb(color)),
+ hovertemplate="prev.: %{x}
error: %{y:,.4f}",
+ )
+ ]
+ _error = []
+ if stdevs is not None:
+ _col_idx = np.where(columns == name)[0]
+ stdev = stdevs[_col_idx].flatten()
+ _error = [
+ go.Scatter(
+ x=np.concatenate([x, x[::-1]]),
+ y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
+ name=int(_col_idx[0]),
+ fill="toself",
+ fillcolor=self.hex_to_rgb(color, t=0.2),
+ line=dict(color="rgba(255, 255, 255, 0)"),
+ hoverinfo="skip",
+ showlegend=False,
+ )
+ ]
+ fig.add_traces(_line + _error)
+
+ self.update_layout(fig, title, x_label, y_label)
+ return fig
+
+ def plot_diagonal(
+ self,
+ reference,
+ columns,
+ data,
+ *,
+ pos_class=1,
+ title="default",
+ x_label="true",
+ y_label="estim.",
+ legend=True,
+ ) -> go.Figure:
+ fig = go.Figure()
+ x = reference
+ line_colors = self.get_colors(len(columns))
+
+ _edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)]))
+ _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
+
+ for name, val in zip(columns, data):
+ color = next(line_colors)
+ slope, interc = np.polyfit(x, val, 1)
+ y_lr = np.array([slope * _x + interc for _x in _lims[0]])
+ fig.add_traces(
+ [
+ go.Scatter(
+ x=x,
+ y=val,
+ customdata=np.stack((val - x,), axis=-1),
+ mode="markers",
+ name=name,
+ line=dict(color=self.hex_to_rgb(color, t=0.5)),
+ hovertemplate="true acc: %{x:,.4f}
estim. acc: %{y:,.4f}
acc err.: %{customdata[0]:,.4f}",
+ ),
+ go.Scatter(
+ x=_lims[0],
+ y=y_lr,
+ mode="lines",
+ name=name,
+ line=dict(color=self.hex_to_rgb(color), width=3),
+ showlegend=False,
+ ),
+ ]
+ )
+ fig.add_trace(
+ go.Scatter(
+ x=_lims[0],
+ y=_lims[1],
+ mode="lines",
+ name="reference",
+ showlegend=False,
+ line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
+ )
+ )
+
+ self.update_layout(fig, title, x_label, y_label)
+ fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0)
+ return fig
+
+ def plot_shift(
+ self,
+ shift_prevs,
+ columns,
+ data,
+ *,
+ counts=None,
+ pos_class=1,
+ title="default",
+ x_label="true",
+ y_label="estim.",
+ legend=True,
+ ) -> go.Figure:
+ fig = go.Figure()
+ x = shift_prevs[:, pos_class]
+ line_colors = self.get_colors(len(columns))
+ for name, delta in zip(columns, data):
+ col_idx = (columns == name).nonzero()[0][0]
+ color = next(line_colors)
+ fig.add_trace(
+ go.Scatter(
+ x=x,
+ y=delta,
+ customdata=np.stack((counts[col_idx],), axis=-1),
+ mode="lines+markers",
+ name=name,
+ line=dict(color=self.hex_to_rgb(color)),
+ hovertemplate="shift: %{x}
error: %{y}"
+ + "
count: %{customdata[0]}"
+ if counts is not None
+ else "",
+ )
+ )
+
+ self.update_layout(fig, title, x_label, y_label)
+ return fig