matplotlib added, shift missing, plotly updated
This commit is contained in:
parent
8a087e3e2f
commit
ad9fdef786
|
@ -0,0 +1,149 @@
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from cycler import cycler
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
from quacc.plot.utils import _get_ref_limits
|
||||||
|
from quacc.utils.commons import get_plots_path
|
||||||
|
|
||||||
|
|
||||||
|
def _get_markers(num: int):
|
||||||
|
ls = "ovx+sDph*^1234X><.Pd"
|
||||||
|
if num > len(ls):
|
||||||
|
ls = ls * (num / len(ls) + 1)
|
||||||
|
return list(ls)[:num]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cycler(num):
|
||||||
|
cm = plt.get_cmap("tab20") if num > 10 else plt.get_cmap("tab10")
|
||||||
|
return cycler(color=[cm(i) for i in range(num)])
|
||||||
|
|
||||||
|
|
||||||
|
def _save_or_return(
|
||||||
|
fig: Figure, basedir, cls_name, acc_name, dataset_name, plot_type
|
||||||
|
) -> Figure | None:
|
||||||
|
if basedir is None:
|
||||||
|
return fig
|
||||||
|
|
||||||
|
plotsubdir = "all" if dataset_name == "*" else dataset_name
|
||||||
|
file = get_plots_path(basedir, cls_name, acc_name, plotsubdir, plot_type)
|
||||||
|
os.makedirs(Path(file).parent, exist_ok=True)
|
||||||
|
fig.savefig(file)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_diagonal(
|
||||||
|
method_names: list[str],
|
||||||
|
true_accs: np.ndarray,
|
||||||
|
estim_accs: np.ndarray,
|
||||||
|
cls_name,
|
||||||
|
acc_name,
|
||||||
|
dataset_name,
|
||||||
|
*,
|
||||||
|
basedir=None,
|
||||||
|
):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.grid()
|
||||||
|
ax.set_aspect("equal")
|
||||||
|
|
||||||
|
cy = _get_cycler(len(method_names))
|
||||||
|
|
||||||
|
for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy):
|
||||||
|
ax.plot(
|
||||||
|
x,
|
||||||
|
estim,
|
||||||
|
label=name,
|
||||||
|
color=_cy["color"],
|
||||||
|
linestyle="None",
|
||||||
|
marker="o",
|
||||||
|
markersize=3,
|
||||||
|
zorder=2,
|
||||||
|
alpha=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure limits are equal for both axes
|
||||||
|
_lims = _get_ref_limits(true_accs, estim_accs)
|
||||||
|
ax.set(xlim=_lims[0], ylim=_lims[1])
|
||||||
|
|
||||||
|
# draw polyfit line per method
|
||||||
|
# for name, x, estim, _cy in zip(method_names, true_accs, estim_accs, cy):
|
||||||
|
# slope, interc = np.polyfit(x, estim, 1)
|
||||||
|
# y_lr = np.array([slope * x + interc for x in _lims])
|
||||||
|
# ax.plot(
|
||||||
|
# _lims,
|
||||||
|
# y_lr,
|
||||||
|
# label=name,
|
||||||
|
# color=_cy["color"],
|
||||||
|
# linestyle="-",
|
||||||
|
# markersize="0",
|
||||||
|
# zorder=1,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# plot reference line
|
||||||
|
ax.plot(
|
||||||
|
_lims,
|
||||||
|
_lims,
|
||||||
|
color="black",
|
||||||
|
linestyle="--",
|
||||||
|
markersize=0,
|
||||||
|
zorder=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set(xlabel=f"True {acc_name}", ylabel=f"Estimated {acc_name}")
|
||||||
|
|
||||||
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
|
return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_delta(
|
||||||
|
method_names: list[str],
|
||||||
|
prevs: np.ndarray,
|
||||||
|
acc_errs: np.ndarray,
|
||||||
|
cls_name,
|
||||||
|
acc_name,
|
||||||
|
dataset_name,
|
||||||
|
prev_name,
|
||||||
|
*,
|
||||||
|
stdevs: np.ndarray | None = None,
|
||||||
|
basedir=None,
|
||||||
|
):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.set_aspect("auto")
|
||||||
|
ax.grid()
|
||||||
|
|
||||||
|
cy = _get_cycler(len(method_names))
|
||||||
|
|
||||||
|
x = [str(bp) for bp in prevs]
|
||||||
|
if stdevs is None:
|
||||||
|
stdevs = [None] * len(method_names)
|
||||||
|
for name, delta, stdev, _cy in zip(method_names, acc_errs, stdevs, cy):
|
||||||
|
ax.plot(
|
||||||
|
x,
|
||||||
|
delta,
|
||||||
|
label=name,
|
||||||
|
color=_cy["color"],
|
||||||
|
linestyle="-",
|
||||||
|
marker="",
|
||||||
|
markersize=3,
|
||||||
|
zorder=2,
|
||||||
|
)
|
||||||
|
if stdev is not None:
|
||||||
|
ax.fill_between(
|
||||||
|
prevs,
|
||||||
|
delta - stdev,
|
||||||
|
delta + stdev,
|
||||||
|
color=_cy["color"],
|
||||||
|
alpha=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set(
|
||||||
|
xlabel=f"{prev_name} Prevalence",
|
||||||
|
ylabel=f"Prediction Error for {acc_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
|
return fig
|
|
@ -1,10 +1,8 @@
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import plotly
|
import plotly
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
|
|
||||||
from quacc.utils.commons import get_plots_path
|
from quacc.plot.utils import _get_ref_limits
|
||||||
|
|
||||||
MODE = "lines"
|
MODE = "lines"
|
||||||
L_WIDTH = 5
|
L_WIDTH = 5
|
||||||
|
@ -18,17 +16,7 @@ FONT = {"size": 24}
|
||||||
TEMPLATE = "ggplot2"
|
TEMPLATE = "ggplot2"
|
||||||
|
|
||||||
|
|
||||||
def _save_or_return(
|
def _update_layout(fig, x_label, y_label, **kwargs):
|
||||||
fig: go.Figure, basedir, cls_name, acc_name, dataset_name, plot_type
|
|
||||||
) -> go.Figure | None:
|
|
||||||
if basedir is None:
|
|
||||||
return fig
|
|
||||||
|
|
||||||
path = get_plots_path(basedir, cls_name, acc_name, dataset_name, plot_type)
|
|
||||||
fig.write_image(path)
|
|
||||||
|
|
||||||
|
|
||||||
def _update_layout(fig, title, x_label, y_label, **kwargs):
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
xaxis_title=x_label,
|
xaxis_title=x_label,
|
||||||
yaxis_title=y_label,
|
yaxis_title=y_label,
|
||||||
|
@ -39,7 +27,7 @@ def _update_layout(fig, title, x_label, y_label, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _hex_to_rgb(self, hex: str, t: float | None = None):
|
def _hex_to_rgb(hex: str, t: float | None = None):
|
||||||
hex = hex.lstrip("#")
|
hex = hex.lstrip("#")
|
||||||
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
|
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
|
||||||
if t is not None:
|
if t is not None:
|
||||||
|
@ -47,7 +35,7 @@ def _hex_to_rgb(self, hex: str, t: float | None = None):
|
||||||
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
|
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
|
||||||
|
|
||||||
|
|
||||||
def _get_colors(self, num):
|
def _get_colors(num):
|
||||||
match num:
|
match num:
|
||||||
case v if v > 10:
|
case v if v > 10:
|
||||||
__colors = plotly.colors.qualitative.Light24
|
__colors = plotly.colors.qualitative.Light24
|
||||||
|
@ -62,16 +50,6 @@ def _get_colors(self, num):
|
||||||
return __generator(__colors)
|
return __generator(__colors)
|
||||||
|
|
||||||
|
|
||||||
def _get_ref_limits(true_accs: np.ndarray, estim_accs: dict[str, np.ndarray]):
|
|
||||||
"""get lmits of reference line"""
|
|
||||||
|
|
||||||
_edges = (
|
|
||||||
np.min([np.min(true_accs), np.min(estim_accs)]),
|
|
||||||
np.max([np.max(true_accs), np.max(estim_accs)]),
|
|
||||||
)
|
|
||||||
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
|
|
||||||
|
|
||||||
|
|
||||||
def plot_diagonal(
|
def plot_diagonal(
|
||||||
method_names,
|
method_names,
|
||||||
true_accs,
|
true_accs,
|
||||||
|
@ -83,11 +61,10 @@ def plot_diagonal(
|
||||||
basedir=None,
|
basedir=None,
|
||||||
) -> go.Figure:
|
) -> go.Figure:
|
||||||
fig = go.Figure()
|
fig = go.Figure()
|
||||||
x = true_accs
|
|
||||||
line_colors = _get_colors(len(method_names))
|
line_colors = _get_colors(len(method_names))
|
||||||
_lims = _get_ref_limits(true_accs, estim_accs)
|
_lims = _get_ref_limits(true_accs, estim_accs)
|
||||||
|
|
||||||
for name, estim in zip(method_names, estim_accs):
|
for name, x, estim in zip(method_names, true_accs, estim_accs):
|
||||||
color = next(line_colors)
|
color = next(line_colors)
|
||||||
slope, interc = np.polyfit(x, estim, 1)
|
slope, interc = np.polyfit(x, estim, 1)
|
||||||
fig.add_traces(
|
fig.add_traces(
|
||||||
|
@ -125,7 +102,8 @@ def plot_diagonal(
|
||||||
yaxis_scaleratio=1.0,
|
yaxis_scaleratio=1.0,
|
||||||
yaxis_range=[-0.1, 1.1],
|
yaxis_range=[-0.1, 1.1],
|
||||||
)
|
)
|
||||||
return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
|
# return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal")
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def plot_delta(
|
def plot_delta(
|
||||||
|
@ -133,7 +111,7 @@ def plot_delta(
|
||||||
prevs: np.ndarray,
|
prevs: np.ndarray,
|
||||||
acc_errs: np.ndarray,
|
acc_errs: np.ndarray,
|
||||||
cls_name,
|
cls_name,
|
||||||
acc_mame,
|
acc_name,
|
||||||
dataset_name,
|
dataset_name,
|
||||||
prev_name,
|
prev_name,
|
||||||
*,
|
*,
|
||||||
|
@ -176,16 +154,17 @@ def plot_delta(
|
||||||
_update_layout(
|
_update_layout(
|
||||||
fig,
|
fig,
|
||||||
x_label=f"{prev_name} Prevalence",
|
x_label=f"{prev_name} Prevalence",
|
||||||
y_label=f"Prediction Error for {acc_mame}",
|
y_label=f"Prediction Error for {acc_name}",
|
||||||
)
|
|
||||||
return _save_or_return(
|
|
||||||
fig,
|
|
||||||
basedir,
|
|
||||||
cls_name,
|
|
||||||
acc_mame,
|
|
||||||
dataset_name,
|
|
||||||
"delta" if stdevs is None else "stdev",
|
|
||||||
)
|
)
|
||||||
|
# return _save_or_return(
|
||||||
|
# fig,
|
||||||
|
# basedir,
|
||||||
|
# cls_name,
|
||||||
|
# acc_mame,
|
||||||
|
# dataset_name,
|
||||||
|
# "delta" if stdevs is None else "stdev",
|
||||||
|
# )
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def plot_shift(
|
def plot_shift(
|
||||||
|
@ -226,4 +205,5 @@ def plot_shift(
|
||||||
x_label="Amount of Prior Probability Shift",
|
x_label="Amount of Prior Probability Shift",
|
||||||
y_label=f"Prediction Error for {acc_name}",
|
y_label=f"Prediction Error for {acc_name}",
|
||||||
)
|
)
|
||||||
return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "shift")
|
# return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "shift")
|
||||||
|
return fig
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
import numpy as np
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
|
||||||
|
from quacc.utils.commons import get_plots_path
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ref_limits(true_accs: np.ndarray, estim_accs: np.ndarray):
|
||||||
|
"""get lmits of reference line"""
|
||||||
|
|
||||||
|
_edges = (
|
||||||
|
np.min([np.min(true_accs), np.min(estim_accs)]),
|
||||||
|
np.max([np.max(true_accs), np.max(estim_accs)]),
|
||||||
|
)
|
||||||
|
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
|
||||||
|
return _lims
|
Loading…
Reference in New Issue