plots refactoring started
This commit is contained in:
parent
4a06c83256
commit
b8e43c02f2
|
@ -1,7 +1,7 @@
|
|||
from quacc.plot.plot import (
|
||||
from quacc.legacy.plot.plot import (
|
||||
get_backend,
|
||||
plot_delta,
|
||||
plot_diagonal,
|
||||
plot_shift,
|
||||
plot_fit_scores,
|
||||
plot_shift,
|
||||
)
|
||||
|
|
|
@ -1,78 +1,49 @@
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import plotly
|
||||
import plotly.graph_objects as go
|
||||
|
||||
from quacc.evaluation.estimators import CE, _renames
|
||||
from quacc.plot.base import BasePlot
|
||||
|
||||
|
||||
class PlotCfg:
|
||||
def __init__(self, mode, lwidth, font=None, legend=None, template="seaborn"):
|
||||
self.mode = mode
|
||||
self.lwidth = lwidth
|
||||
self.legend = {} if legend is None else legend
|
||||
self.font = {} if font is None else font
|
||||
self.template = template
|
||||
|
||||
|
||||
web_cfg = PlotCfg("lines+markers", 2)
|
||||
png_cfg_old = PlotCfg(
|
||||
"lines",
|
||||
5,
|
||||
legend=dict(
|
||||
orientation="h",
|
||||
yanchor="bottom",
|
||||
xanchor="right",
|
||||
y=1.02,
|
||||
x=1,
|
||||
font=dict(size=24),
|
||||
),
|
||||
font=dict(size=24),
|
||||
# template="ggplot2",
|
||||
)
|
||||
png_cfg = PlotCfg(
|
||||
"lines",
|
||||
5,
|
||||
legend=dict(
|
||||
font=dict(
|
||||
family="DejaVu Sans",
|
||||
size=24,
|
||||
),
|
||||
),
|
||||
font=dict(size=24),
|
||||
# template="ggplot2",
|
||||
)
|
||||
|
||||
_cfg = png_cfg
|
||||
|
||||
|
||||
class PlotlyPlot(BasePlot):
|
||||
__themes = defaultdict(
|
||||
lambda: {
|
||||
"template": _cfg.template,
|
||||
MODE = "lines"
|
||||
L_WIDTH = 5
|
||||
LEGEND = {
|
||||
"font": {
|
||||
"family": "DejaVu Sans",
|
||||
"size": 24,
|
||||
}
|
||||
}
|
||||
FONT = {"size": 24}
|
||||
TEMPLATE = "ggplot2"
|
||||
|
||||
|
||||
def _save_or_return(fig, basedir, dataset_name, measure_name, plot_type):
|
||||
if basedir is not None:
|
||||
plotsubdir = dataset_name
|
||||
os.path.join(basedir, "plots", measure_name, plotsubdir, plot_type + ".svg")
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _update_layout(fig, title, x_label, y_label, **kwargs):
|
||||
fig.update_layout(
|
||||
xaxis_title=x_label,
|
||||
yaxis_title=y_label,
|
||||
template=TEMPLATE,
|
||||
font=FONT,
|
||||
legend=LEGEND,
|
||||
**kwargs,
|
||||
)
|
||||
__themes = __themes | {
|
||||
"dark": {
|
||||
"template": "plotly_dark",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, theme=None):
|
||||
self.theme = PlotlyPlot.__themes[theme]
|
||||
self.rename = True
|
||||
|
||||
def hex_to_rgb(self, hex: str, t: float | None = None):
|
||||
def _hex_to_rgb(self, hex: str, t: float | None = None):
|
||||
hex = hex.lstrip("#")
|
||||
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
|
||||
if t is not None:
|
||||
rgb.append(t)
|
||||
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
|
||||
|
||||
def get_colors(self, num):
|
||||
|
||||
def _get_colors(self, num):
|
||||
match num:
|
||||
case v if v > 10:
|
||||
__colors = plotly.colors.qualitative.Light24
|
||||
|
@ -86,156 +57,45 @@ class PlotlyPlot(BasePlot):
|
|||
|
||||
return __generator(__colors)
|
||||
|
||||
def update_layout(self, fig, title, x_label, y_label):
|
||||
fig.update_layout(
|
||||
# title=title,
|
||||
xaxis_title=x_label,
|
||||
yaxis_title=y_label,
|
||||
template=self.theme["template"],
|
||||
font=_cfg.font,
|
||||
legend=_cfg.legend,
|
||||
)
|
||||
|
||||
def save_fig(self, fig, base_path, title) -> Path:
|
||||
return None
|
||||
def _get_ref_limits(true_accs: np.ndarray, estim_accs: dict[str, np.ndarray]):
|
||||
"""get lmits of reference line"""
|
||||
|
||||
def rename_plots(
|
||||
self,
|
||||
columns,
|
||||
):
|
||||
if not self.rename:
|
||||
return columns
|
||||
|
||||
new_columns = []
|
||||
for c in columns:
|
||||
nc = c
|
||||
for old, new in _renames.items():
|
||||
if c.startswith(old):
|
||||
nc = new + c[len(old) :]
|
||||
|
||||
new_columns.append(nc)
|
||||
|
||||
return np.array(new_columns)
|
||||
|
||||
def plot_delta(
|
||||
self,
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="prevs.",
|
||||
y_label="error",
|
||||
legend=True,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
if isinstance(base_prevs[0], float):
|
||||
base_prevs = np.around([(1 - bp, bp) for bp in base_prevs], decimals=4)
|
||||
x = [str(tuple(bp)) for bp in base_prevs]
|
||||
named_data = {c: d for c, d in zip(columns, data)}
|
||||
r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))}
|
||||
line_colors = self.get_colors(len(columns))
|
||||
# for name, delta in zip(columns, data):
|
||||
columns = np.array(CE.name.sort(columns))
|
||||
for name in columns:
|
||||
delta = named_data[name]
|
||||
r_name = r_columns[name]
|
||||
color = next(line_colors)
|
||||
_line = [
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=delta,
|
||||
mode=_cfg.mode,
|
||||
name=r_name,
|
||||
line=dict(color=self.hex_to_rgb(color), width=_cfg.lwidth),
|
||||
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}",
|
||||
)
|
||||
]
|
||||
_error = []
|
||||
if stdevs is not None:
|
||||
_col_idx = np.where(columns == name)[0]
|
||||
stdev = stdevs[_col_idx].flatten()
|
||||
_error = [
|
||||
go.Scatter(
|
||||
x=np.concatenate([x, x[::-1]]),
|
||||
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
|
||||
name=int(_col_idx[0]),
|
||||
fill="toself",
|
||||
fillcolor=self.hex_to_rgb(color, t=0.2),
|
||||
line=dict(color="rgba(255, 255, 255, 0)"),
|
||||
hoverinfo="skip",
|
||||
showlegend=False,
|
||||
)
|
||||
]
|
||||
fig.add_traces(_line + _error)
|
||||
|
||||
self.update_layout(fig, title, x_label, y_label)
|
||||
return fig
|
||||
|
||||
def plot_diagonal(
|
||||
self,
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
fixed_lim=False,
|
||||
legend=True,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
x = reference
|
||||
line_colors = self.get_colors(len(columns))
|
||||
|
||||
if fixed_lim:
|
||||
_lims = np.array([[0.0, 1.0], [0.0, 1.0]])
|
||||
else:
|
||||
_edges = (
|
||||
np.min([np.min(x), np.min(data)]),
|
||||
np.max([np.max(x), np.max(data)]),
|
||||
np.min([np.min(true_accs), np.min(estim_accs)]),
|
||||
np.max([np.max(true_accs), np.max(estim_accs)]),
|
||||
)
|
||||
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
|
||||
|
||||
named_data = {c: d for c, d in zip(columns, data)}
|
||||
r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))}
|
||||
columns = np.array(CE.name.sort(columns))
|
||||
for name in columns:
|
||||
val = named_data[name]
|
||||
r_name = r_columns[name]
|
||||
|
||||
def plot_diagonal(
|
||||
method_names,
|
||||
true_accs,
|
||||
estim_accs,
|
||||
*,
|
||||
measure_name="vanilla_accuracy",
|
||||
dataset_name=None,
|
||||
basedir=None,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
x = true_accs
|
||||
line_colors = _get_colors(len(method_names))
|
||||
_lims = _get_ref_limits(true_accs, estim_accs)
|
||||
|
||||
for name, estim in zip(method_names, estim_accs):
|
||||
color = next(line_colors)
|
||||
slope, interc = np.polyfit(x, val, 1)
|
||||
# y_lr = np.array([slope * _x + interc for _x in _lims[0]])
|
||||
slope, interc = np.polyfit(x, estim, 1)
|
||||
fig.add_traces(
|
||||
[
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=val,
|
||||
customdata=np.stack((val - x,), axis=-1),
|
||||
y=estim,
|
||||
customdata=np.stack((estim - x,), axis=-1),
|
||||
mode="markers",
|
||||
name=r_name,
|
||||
marker=dict(color=self.hex_to_rgb(color, t=0.5)),
|
||||
name=name,
|
||||
marker=dict(color=_hex_to_rgb(color, t=0.5)),
|
||||
hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}",
|
||||
# showlegend=False,
|
||||
),
|
||||
# go.Scatter(
|
||||
# x=[x[-1]],
|
||||
# y=[val[-1]],
|
||||
# mode="markers",
|
||||
# marker=dict(color=self.hex_to_rgb(color), size=8),
|
||||
# name=r_name,
|
||||
# ),
|
||||
# go.Scatter(
|
||||
# x=_lims[0],
|
||||
# y=y_lr,
|
||||
# mode="lines",
|
||||
# name=name,
|
||||
# line=dict(color=self.hex_to_rgb(color), width=3),
|
||||
# showlegend=False,
|
||||
# ),
|
||||
]
|
||||
)
|
||||
fig.add_trace(
|
||||
|
@ -245,12 +105,14 @@ class PlotlyPlot(BasePlot):
|
|||
mode="lines",
|
||||
name="reference",
|
||||
showlegend=False,
|
||||
line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
|
||||
line=dict(color=_hex_to_rgb("#000000"), dash="dash"),
|
||||
)
|
||||
)
|
||||
|
||||
self.update_layout(fig, title, x_label, y_label)
|
||||
fig.update_layout(
|
||||
_update_layout(
|
||||
fig,
|
||||
x_label=f"True {measure_name}",
|
||||
y_label=f"Estimated {measure_name}",
|
||||
autosize=False,
|
||||
width=1300,
|
||||
height=1000,
|
||||
|
@ -258,73 +120,98 @@ class PlotlyPlot(BasePlot):
|
|||
yaxis_scaleratio=1.0,
|
||||
yaxis_range=[-0.1, 1.1],
|
||||
)
|
||||
return fig
|
||||
return _save_or_return(fig, basedir, dataset_name, measure_name, "diagonal")
|
||||
|
||||
def plot_shift(
|
||||
self,
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
|
||||
def plot_delta(
|
||||
method_names: list[str],
|
||||
prevs: np.ndarray,
|
||||
acc_errs: np.ndarray,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
) -> go.Figure:
|
||||
stdevs: np.ndarray | None = None,
|
||||
prev_name="Test",
|
||||
measure_name="Vanilla Accuracy",
|
||||
dataset_name=None,
|
||||
basedir=None,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
# x = shift_prevs[:, pos_class]
|
||||
x = shift_prevs
|
||||
line_colors = self.get_colors(len(columns))
|
||||
named_data = {c: d for c, d in zip(columns, data)}
|
||||
r_columns = {c: r for c, r in zip(columns, self.rename_plots(columns))}
|
||||
columns = np.array(CE.name.sort(columns))
|
||||
for name in columns:
|
||||
delta = named_data[name]
|
||||
r_name = r_columns[name]
|
||||
col_idx = (columns == name).nonzero()[0][0]
|
||||
x = [str(bp) for bp in prevs]
|
||||
line_colors = _get_colors(len(method_names))
|
||||
if stdevs is None:
|
||||
stdevs = [None] * len(method_names)
|
||||
for name, delta, stdev in zip(method_names, acc_errs, stdevs):
|
||||
color = next(line_colors)
|
||||
_line = [
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=delta,
|
||||
mode=MODE,
|
||||
name=name,
|
||||
line=dict(color=_hex_to_rgb(color), width=L_WIDTH),
|
||||
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}",
|
||||
)
|
||||
]
|
||||
_error = []
|
||||
if stdev is not None:
|
||||
_error = [
|
||||
go.Scatter(
|
||||
x=np.concatenate([x, x[::-1]]),
|
||||
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
|
||||
name=name,
|
||||
fill="toself",
|
||||
fillcolor=_hex_to_rgb(color, t=0.2),
|
||||
line=dict(color="rgba(255, 255, 255, 0)"),
|
||||
hoverinfo="skip",
|
||||
showlegend=False,
|
||||
)
|
||||
]
|
||||
fig.add_traces(_line + _error)
|
||||
|
||||
_update_layout(
|
||||
fig,
|
||||
x_label=f"{prev_name} Prevalence",
|
||||
y_label=f"Prediction Error for {measure_name}",
|
||||
)
|
||||
return _save_or_return(
|
||||
fig, basedir, dataset_name, measure_name, "delta" if stdevs is None else "stdev"
|
||||
)
|
||||
|
||||
|
||||
def plot_shift(
|
||||
method_names: list[str],
|
||||
prevs: np.ndarray,
|
||||
acc_errs: np.ndarray,
|
||||
*,
|
||||
counts: np.ndarray | None = None,
|
||||
measure_name="Vanilla Accuracy",
|
||||
dataset_name=None,
|
||||
basedir=None,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
x = prevs
|
||||
line_colors = _get_colors(len(method_names))
|
||||
if counts is None:
|
||||
counts = [None] * len(method_names)
|
||||
for name, delta, count in zip(method_names, acc_errs, counts):
|
||||
color = next(line_colors)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=delta,
|
||||
customdata=np.stack((counts[col_idx],), axis=-1),
|
||||
mode=_cfg.mode,
|
||||
name=r_name,
|
||||
line=dict(color=self.hex_to_rgb(color), width=_cfg.lwidth),
|
||||
customdata=np.stack((count,), axis=-1),
|
||||
mode=MODE,
|
||||
name=name,
|
||||
line=dict(color=_hex_to_rgb(color), width=L_WIDTH),
|
||||
hovertemplate="shift: %{x}<br>error: %{y}"
|
||||
+ "<br>count: %{customdata[0]}"
|
||||
if counts is not None
|
||||
if count is not None
|
||||
else "",
|
||||
)
|
||||
)
|
||||
|
||||
self.update_layout(fig, title, x_label, y_label)
|
||||
return fig
|
||||
|
||||
def plot_fit_scores(
|
||||
self,
|
||||
train_prevs,
|
||||
scores,
|
||||
*,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="prev.",
|
||||
y_label="position",
|
||||
legend=True,
|
||||
) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
# x = train_prevs
|
||||
x = [str(tuple(bp)) for bp in train_prevs]
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=scores,
|
||||
mode="lines+markers",
|
||||
showlegend=False,
|
||||
),
|
||||
_update_layout(
|
||||
fig,
|
||||
x_label="Amount of Prior Probability Shift",
|
||||
y_label=f"Prediction Error for {measure_name}",
|
||||
)
|
||||
|
||||
self.update_layout(fig, title, x_label, y_label)
|
||||
return fig
|
||||
return _save_or_return(fig, basedir, dataset_name, measure_name, "shift")
|
||||
|
|
Loading…
Reference in New Issue