150 lines
3.5 KiB
Python
150 lines
3.5 KiB
Python
import os
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from cycler import cycler
|
|
from matplotlib.figure import Figure
|
|
|
|
from quacc.plot.utils import _get_ref_limits
|
|
from quacc.utils.commons import get_plots_path
|
|
|
|
|
|
def _get_markers(num: int):
|
|
ls = "ovx+sDph*^1234X><.Pd"
|
|
if num > len(ls):
|
|
ls = ls * (num / len(ls) + 1)
|
|
return list(ls)[:num]
|
|
|
|
|
|
def _get_cycler(num):
|
|
cm = plt.get_cmap("tab20") if num > 10 else plt.get_cmap("tab10")
|
|
return cycler(color=[cm(i) for i in range(num)])
|
|
|
|
|
|
def _save_or_return(
|
|
fig: Figure, basedir, cls_name, acc_name, dataset_name, plot_type
|
|
) -> Figure | None:
|
|
if basedir is None:
|
|
return fig
|
|
|
|
plotsubdir = "all" if dataset_name == "*" else dataset_name
|
|
file = get_plots_path(basedir, cls_name, acc_name, plotsubdir, plot_type)
|
|
os.makedirs(Path(file).parent, exist_ok=True)
|
|
fig.savefig(file)
|
|
|
|
|
|
def plot_diagonal(
|
|
method_names: list[str],
|
|
true_accs: np.ndarray,
|
|
estim_accs: np.ndarray,
|
|
cls_name,
|
|
acc_name,
|
|
dataset_name,
|
|
*,
|
|
basedir=None,
|
|
):
|
|
fig, ax = plt.subplots()
|
|
ax.grid()
|
|
ax.set_aspect("equal")
|
|
|
|
cy = _get_cycler(len(method_names))
|
|
|
|
for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy):
|
|
ax.plot(
|
|
x,
|
|
estim,
|
|
label=name,
|
|
color=_cy["color"],
|
|
linestyle="None",
|
|
marker="o",
|
|
markersize=3,
|
|
zorder=2,
|
|
alpha=0.25,
|
|
)
|
|
|
|
# ensure limits are equal for both axes
|
|
_lims = _get_ref_limits(true_accs, estim_accs)
|
|
ax.set(xlim=_lims[0], ylim=_lims[1])
|
|
|
|
# draw polyfit line per method
|
|
# for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy):
|
|
# slope, interc = np.polyfit(x, estim, 1)
|
|
# y_lr = np.array([slope * x + interc for x in _lims])
|
|
# ax.plot(
|
|
# _lims,
|
|
# y_lr,
|
|
# label=name,
|
|
# color=_cy["color"],
|
|
# linestyle="-",
|
|
# markersize="0",
|
|
# zorder=1,
|
|
# )
|
|
|
|
# plot reference line
|
|
ax.plot(
|
|
_lims,
|
|
_lims,
|
|
color="black",
|
|
linestyle="--",
|
|
markersize=0,
|
|
zorder=1,
|
|
)
|
|
|
|
ax.set(xlabel=f"True {acc_name}", ylabel=f"Estimated {acc_name}")
|
|
|
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
|
|
|
return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
|
|
|
|
|
|
def plot_delta(
|
|
method_names: list[str],
|
|
prevs: np.ndarray,
|
|
acc_errs: np.ndarray,
|
|
cls_name,
|
|
acc_name,
|
|
dataset_name,
|
|
prev_name,
|
|
*,
|
|
stdevs: np.ndarray | None = None,
|
|
basedir=None,
|
|
):
|
|
fig, ax = plt.subplots()
|
|
ax.set_aspect("auto")
|
|
ax.grid()
|
|
|
|
cy = _get_cycler(len(method_names))
|
|
|
|
x = [str(bp) for bp in prevs]
|
|
if stdevs is None:
|
|
stdevs = [None] * len(method_names)
|
|
for name, delta, stdev, _cy in zip(method_names, acc_errs, stdevs, cy):
|
|
ax.plot(
|
|
x,
|
|
delta,
|
|
label=name,
|
|
color=_cy["color"],
|
|
linestyle="-",
|
|
marker="",
|
|
markersize=3,
|
|
zorder=2,
|
|
)
|
|
if stdev is not None:
|
|
ax.fill_between(
|
|
prevs,
|
|
delta - stdev,
|
|
delta + stdev,
|
|
color=_cy["color"],
|
|
alpha=0.25,
|
|
)
|
|
|
|
ax.set(
|
|
xlabel=f"{prev_name} Prevalence",
|
|
ylabel=f"Prediction Error for {acc_name}",
|
|
)
|
|
|
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
|
|
|
return fig
|