diff --git a/qcpanel/util.py b/qcpanel/util.py index 3820d71..f92a57a 100644 --- a/qcpanel/util.py +++ b/qcpanel/util.py @@ -52,7 +52,7 @@ def create_plots( metric=metric, estimators=estimators, conf="panel", - return_fig=True, + save_fig=False, ) return ( pn.pane.Matplotlib( @@ -87,7 +87,7 @@ def create_plots( metric=metric, estimators=estimators, conf="panel", - return_fig=True, + save_fig=False, ) return ( pn.pane.Matplotlib( diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index e636980..70a43d6 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -6,7 +6,7 @@ from typing import List, Tuple import numpy as np import pandas as pd -from quacc import plot +import quacc.plot as plot from quacc.utils import fmt_line_md @@ -214,16 +214,17 @@ class CompReport: def get_plots( self, - mode="delta", + mode="delta_train", metric="acc", estimators=None, conf="default", - return_fig=False, + save_fig=True, base_path=None, + backend=None, ) -> List[Tuple[str, Path]]: if mode == "delta_train": avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) - if avg_data.empty is True: + if avg_data.empty: return None return plot.plot_delta( @@ -233,8 +234,9 @@ class CompReport: metric=metric, name=conf, train_prev=self.train_prev, - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "stdev_train": avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) @@ -250,8 +252,9 @@ class CompReport: name=conf, train_prev=self.train_prev, stdevs=st_data.T.to_numpy(), - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "diagonal": f_data = self.data(metric=metric + "_score", estimators=estimators) @@ -267,8 +270,9 @@ class CompReport: metric=metric, name=conf, train_prev=self.train_prev, - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "shift": _shift_data = self.shift_data(metric=metric, estimators=estimators) @@ -289,8 +293,9 @@ class CompReport: name=conf, train_prev=self.train_prev, counts=shift_counts.T.to_numpy(), - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) def to_md( @@ -322,11 +327,12 @@ class CompReport: plot_modes = [m for m in modes if not m.endswith("table")] for mode in plot_modes: res += f"### {mode}\n" - op = self.get_plots( + _, op = self.get_plots( mode=mode, metric=metric, estimators=estimators, conf=conf, + save_fig=True, base_path=plot_path, ) res += f"![plot_{mode}]({op.relative_to(op.parents[1]).as_posix()})\n" @@ -423,8 +429,9 @@ class DatasetReport: metric="acc", estimators=None, conf="default", - return_fig=False, + save_fig=True, base_path=None, + backend=None, ): if mode == "delta_train": _data = self.data(metric, estimators) if data is None else data @@ -440,8 +447,9 @@ class DatasetReport: name=conf, train_prev=None, avg="train", - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "stdev_train": _data = self.data(metric, estimators) if data is None else data @@ -459,8 +467,9 @@ class DatasetReport: train_prev=None, stdevs=stdev_on_train.T.to_numpy(), avg="train", - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "delta_test": _data = self.data(metric, estimators) if data is None else data @@ -474,8 +483,9 @@ class DatasetReport: name=conf, train_prev=None, avg="test", - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "stdev_test": _data = self.data(metric, estimators) if data is None else data @@ -491,8 +501,9 @@ class DatasetReport: train_prev=None, stdevs=stdev_on_test.T.to_numpy(), avg="test", - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "shift": _shift_data = self.shift_data(metric, estimators) if data is None else data @@ -507,8 +518,9 @@ class DatasetReport: name=conf, train_prev=None, counts=count_shift.T.to_numpy(), - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) def to_md( @@ -544,24 +556,26 @@ class DatasetReport: res += avg_on_train_tbl.to_html() + "\n\n" if "delta_train" in dr_modes: - delta_op = self.get_plots( + _, delta_op = self.get_plots( data=_data, mode="delta_train", metric=metric, estimators=estimators, conf=conf, base_path=plot_path, + save_fig=True, ) res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" if "stdev_train" in dr_modes: - delta_stdev_op = self.get_plots( + _, delta_stdev_op = self.get_plots( data=_data, mode="stdev_train", metric=metric, estimators=estimators, conf=conf, base_path=plot_path, + save_fig=True, ) res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" @@ -574,24 +588,26 @@ class DatasetReport: res += avg_on_test_tbl.to_html() + "\n\n" if "delta_test" in dr_modes: - delta_op = self.get_plots( + _, delta_op = self.get_plots( data=_data, mode="delta_test", metric=metric, estimators=estimators, conf=conf, base_path=plot_path, + save_fig=True, ) res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" if "stdev_test" in dr_modes: - delta_stdev_op = self.get_plots( + _, delta_stdev_op = self.get_plots( data=_data, mode="stdev_test", metric=metric, estimators=estimators, conf=conf, base_path=plot_path, + save_fig=True, ) res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" @@ -604,13 +620,14 @@ class DatasetReport: res += shift_on_train_tbl.to_html() + "\n\n" if "shift" in dr_modes: - shift_op = self.get_plots( + _, shift_op = self.get_plots( data=_shift_data, mode="shift", metric=metric, estimators=estimators, conf=conf, base_path=plot_path, + save_fig=True, ) res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n" diff --git a/quacc/plot/__init__.py b/quacc/plot/__init__.py new file mode 100644 index 0000000..6a182c5 --- /dev/null +++ b/quacc/plot/__init__.py @@ -0,0 +1 @@ +from quacc.plot.plot import get_backend, plot_delta, plot_diagonal, plot_shift diff --git a/quacc/plot/base.py b/quacc/plot/base.py new file mode 100644 index 0000000..36a58b2 --- /dev/null +++ b/quacc/plot/base.py @@ -0,0 +1,54 @@ +from pathlib import Path + + +class BasePlot: + @classmethod + def save_fig(cls, fig, base_path, title) -> Path: + ... + + @classmethod + def plot_diagonal( + cls, + reference, + columns, + data, + *, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ): + ... + + @classmethod + def plot_delta( + cls, + base_prevs, + columns, + data, + *, + stdevs=None, + pos_class=1, + title="default", + x_label="prevs.", + y_label="error", + legend=True, + ): + ... + + @classmethod + def plot_shift( + cls, + shift_prevs, + columns, + data, + *, + counts=None, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ): + ... diff --git a/quacc/plot/mpl.py b/quacc/plot/mpl.py new file mode 100644 index 0000000..dd84b7a --- /dev/null +++ b/quacc/plot/mpl.py @@ -0,0 +1,222 @@ +from pathlib import Path + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from cycler import cycler + +from quacc import utils +from quacc.plot.base import BasePlot + +matplotlib.use("agg") + + +class MplPlot(BasePlot): + def _get_markers(self, n: int): + ls = "ovx+sDph*^1234X><.Pd" + if n > len(ls): + ls = ls * (n / len(ls) + 1) + return list(ls)[:n] + + def save_fig(self, fig, base_path, title) -> Path: + if base_path is None: + base_path = utils.get_quacc_home() / "plots" + output_path = base_path / f"{title}.png" + fig.savefig(output_path, bbox_inches="tight") + return output_path + + def plot_delta( + self, + base_prevs, + columns, + data, + *, + stdevs=None, + pos_class=1, + title="default", + x_label="prevs.", + y_label="error", + legend=True, + ): + fig, ax = plt.subplots() + ax.set_aspect("auto") + ax.grid() + + NUM_COLORS = len(data) + cm = plt.get_cmap("tab10") + if NUM_COLORS > 10: + cm = plt.get_cmap("tab20") + cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) + + base_prevs = base_prevs[:, pos_class] + for method, deltas, _cy in zip(columns, data, cy): + ax.plot( + base_prevs, + deltas, + label=method, + color=_cy["color"], + linestyle="-", + marker="o", + markersize=3, + zorder=2, + ) + if stdevs is not None: + _col_idx = np.where(columns == method)[0] + stdev = stdevs[_col_idx].flatten() + nn_idx = np.intersect1d( + np.where(deltas != np.nan)[0], + np.where(stdev != np.nan)[0], + ) + _bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx] + ax.fill_between( + _bps, + _ds - _st, + _ds + _st, + color=_cy["color"], + alpha=0.25, + ) + + ax.set( + xlabel=f"{x_label} prevalence", + ylabel=y_label, + title=title, + ) + + if legend: + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + return fig + + def plot_diagonal( + self, + reference, + columns, + data, + *, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ): + fig, ax = plt.subplots() + ax.set_aspect("auto") + ax.grid() + ax.set_aspect("equal") + + NUM_COLORS = len(data) + cm = plt.get_cmap("tab10") + if NUM_COLORS > 10: + cm = plt.get_cmap("tab20") + cy = cycler( + color=[cm(i) for i in range(NUM_COLORS)], + marker=self._get_markers(NUM_COLORS), + ) + + reference = np.array(reference) + x_ticks = np.unique(reference) + x_ticks.sort() + + for deltas, _cy in zip(data, cy): + ax.plot( + reference, + deltas, + color=_cy["color"], + linestyle="None", + marker=_cy["marker"], + markersize=3, + zorder=2, + alpha=0.25, + ) + + # ensure limits are equal for both axes + _alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1) + _lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)]) + ax.set(xlim=tuple(_lims), ylim=tuple(_lims)) + + for method, deltas, _cy in zip(columns, data, cy): + slope, interc = np.polyfit(reference, deltas, 1) + y_lr = np.array([slope * x + interc for x in _lims]) + ax.plot( + _lims, + y_lr, + label=method, + color=_cy["color"], + linestyle="-", + markersize="0", + zorder=1, + ) + + # plot reference line + ax.plot( + _lims, + _lims, + color="black", + linestyle="--", + markersize=0, + zorder=1, + ) + + ax.set(xlabel=x_label, ylabel=y_label, title=title) + + if legend: + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + return fig + + def plot_shift( + self, + shift_prevs, + columns, + data, + *, + counts=None, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ): + fig, ax = plt.subplots() + ax.set_aspect("auto") + ax.grid() + + NUM_COLORS = len(data) + cm = plt.get_cmap("tab10") + if NUM_COLORS > 10: + cm = plt.get_cmap("tab20") + cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) + + shift_prevs = shift_prevs[:, pos_class] + for method, shifts, _cy in zip(columns, data, cy): + ax.plot( + shift_prevs, + shifts, + label=method, + color=_cy["color"], + linestyle="-", + marker="o", + markersize=3, + zorder=2, + ) + if counts is not None: + _col_idx = np.where(columns == method)[0] + count = counts[_col_idx].flatten() + for prev, shift, cnt in zip(shift_prevs, shifts, count): + label = f"{cnt}" + plt.annotate( + label, + (prev, shift), + textcoords="offset points", + xytext=(0, 10), + ha="center", + color=_cy["color"], + fontsize=12.0, + ) + + ax.set(xlabel=x_label, ylabel=y_label, title=title) + + if legend: + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + return fig diff --git a/quacc/plot/plot.py b/quacc/plot/plot.py new file mode 100644 index 0000000..1bd2369 --- /dev/null +++ b/quacc/plot/plot.py @@ -0,0 +1,144 @@ +from quacc.plot.base import BasePlot +from quacc.plot.mpl import MplPlot +from quacc.plot.plotly import PlotlyPlot + +__backend: BasePlot = MplPlot() + + +def get_backend(be, theme=None): + match be: + case "matplotlib" | "mpl": + return MplPlot() + case "plotly": + return PlotlyPlot(theme=theme) + case _: + return MplPlot() + + +def plot_delta( + base_prevs, + columns, + data, + *, + stdevs=None, + pos_class=1, + metric="acc", + name="default", + train_prev=None, + legend=True, + avg=None, + save_fig=False, + base_path=None, + backend=None, +): + backend = __backend if backend is None else backend + _base_title = "delta_stdev" if stdevs is not None else "delta" + if train_prev is not None: + t_prev_pos = int(round(train_prev[pos_class] * 100)) + title = f"{_base_title}_{name}_{t_prev_pos}_{metric}" + else: + title = f"{_base_title}_{name}_avg_{avg}_{metric}" + + x_label = f"{'test' if avg is None or avg == 'train' else 'train'} prevalence" + y_label = f"{metric} error" + fig = backend.plot_delta( + base_prevs, + columns, + data, + stdevs=stdevs, + pos_class=pos_class, + title=title, + x_label=x_label, + y_label=y_label, + legend=legend, + ) + + if save_fig: + output_path = backend.save_fig(fig, base_path, title) + return fig, output_path + + return fig + + +def plot_diagonal( + reference, + columns, + data, + *, + pos_class=1, + metric="acc", + name="default", + train_prev=None, + legend=True, + save_fig=False, + base_path=None, + backend=None, +): + backend = __backend if backend is None else backend + if train_prev is not None: + t_prev_pos = int(round(train_prev[pos_class] * 100)) + title = f"diagonal_{name}_{t_prev_pos}_{metric}" + else: + title = f"diagonal_{name}_{metric}" + + x_label = f"true {metric}" + y_label = f"estim. {metric}" + fig = backend.plot_diagonal( + reference, + columns, + data, + pos_class=pos_class, + title=title, + x_label=x_label, + y_label=y_label, + legend=legend, + ) + + if save_fig: + output_path = backend.save_fig(fig, base_path, title) + return fig, output_path + + return fig + + +def plot_shift( + shift_prevs, + columns, + data, + *, + counts=None, + pos_class=1, + metric="acc", + name="default", + train_prev=None, + legend=True, + save_fig=False, + base_path=None, + backend=None, +): + backend = __backend if backend is None else backend + if train_prev is not None: + t_prev_pos = int(round(train_prev[pos_class] * 100)) + title = f"shift_{name}_{t_prev_pos}_{metric}" + else: + title = f"shift_{name}_avg_{metric}" + + x_label = "dataset shift" + y_label = f"{metric} error" + fig = backend.plot_shift( + shift_prevs, + columns, + data, + counts=counts, + pos_class=pos_class, + title=title, + x_label=x_label, + y_label=y_label, + legend=legend, + ) + + if save_fig: + output_path = backend.save_fig(fig, base_path, title) + return fig, output_path + + return fig