plot refactored

This commit is contained in:
Lorenzo Volpi 2023-11-29 03:55:38 +01:00
parent deeb522ccb
commit f0bfb2e039
6 changed files with 460 additions and 22 deletions

View File

@ -52,7 +52,7 @@ def create_plots(
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
save_fig=False,
)
return (
pn.pane.Matplotlib(
@ -87,7 +87,7 @@ def create_plots(
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
save_fig=False,
)
return (
pn.pane.Matplotlib(

View File

@ -6,7 +6,7 @@ from typing import List, Tuple
import numpy as np
import pandas as pd
from quacc import plot
import quacc.plot as plot
from quacc.utils import fmt_line_md
@ -214,16 +214,17 @@ class CompReport:
def get_plots(
self,
mode="delta",
mode="delta_train",
metric="acc",
estimators=None,
conf="default",
return_fig=False,
save_fig=True,
base_path=None,
backend=None,
) -> List[Tuple[str, Path]]:
if mode == "delta_train":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
if avg_data.empty is True:
if avg_data.empty:
return None
return plot.plot_delta(
@ -233,8 +234,9 @@ class CompReport:
metric=metric,
name=conf,
train_prev=self.train_prev,
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "stdev_train":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
@ -250,8 +252,9 @@ class CompReport:
name=conf,
train_prev=self.train_prev,
stdevs=st_data.T.to_numpy(),
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "diagonal":
f_data = self.data(metric=metric + "_score", estimators=estimators)
@ -267,8 +270,9 @@ class CompReport:
metric=metric,
name=conf,
train_prev=self.train_prev,
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "shift":
_shift_data = self.shift_data(metric=metric, estimators=estimators)
@ -289,8 +293,9 @@ class CompReport:
name=conf,
train_prev=self.train_prev,
counts=shift_counts.T.to_numpy(),
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
def to_md(
@ -322,11 +327,12 @@ class CompReport:
plot_modes = [m for m in modes if not m.endswith("table")]
for mode in plot_modes:
res += f"### {mode}\n"
op = self.get_plots(
_, op = self.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf=conf,
save_fig=True,
base_path=plot_path,
)
res += f"![plot_{mode}]({op.relative_to(op.parents[1]).as_posix()})\n"
@ -423,8 +429,9 @@ class DatasetReport:
metric="acc",
estimators=None,
conf="default",
return_fig=False,
save_fig=True,
base_path=None,
backend=None,
):
if mode == "delta_train":
_data = self.data(metric, estimators) if data is None else data
@ -440,8 +447,9 @@ class DatasetReport:
name=conf,
train_prev=None,
avg="train",
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "stdev_train":
_data = self.data(metric, estimators) if data is None else data
@ -459,8 +467,9 @@ class DatasetReport:
train_prev=None,
stdevs=stdev_on_train.T.to_numpy(),
avg="train",
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "delta_test":
_data = self.data(metric, estimators) if data is None else data
@ -474,8 +483,9 @@ class DatasetReport:
name=conf,
train_prev=None,
avg="test",
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "stdev_test":
_data = self.data(metric, estimators) if data is None else data
@ -491,8 +501,9 @@ class DatasetReport:
train_prev=None,
stdevs=stdev_on_test.T.to_numpy(),
avg="test",
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "shift":
_shift_data = self.shift_data(metric, estimators) if data is None else data
@ -507,8 +518,9 @@ class DatasetReport:
name=conf,
train_prev=None,
counts=count_shift.T.to_numpy(),
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
def to_md(
@ -544,24 +556,26 @@ class DatasetReport:
res += avg_on_train_tbl.to_html() + "\n\n"
if "delta_train" in dr_modes:
delta_op = self.get_plots(
_, delta_op = self.get_plots(
data=_data,
mode="delta_train",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
save_fig=True,
)
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n"
if "stdev_train" in dr_modes:
delta_stdev_op = self.get_plots(
_, delta_stdev_op = self.get_plots(
data=_data,
mode="stdev_train",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
save_fig=True,
)
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n"
@ -574,24 +588,26 @@ class DatasetReport:
res += avg_on_test_tbl.to_html() + "\n\n"
if "delta_test" in dr_modes:
delta_op = self.get_plots(
_, delta_op = self.get_plots(
data=_data,
mode="delta_test",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
save_fig=True,
)
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n"
if "stdev_test" in dr_modes:
delta_stdev_op = self.get_plots(
_, delta_stdev_op = self.get_plots(
data=_data,
mode="stdev_test",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
save_fig=True,
)
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n"
@ -604,13 +620,14 @@ class DatasetReport:
res += shift_on_train_tbl.to_html() + "\n\n"
if "shift" in dr_modes:
shift_op = self.get_plots(
_, shift_op = self.get_plots(
data=_shift_data,
mode="shift",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
save_fig=True,
)
res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n"

1
quacc/plot/__init__.py Normal file
View File

@ -0,0 +1 @@
from quacc.plot.plot import get_backend, plot_delta, plot_diagonal, plot_shift

54
quacc/plot/base.py Normal file
View File

@ -0,0 +1,54 @@
from pathlib import Path
class BasePlot:
@classmethod
def save_fig(cls, fig, base_path, title) -> Path:
...
@classmethod
def plot_diagonal(
cls,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
...
@classmethod
def plot_delta(
cls,
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
title="default",
x_label="prevs.",
y_label="error",
legend=True,
):
...
@classmethod
def plot_shift(
cls,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
...

222
quacc/plot/mpl.py Normal file
View File

@ -0,0 +1,222 @@
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from quacc import utils
from quacc.plot.base import BasePlot
matplotlib.use("agg")
class MplPlot(BasePlot):
def _get_markers(self, n: int):
ls = "ovx+sDph*^1234X><.Pd"
if n > len(ls):
ls = ls * (n / len(ls) + 1)
return list(ls)[:n]
def save_fig(self, fig, base_path, title) -> Path:
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
def plot_delta(
self,
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
title="default",
x_label="prevs.",
y_label="error",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
base_prevs = base_prevs[:, pos_class]
for method, deltas, _cy in zip(columns, data, cy):
ax.plot(
base_prevs,
deltas,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if stdevs is not None:
_col_idx = np.where(columns == method)[0]
stdev = stdevs[_col_idx].flatten()
nn_idx = np.intersect1d(
np.where(deltas != np.nan)[0],
np.where(stdev != np.nan)[0],
)
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
ax.fill_between(
_bps,
_ds - _st,
_ds + _st,
color=_cy["color"],
alpha=0.25,
)
ax.set(
xlabel=f"{x_label} prevalence",
ylabel=y_label,
title=title,
)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig
def plot_diagonal(
self,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
ax.set_aspect("equal")
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(
color=[cm(i) for i in range(NUM_COLORS)],
marker=self._get_markers(NUM_COLORS),
)
reference = np.array(reference)
x_ticks = np.unique(reference)
x_ticks.sort()
for deltas, _cy in zip(data, cy):
ax.plot(
reference,
deltas,
color=_cy["color"],
linestyle="None",
marker=_cy["marker"],
markersize=3,
zorder=2,
alpha=0.25,
)
# ensure limits are equal for both axes
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
for method, deltas, _cy in zip(columns, data, cy):
slope, interc = np.polyfit(reference, deltas, 1)
y_lr = np.array([slope * x + interc for x in _lims])
ax.plot(
_lims,
y_lr,
label=method,
color=_cy["color"],
linestyle="-",
markersize="0",
zorder=1,
)
# plot reference line
ax.plot(
_lims,
_lims,
color="black",
linestyle="--",
markersize=0,
zorder=1,
)
ax.set(xlabel=x_label, ylabel=y_label, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig
def plot_shift(
self,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
shift_prevs = shift_prevs[:, pos_class]
for method, shifts, _cy in zip(columns, data, cy):
ax.plot(
shift_prevs,
shifts,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if counts is not None:
_col_idx = np.where(columns == method)[0]
count = counts[_col_idx].flatten()
for prev, shift, cnt in zip(shift_prevs, shifts, count):
label = f"{cnt}"
plt.annotate(
label,
(prev, shift),
textcoords="offset points",
xytext=(0, 10),
ha="center",
color=_cy["color"],
fontsize=12.0,
)
ax.set(xlabel=x_label, ylabel=y_label, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig

144
quacc/plot/plot.py Normal file
View File

@ -0,0 +1,144 @@
from quacc.plot.base import BasePlot
from quacc.plot.mpl import MplPlot
from quacc.plot.plotly import PlotlyPlot
__backend: BasePlot = MplPlot()
def get_backend(be, theme=None):
match be:
case "matplotlib" | "mpl":
return MplPlot()
case "plotly":
return PlotlyPlot(theme=theme)
case _:
return MplPlot()
def plot_delta(
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
avg=None,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
_base_title = "delta_stdev" if stdevs is not None else "delta"
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
else:
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
x_label = f"{'test' if avg is None or avg == 'train' else 'train'} prevalence"
y_label = f"{metric} error"
fig = backend.plot_delta(
base_prevs,
columns,
data,
stdevs=stdevs,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig
def plot_diagonal(
reference,
columns,
data,
*,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
else:
title = f"diagonal_{name}_{metric}"
x_label = f"true {metric}"
y_label = f"estim. {metric}"
fig = backend.plot_diagonal(
reference,
columns,
data,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig
def plot_shift(
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"shift_{name}_{t_prev_pos}_{metric}"
else:
title = f"shift_{name}_avg_{metric}"
x_label = "dataset shift"
y_label = f"{metric} error"
fig = backend.plot_shift(
shift_prevs,
columns,
data,
counts=counts,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig