plot refactored

This commit is contained in:
Lorenzo Volpi 2023-11-29 03:55:38 +01:00
parent deeb522ccb
commit f0bfb2e039
6 changed files with 460 additions and 22 deletions

View File

@ -52,7 +52,7 @@ def create_plots(
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf="panel", conf="panel",
return_fig=True, save_fig=False,
) )
return ( return (
pn.pane.Matplotlib( pn.pane.Matplotlib(
@ -87,7 +87,7 @@ def create_plots(
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf="panel", conf="panel",
return_fig=True, save_fig=False,
) )
return ( return (
pn.pane.Matplotlib( pn.pane.Matplotlib(

View File

@ -6,7 +6,7 @@ from typing import List, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from quacc import plot import quacc.plot as plot
from quacc.utils import fmt_line_md from quacc.utils import fmt_line_md
@ -214,16 +214,17 @@ class CompReport:
def get_plots( def get_plots(
self, self,
mode="delta", mode="delta_train",
metric="acc", metric="acc",
estimators=None, estimators=None,
conf="default", conf="default",
return_fig=False, save_fig=True,
base_path=None, base_path=None,
backend=None,
) -> List[Tuple[str, Path]]: ) -> List[Tuple[str, Path]]:
if mode == "delta_train": if mode == "delta_train":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
if avg_data.empty is True: if avg_data.empty:
return None return None
return plot.plot_delta( return plot.plot_delta(
@ -233,8 +234,9 @@ class CompReport:
metric=metric, metric=metric,
name=conf, name=conf,
train_prev=self.train_prev, train_prev=self.train_prev,
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "stdev_train": elif mode == "stdev_train":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
@ -250,8 +252,9 @@ class CompReport:
name=conf, name=conf,
train_prev=self.train_prev, train_prev=self.train_prev,
stdevs=st_data.T.to_numpy(), stdevs=st_data.T.to_numpy(),
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "diagonal": elif mode == "diagonal":
f_data = self.data(metric=metric + "_score", estimators=estimators) f_data = self.data(metric=metric + "_score", estimators=estimators)
@ -267,8 +270,9 @@ class CompReport:
metric=metric, metric=metric,
name=conf, name=conf,
train_prev=self.train_prev, train_prev=self.train_prev,
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "shift": elif mode == "shift":
_shift_data = self.shift_data(metric=metric, estimators=estimators) _shift_data = self.shift_data(metric=metric, estimators=estimators)
@ -289,8 +293,9 @@ class CompReport:
name=conf, name=conf,
train_prev=self.train_prev, train_prev=self.train_prev,
counts=shift_counts.T.to_numpy(), counts=shift_counts.T.to_numpy(),
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
def to_md( def to_md(
@ -322,11 +327,12 @@ class CompReport:
plot_modes = [m for m in modes if not m.endswith("table")] plot_modes = [m for m in modes if not m.endswith("table")]
for mode in plot_modes: for mode in plot_modes:
res += f"### {mode}\n" res += f"### {mode}\n"
op = self.get_plots( _, op = self.get_plots(
mode=mode, mode=mode,
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
save_fig=True,
base_path=plot_path, base_path=plot_path,
) )
res += f"![plot_{mode}]({op.relative_to(op.parents[1]).as_posix()})\n" res += f"![plot_{mode}]({op.relative_to(op.parents[1]).as_posix()})\n"
@ -423,8 +429,9 @@ class DatasetReport:
metric="acc", metric="acc",
estimators=None, estimators=None,
conf="default", conf="default",
return_fig=False, save_fig=True,
base_path=None, base_path=None,
backend=None,
): ):
if mode == "delta_train": if mode == "delta_train":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
@ -440,8 +447,9 @@ class DatasetReport:
name=conf, name=conf,
train_prev=None, train_prev=None,
avg="train", avg="train",
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "stdev_train": elif mode == "stdev_train":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
@ -459,8 +467,9 @@ class DatasetReport:
train_prev=None, train_prev=None,
stdevs=stdev_on_train.T.to_numpy(), stdevs=stdev_on_train.T.to_numpy(),
avg="train", avg="train",
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "delta_test": elif mode == "delta_test":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
@ -474,8 +483,9 @@ class DatasetReport:
name=conf, name=conf,
train_prev=None, train_prev=None,
avg="test", avg="test",
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "stdev_test": elif mode == "stdev_test":
_data = self.data(metric, estimators) if data is None else data _data = self.data(metric, estimators) if data is None else data
@ -491,8 +501,9 @@ class DatasetReport:
train_prev=None, train_prev=None,
stdevs=stdev_on_test.T.to_numpy(), stdevs=stdev_on_test.T.to_numpy(),
avg="test", avg="test",
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
elif mode == "shift": elif mode == "shift":
_shift_data = self.shift_data(metric, estimators) if data is None else data _shift_data = self.shift_data(metric, estimators) if data is None else data
@ -507,8 +518,9 @@ class DatasetReport:
name=conf, name=conf,
train_prev=None, train_prev=None,
counts=count_shift.T.to_numpy(), counts=count_shift.T.to_numpy(),
return_fig=return_fig, save_fig=save_fig,
base_path=base_path, base_path=base_path,
backend=backend,
) )
def to_md( def to_md(
@ -544,24 +556,26 @@ class DatasetReport:
res += avg_on_train_tbl.to_html() + "\n\n" res += avg_on_train_tbl.to_html() + "\n\n"
if "delta_train" in dr_modes: if "delta_train" in dr_modes:
delta_op = self.get_plots( _, delta_op = self.get_plots(
data=_data, data=_data,
mode="delta_train", mode="delta_train",
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
base_path=plot_path, base_path=plot_path,
save_fig=True,
) )
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n"
if "stdev_train" in dr_modes: if "stdev_train" in dr_modes:
delta_stdev_op = self.get_plots( _, delta_stdev_op = self.get_plots(
data=_data, data=_data,
mode="stdev_train", mode="stdev_train",
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
base_path=plot_path, base_path=plot_path,
save_fig=True,
) )
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n"
@ -574,24 +588,26 @@ class DatasetReport:
res += avg_on_test_tbl.to_html() + "\n\n" res += avg_on_test_tbl.to_html() + "\n\n"
if "delta_test" in dr_modes: if "delta_test" in dr_modes:
delta_op = self.get_plots( _, delta_op = self.get_plots(
data=_data, data=_data,
mode="delta_test", mode="delta_test",
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
base_path=plot_path, base_path=plot_path,
save_fig=True,
) )
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n"
if "stdev_test" in dr_modes: if "stdev_test" in dr_modes:
delta_stdev_op = self.get_plots( _, delta_stdev_op = self.get_plots(
data=_data, data=_data,
mode="stdev_test", mode="stdev_test",
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
base_path=plot_path, base_path=plot_path,
save_fig=True,
) )
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n"
@ -604,13 +620,14 @@ class DatasetReport:
res += shift_on_train_tbl.to_html() + "\n\n" res += shift_on_train_tbl.to_html() + "\n\n"
if "shift" in dr_modes: if "shift" in dr_modes:
shift_op = self.get_plots( _, shift_op = self.get_plots(
data=_shift_data, data=_shift_data,
mode="shift", mode="shift",
metric=metric, metric=metric,
estimators=estimators, estimators=estimators,
conf=conf, conf=conf,
base_path=plot_path, base_path=plot_path,
save_fig=True,
) )
res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n" res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n"

1
quacc/plot/__init__.py Normal file
View File

@ -0,0 +1 @@
from quacc.plot.plot import get_backend, plot_delta, plot_diagonal, plot_shift

54
quacc/plot/base.py Normal file
View File

@ -0,0 +1,54 @@
from pathlib import Path
class BasePlot:
@classmethod
def save_fig(cls, fig, base_path, title) -> Path:
...
@classmethod
def plot_diagonal(
cls,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
...
@classmethod
def plot_delta(
cls,
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
title="default",
x_label="prevs.",
y_label="error",
legend=True,
):
...
@classmethod
def plot_shift(
cls,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
...

222
quacc/plot/mpl.py Normal file
View File

@ -0,0 +1,222 @@
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from quacc import utils
from quacc.plot.base import BasePlot
matplotlib.use("agg")
class MplPlot(BasePlot):
def _get_markers(self, n: int):
ls = "ovx+sDph*^1234X><.Pd"
if n > len(ls):
ls = ls * (n / len(ls) + 1)
return list(ls)[:n]
def save_fig(self, fig, base_path, title) -> Path:
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
def plot_delta(
self,
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
title="default",
x_label="prevs.",
y_label="error",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
base_prevs = base_prevs[:, pos_class]
for method, deltas, _cy in zip(columns, data, cy):
ax.plot(
base_prevs,
deltas,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if stdevs is not None:
_col_idx = np.where(columns == method)[0]
stdev = stdevs[_col_idx].flatten()
nn_idx = np.intersect1d(
np.where(deltas != np.nan)[0],
np.where(stdev != np.nan)[0],
)
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
ax.fill_between(
_bps,
_ds - _st,
_ds + _st,
color=_cy["color"],
alpha=0.25,
)
ax.set(
xlabel=f"{x_label} prevalence",
ylabel=y_label,
title=title,
)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig
def plot_diagonal(
self,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
ax.set_aspect("equal")
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(
color=[cm(i) for i in range(NUM_COLORS)],
marker=self._get_markers(NUM_COLORS),
)
reference = np.array(reference)
x_ticks = np.unique(reference)
x_ticks.sort()
for deltas, _cy in zip(data, cy):
ax.plot(
reference,
deltas,
color=_cy["color"],
linestyle="None",
marker=_cy["marker"],
markersize=3,
zorder=2,
alpha=0.25,
)
# ensure limits are equal for both axes
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
for method, deltas, _cy in zip(columns, data, cy):
slope, interc = np.polyfit(reference, deltas, 1)
y_lr = np.array([slope * x + interc for x in _lims])
ax.plot(
_lims,
y_lr,
label=method,
color=_cy["color"],
linestyle="-",
markersize="0",
zorder=1,
)
# plot reference line
ax.plot(
_lims,
_lims,
color="black",
linestyle="--",
markersize=0,
zorder=1,
)
ax.set(xlabel=x_label, ylabel=y_label, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig
def plot_shift(
self,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
shift_prevs = shift_prevs[:, pos_class]
for method, shifts, _cy in zip(columns, data, cy):
ax.plot(
shift_prevs,
shifts,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if counts is not None:
_col_idx = np.where(columns == method)[0]
count = counts[_col_idx].flatten()
for prev, shift, cnt in zip(shift_prevs, shifts, count):
label = f"{cnt}"
plt.annotate(
label,
(prev, shift),
textcoords="offset points",
xytext=(0, 10),
ha="center",
color=_cy["color"],
fontsize=12.0,
)
ax.set(xlabel=x_label, ylabel=y_label, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig

144
quacc/plot/plot.py Normal file
View File

@ -0,0 +1,144 @@
from quacc.plot.base import BasePlot
from quacc.plot.mpl import MplPlot
from quacc.plot.plotly import PlotlyPlot
__backend: BasePlot = MplPlot()
def get_backend(be, theme=None):
match be:
case "matplotlib" | "mpl":
return MplPlot()
case "plotly":
return PlotlyPlot(theme=theme)
case _:
return MplPlot()
def plot_delta(
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
avg=None,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
_base_title = "delta_stdev" if stdevs is not None else "delta"
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
else:
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
x_label = f"{'test' if avg is None or avg == 'train' else 'train'} prevalence"
y_label = f"{metric} error"
fig = backend.plot_delta(
base_prevs,
columns,
data,
stdevs=stdevs,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig
def plot_diagonal(
reference,
columns,
data,
*,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
else:
title = f"diagonal_{name}_{metric}"
x_label = f"true {metric}"
y_label = f"estim. {metric}"
fig = backend.plot_diagonal(
reference,
columns,
data,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig
def plot_shift(
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"shift_{name}_{t_prev_pos}_{metric}"
else:
title = f"shift_{name}_avg_{metric}"
x_label = "dataset shift"
y_label = f"{metric} error"
fig = backend.plot_shift(
shift_prevs,
columns,
data,
counts=counts,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig