diff --git a/quacc/plot/matplotlib.py b/quacc/plot/matplotlib.py new file mode 100644 index 0000000..150324d --- /dev/null +++ b/quacc/plot/matplotlib.py @@ -0,0 +1,149 @@ +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +from cycler import cycler +from matplotlib.figure import Figure + +from quacc.plot.utils import _get_ref_limits +from quacc.utils.commons import get_plots_path + + +def _get_markers(num: int): + ls = "ovx+sDph*^1234X><.Pd" + if num > len(ls): + ls = ls * (num / len(ls) + 1) + return list(ls)[:num] + + +def _get_cycler(num): + cm = plt.get_cmap("tab20") if num > 10 else plt.get_cmap("tab10") + return cycler(color=[cm(i) for i in range(num)]) + + +def _save_or_return( + fig: Figure, basedir, cls_name, acc_name, dataset_name, plot_type +) -> Figure | None: + if basedir is None: + return fig + + plotsubdir = "all" if dataset_name == "*" else dataset_name + file = get_plots_path(basedir, cls_name, acc_name, plotsubdir, plot_type) + os.makedirs(Path(file).parent, exist_ok=True) + fig.savefig(file) + + +def plot_diagonal( + method_names: list[str], + true_accs: np.ndarray, + estim_accs: np.ndarray, + cls_name, + acc_name, + dataset_name, + *, + basedir=None, +): + fig, ax = plt.subplots() + ax.grid() + ax.set_aspect("equal") + + cy = _get_cycler(len(method_names)) + + for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy): + ax.plot( + x, + estim, + label=name, + color=_cy["color"], + linestyle="None", + marker="o", + markersize=3, + zorder=2, + alpha=0.25, + ) + + # ensure limits are equal for both axes + _lims = _get_ref_limits(true_accs, estim_accs) + ax.set(xlim=_lims[0], ylim=_lims[1]) + + # draw polyfit line per method + # for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy): + # slope, interc = np.polyfit(x, estim, 1) + # y_lr = np.array([slope * x + interc for x in _lims]) + # ax.plot( + # _lims, + # y_lr, + # label=name, + # color=_cy["color"], + # linestyle="-", + # markersize="0", + # zorder=1, + # ) + + # plot reference line + ax.plot( + _lims, + _lims, + color="black", + linestyle="--", + markersize=0, + zorder=1, + ) + + ax.set(xlabel=f"True {acc_name}", ylabel=f"Estimated {acc_name}") + + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal") + + +def plot_delta( + method_names: list[str], + prevs: np.ndarray, + acc_errs: np.ndarray, + cls_name, + acc_name, + dataset_name, + prev_name, + *, + stdevs: np.ndarray | None = None, + basedir=None, +): + fig, ax = plt.subplots() + ax.set_aspect("auto") + ax.grid() + + cy = _get_cycler(len(method_names)) + + x = [str(bp) for bp in prevs] + if stdevs is None: + stdevs = [None] * len(method_names) + for name, delta, stdev, _cy in zip(method_names, acc_errs, stdevs, cy): + ax.plot( + x, + delta, + label=name, + color=_cy["color"], + linestyle="-", + marker="", + markersize=3, + zorder=2, + ) + if stdev is not None: + ax.fill_between( + prevs, + delta - stdev, + delta + stdev, + color=_cy["color"], + alpha=0.25, + ) + + ax.set( + xlabel=f"{prev_name} Prevalence", + ylabel=f"Prediction Error for {acc_name}", + ) + + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + return fig diff --git a/quacc/plot/plotly.py b/quacc/plot/plotly.py index 33d1365..7d9a58d 100644 --- a/quacc/plot/plotly.py +++ b/quacc/plot/plotly.py @@ -1,10 +1,8 @@ -from pathlib import Path - import numpy as np import plotly import plotly.graph_objects as go -from quacc.utils.commons import get_plots_path +from quacc.plot.utils import _get_ref_limits MODE = "lines" L_WIDTH = 5 @@ -18,17 +16,7 @@ FONT = {"size": 24} TEMPLATE = "ggplot2" -def _save_or_return( - fig: go.Figure, basedir, cls_name, acc_name, dataset_name, plot_type -) -> go.Figure | None: - if basedir is None: - return fig - - path = get_plots_path(basedir, cls_name, acc_name, dataset_name, plot_type) - fig.write_image(path) - - -def _update_layout(fig, title, x_label, y_label, **kwargs): +def _update_layout(fig, x_label, y_label, **kwargs): fig.update_layout( xaxis_title=x_label, yaxis_title=y_label, @@ -39,7 +27,7 @@ def _update_layout(fig, title, x_label, y_label, **kwargs): ) -def _hex_to_rgb(self, hex: str, t: float | None = None): +def _hex_to_rgb(hex: str, t: float | None = None): hex = hex.lstrip("#") rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]] if t is not None: @@ -47,7 +35,7 @@ def _hex_to_rgb(self, hex: str, t: float | None = None): return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}" -def _get_colors(self, num): +def _get_colors(num): match num: case v if v > 10: __colors = plotly.colors.qualitative.Light24 @@ -62,16 +50,6 @@ def _get_colors(self, num): return __generator(__colors) -def _get_ref_limits(true_accs: np.ndarray, estim_accs: dict[str, np.ndarray]): - """get lmits of reference line""" - - _edges = ( - np.min([np.min(true_accs), np.min(estim_accs)]), - np.max([np.max(true_accs), np.max(estim_accs)]), - ) - _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]]) - - def plot_diagonal( method_names, true_accs, @@ -83,11 +61,10 @@ def plot_diagonal( basedir=None, ) -> go.Figure: fig = go.Figure() - x = true_accs line_colors = _get_colors(len(method_names)) _lims = _get_ref_limits(true_accs, estim_accs) - for name, estim in zip(method_names, estim_accs): + for name, x, estim in zip(method_names, true_accs, estim_accs): color = next(line_colors) slope, interc = np.polyfit(x, estim, 1) fig.add_traces( @@ -125,7 +102,8 @@ def plot_diagonal( yaxis_scaleratio=1.0, yaxis_range=[-0.1, 1.1], ) - return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal") + # return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal") + return fig def plot_delta( @@ -133,7 +111,7 @@ def plot_delta( prevs: np.ndarray, acc_errs: np.ndarray, cls_name, - acc_mame, + acc_name, dataset_name, prev_name, *, @@ -176,16 +154,17 @@ def plot_delta( _update_layout( fig, x_label=f"{prev_name} Prevalence", - y_label=f"Prediction Error for {acc_mame}", - ) - return _save_or_return( - fig, - basedir, - cls_name, - acc_mame, - dataset_name, - "delta" if stdevs is None else "stdev", + y_label=f"Prediction Error for {acc_name}", ) + # return _save_or_return( + # fig, + # basedir, + # cls_name, + # acc_mame, + # dataset_name, + # "delta" if stdevs is None else "stdev", + # ) + return fig def plot_shift( @@ -226,4 +205,5 @@ def plot_shift( x_label="Amount of Prior Probability Shift", y_label=f"Prediction Error for {acc_name}", ) - return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "shift") + # return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "shift") + return fig diff --git a/quacc/plot/utils.py b/quacc/plot/utils.py new file mode 100644 index 0000000..292df1b --- /dev/null +++ b/quacc/plot/utils.py @@ -0,0 +1,15 @@ +import numpy as np +import plotly.graph_objects as go + +from quacc.utils.commons import get_plots_path + + +def _get_ref_limits(true_accs: np.ndarray, estim_accs: np.ndarray): + """get lmits of reference line""" + + _edges = ( + np.min([np.min(true_accs), np.min(estim_accs)]), + np.max([np.max(true_accs), np.max(estim_accs)]), + ) + _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]]) + return _lims