QuAcc/quacc/plot/matplotlib.py

150 lines
3.5 KiB
Python

import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from matplotlib.figure import Figure
from quacc.plot.utils import _get_ref_limits
from quacc.utils.commons import get_plots_path
def _get_markers(num: int):
ls = "ovx+sDph*^1234X><.Pd"
if num > len(ls):
ls = ls * (num / len(ls) + 1)
return list(ls)[:num]
def _get_cycler(num):
cm = plt.get_cmap("tab20") if num > 10 else plt.get_cmap("tab10")
return cycler(color=[cm(i) for i in range(num)])
def _save_or_return(
fig: Figure, basedir, cls_name, acc_name, dataset_name, plot_type
) -> Figure | None:
if basedir is None:
return fig
plotsubdir = "all" if dataset_name == "*" else dataset_name
file = get_plots_path(basedir, cls_name, acc_name, plotsubdir, plot_type)
os.makedirs(Path(file).parent, exist_ok=True)
fig.savefig(file)
def plot_diagonal(
method_names: list[str],
true_accs: np.ndarray,
estim_accs: np.ndarray,
cls_name,
acc_name,
dataset_name,
*,
basedir=None,
):
fig, ax = plt.subplots()
ax.grid()
ax.set_aspect("equal")
cy = _get_cycler(len(method_names))
for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy):
ax.plot(
x,
estim,
label=name,
color=_cy["color"],
linestyle="None",
marker="o",
markersize=3,
zorder=2,
alpha=0.25,
)
# ensure limits are equal for both axes
_lims = _get_ref_limits(true_accs, estim_accs)
ax.set(xlim=_lims[0], ylim=_lims[1])
# draw polyfit line per method
# for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy):
# slope, interc = np.polyfit(x, estim, 1)
# y_lr = np.array([slope * x + interc for x in _lims])
# ax.plot(
# _lims,
# y_lr,
# label=name,
# color=_cy["color"],
# linestyle="-",
# markersize="0",
# zorder=1,
# )
# plot reference line
ax.plot(
_lims,
_lims,
color="black",
linestyle="--",
markersize=0,
zorder=1,
)
ax.set(xlabel=f"True {acc_name}", ylabel=f"Estimated {acc_name}")
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
def plot_delta(
method_names: list[str],
prevs: np.ndarray,
acc_errs: np.ndarray,
cls_name,
acc_name,
dataset_name,
prev_name,
*,
stdevs: np.ndarray | None = None,
basedir=None,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
cy = _get_cycler(len(method_names))
x = [str(bp) for bp in prevs]
if stdevs is None:
stdevs = [None] * len(method_names)
for name, delta, stdev, _cy in zip(method_names, acc_errs, stdevs, cy):
ax.plot(
x,
delta,
label=name,
color=_cy["color"],
linestyle="-",
marker="",
markersize=3,
zorder=2,
)
if stdev is not None:
ax.fill_between(
prevs,
delta - stdev,
delta + stdev,
color=_cy["color"],
alpha=0.25,
)
ax.set(
xlabel=f"{prev_name} Prevalence",
ylabel=f"Prediction Error for {acc_name}",
)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig