import os from pathlib import Path import matplotlib.pyplot as plt import numpy as np from cycler import cycler from matplotlib.figure import Figure from quacc.plot.utils import _get_ref_limits from quacc.utils.commons import get_plots_path def _get_markers(num: int): ls = "ovx+sDph*^1234X><.Pd" if num > len(ls): ls = ls * (num / len(ls) + 1) return list(ls)[:num] def _get_cycler(num): cm = plt.get_cmap("tab20") if num > 10 else plt.get_cmap("tab10") return cycler(color=[cm(i) for i in range(num)]) def _save_or_return( fig: Figure, basedir, cls_name, acc_name, dataset_name, plot_type ) -> Figure | None: if basedir is None: return fig plotsubdir = "all" if dataset_name == "*" else dataset_name file = get_plots_path(basedir, cls_name, acc_name, plotsubdir, plot_type) os.makedirs(Path(file).parent, exist_ok=True) fig.savefig(file) def plot_diagonal( method_names: list[str], true_accs: np.ndarray, estim_accs: np.ndarray, cls_name, acc_name, dataset_name, *, basedir=None, ): fig, ax = plt.subplots() ax.grid() ax.set_aspect("equal") cy = _get_cycler(len(method_names)) for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy): ax.plot( x, estim, label=name, color=_cy["color"], linestyle="None", marker="o", markersize=3, zorder=2, alpha=0.25, ) # ensure limits are equal for both axes _lims = _get_ref_limits(true_accs, estim_accs) ax.set(xlim=_lims[0], ylim=_lims[1]) # draw polyfit line per method # for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy): # slope, interc = np.polyfit(x, estim, 1) # y_lr = np.array([slope * x + interc for x in _lims]) # ax.plot( # _lims, # y_lr, # label=name, # color=_cy["color"], # linestyle="-", # markersize="0", # zorder=1, # ) # plot reference line ax.plot( _lims, _lims, color="black", linestyle="--", markersize=0, zorder=1, ) ax.set(xlabel=f"True {acc_name}", ylabel=f"Estimated {acc_name}") ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal") def plot_delta( method_names: list[str], prevs: np.ndarray, acc_errs: np.ndarray, cls_name, acc_name, dataset_name, prev_name, *, stdevs: np.ndarray | None = None, basedir=None, ): fig, ax = plt.subplots() ax.set_aspect("auto") ax.grid() cy = _get_cycler(len(method_names)) x = [str(bp) for bp in prevs] if stdevs is None: stdevs = [None] * len(method_names) for name, delta, stdev, _cy in zip(method_names, acc_errs, stdevs, cy): ax.plot( x, delta, label=name, color=_cy["color"], linestyle="-", marker="", markersize=3, zorder=2, ) if stdev is not None: ax.fill_between( prevs, delta - stdev, delta + stdev, color=_cy["color"], alpha=0.25, ) ax.set( xlabel=f"{prev_name} Prevalence", ylabel=f"Prediction Error for {acc_name}", ) ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) return fig