matplotlib added, shift missing, plotly updated

This commit is contained in:
Lorenzo Volpi 2024-04-08 17:58:56 +02:00
parent 8a087e3e2f
commit ad9fdef786
3 changed files with 184 additions and 40 deletions

149
quacc/plot/matplotlib.py Normal file
View File

@ -0,0 +1,149 @@
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from matplotlib.figure import Figure
from quacc.plot.utils import _get_ref_limits
from quacc.utils.commons import get_plots_path
def _get_markers(num: int):
ls = "ovx+sDph*^1234X><.Pd"
if num > len(ls):
ls = ls * (num / len(ls) + 1)
return list(ls)[:num]
def _get_cycler(num):
cm = plt.get_cmap("tab20") if num > 10 else plt.get_cmap("tab10")
return cycler(color=[cm(i) for i in range(num)])
def _save_or_return(
fig: Figure, basedir, cls_name, acc_name, dataset_name, plot_type
) -> Figure | None:
if basedir is None:
return fig
plotsubdir = "all" if dataset_name == "*" else dataset_name
file = get_plots_path(basedir, cls_name, acc_name, plotsubdir, plot_type)
os.makedirs(Path(file).parent, exist_ok=True)
fig.savefig(file)
def plot_diagonal(
method_names: list[str],
true_accs: np.ndarray,
estim_accs: np.ndarray,
cls_name,
acc_name,
dataset_name,
*,
basedir=None,
):
fig, ax = plt.subplots()
ax.grid()
ax.set_aspect("equal")
cy = _get_cycler(len(method_names))
for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy):
ax.plot(
x,
estim,
label=name,
color=_cy["color"],
linestyle="None",
marker="o",
markersize=3,
zorder=2,
alpha=0.25,
)
# ensure limits are equal for both axes
_lims = _get_ref_limits(true_accs, estim_accs)
ax.set(xlim=_lims[0], ylim=_lims[1])
# draw polyfit line per method
# for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy):
# slope, interc = np.polyfit(x, estim, 1)
# y_lr = np.array([slope * x + interc for x in _lims])
# ax.plot(
# _lims,
# y_lr,
# label=name,
# color=_cy["color"],
# linestyle="-",
# markersize="0",
# zorder=1,
# )
# plot reference line
ax.plot(
_lims,
_lims,
color="black",
linestyle="--",
markersize=0,
zorder=1,
)
ax.set(xlabel=f"True {acc_name}", ylabel=f"Estimated {acc_name}")
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
def plot_delta(
method_names: list[str],
prevs: np.ndarray,
acc_errs: np.ndarray,
cls_name,
acc_name,
dataset_name,
prev_name,
*,
stdevs: np.ndarray | None = None,
basedir=None,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
cy = _get_cycler(len(method_names))
x = [str(bp) for bp in prevs]
if stdevs is None:
stdevs = [None] * len(method_names)
for name, delta, stdev, _cy in zip(method_names, acc_errs, stdevs, cy):
ax.plot(
x,
delta,
label=name,
color=_cy["color"],
linestyle="-",
marker="",
markersize=3,
zorder=2,
)
if stdev is not None:
ax.fill_between(
prevs,
delta - stdev,
delta + stdev,
color=_cy["color"],
alpha=0.25,
)
ax.set(
xlabel=f"{prev_name} Prevalence",
ylabel=f"Prediction Error for {acc_name}",
)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig

View File

@ -1,10 +1,8 @@
from pathlib import Path
import numpy as np
import plotly
import plotly.graph_objects as go
from quacc.utils.commons import get_plots_path
from quacc.plot.utils import _get_ref_limits
MODE = "lines"
L_WIDTH = 5
@ -18,17 +16,7 @@ FONT = {"size": 24}
TEMPLATE = "ggplot2"
def _save_or_return(
fig: go.Figure, basedir, cls_name, acc_name, dataset_name, plot_type
) -> go.Figure | None:
if basedir is None:
return fig
path = get_plots_path(basedir, cls_name, acc_name, dataset_name, plot_type)
fig.write_image(path)
def _update_layout(fig, title, x_label, y_label, **kwargs):
def _update_layout(fig, x_label, y_label, **kwargs):
fig.update_layout(
xaxis_title=x_label,
yaxis_title=y_label,
@ -39,7 +27,7 @@ def _update_layout(fig, title, x_label, y_label, **kwargs):
)
def _hex_to_rgb(self, hex: str, t: float | None = None):
def _hex_to_rgb(hex: str, t: float | None = None):
hex = hex.lstrip("#")
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
if t is not None:
@ -47,7 +35,7 @@ def _hex_to_rgb(self, hex: str, t: float | None = None):
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
def _get_colors(self, num):
def _get_colors(num):
match num:
case v if v > 10:
__colors = plotly.colors.qualitative.Light24
@ -62,16 +50,6 @@ def _get_colors(self, num):
return __generator(__colors)
def _get_ref_limits(true_accs: np.ndarray, estim_accs: dict[str, np.ndarray]):
"""get lmits of reference line"""
_edges = (
np.min([np.min(true_accs), np.min(estim_accs)]),
np.max([np.max(true_accs), np.max(estim_accs)]),
)
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
def plot_diagonal(
method_names,
true_accs,
@ -83,11 +61,10 @@ def plot_diagonal(
basedir=None,
) -> go.Figure:
fig = go.Figure()
x = true_accs
line_colors = _get_colors(len(method_names))
_lims = _get_ref_limits(true_accs, estim_accs)
for name, estim in zip(method_names, estim_accs):
for name, x, estim in zip(method_names, true_accs, estim_accs):
color = next(line_colors)
slope, interc = np.polyfit(x, estim, 1)
fig.add_traces(
@ -125,7 +102,8 @@ def plot_diagonal(
yaxis_scaleratio=1.0,
yaxis_range=[-0.1, 1.1],
)
return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
# return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
return fig
def plot_delta(
@ -133,7 +111,7 @@ def plot_delta(
prevs: np.ndarray,
acc_errs: np.ndarray,
cls_name,
acc_mame,
acc_name,
dataset_name,
prev_name,
*,
@ -176,16 +154,17 @@ def plot_delta(
_update_layout(
fig,
x_label=f"{prev_name} Prevalence",
y_label=f"Prediction Error for {acc_mame}",
)
return _save_or_return(
fig,
basedir,
cls_name,
acc_mame,
dataset_name,
"delta" if stdevs is None else "stdev",
y_label=f"Prediction Error for {acc_name}",
)
# return _save_or_return(
# fig,
# basedir,
# cls_name,
# acc_mame,
# dataset_name,
# "delta" if stdevs is None else "stdev",
# )
return fig
def plot_shift(
@ -226,4 +205,5 @@ def plot_shift(
x_label="Amount of Prior Probability Shift",
y_label=f"Prediction Error for {acc_name}",
)
return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "shift")
# return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "shift")
return fig

15
quacc/plot/utils.py Normal file
View File

@ -0,0 +1,15 @@
import numpy as np
import plotly.graph_objects as go
from quacc.utils.commons import get_plots_path
def _get_ref_limits(true_accs: np.ndarray, estim_accs: np.ndarray):
"""get lmits of reference line"""
_edges = (
np.min([np.min(true_accs), np.min(estim_accs)]),
np.max([np.max(true_accs), np.max(estim_accs)]),
)
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
return _lims