plot refactored
This commit is contained in:
parent
deeb522ccb
commit
f0bfb2e039
|
@ -52,7 +52,7 @@ def create_plots(
|
|||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf="panel",
|
||||
return_fig=True,
|
||||
save_fig=False,
|
||||
)
|
||||
return (
|
||||
pn.pane.Matplotlib(
|
||||
|
@ -87,7 +87,7 @@ def create_plots(
|
|||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf="panel",
|
||||
return_fig=True,
|
||||
save_fig=False,
|
||||
)
|
||||
return (
|
||||
pn.pane.Matplotlib(
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import List, Tuple
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from quacc import plot
|
||||
import quacc.plot as plot
|
||||
from quacc.utils import fmt_line_md
|
||||
|
||||
|
||||
|
@ -214,16 +214,17 @@ class CompReport:
|
|||
|
||||
def get_plots(
|
||||
self,
|
||||
mode="delta",
|
||||
mode="delta_train",
|
||||
metric="acc",
|
||||
estimators=None,
|
||||
conf="default",
|
||||
return_fig=False,
|
||||
save_fig=True,
|
||||
base_path=None,
|
||||
backend=None,
|
||||
) -> List[Tuple[str, Path]]:
|
||||
if mode == "delta_train":
|
||||
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
|
||||
if avg_data.empty is True:
|
||||
if avg_data.empty:
|
||||
return None
|
||||
|
||||
return plot.plot_delta(
|
||||
|
@ -233,8 +234,9 @@ class CompReport:
|
|||
metric=metric,
|
||||
name=conf,
|
||||
train_prev=self.train_prev,
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "stdev_train":
|
||||
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
|
||||
|
@ -250,8 +252,9 @@ class CompReport:
|
|||
name=conf,
|
||||
train_prev=self.train_prev,
|
||||
stdevs=st_data.T.to_numpy(),
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "diagonal":
|
||||
f_data = self.data(metric=metric + "_score", estimators=estimators)
|
||||
|
@ -267,8 +270,9 @@ class CompReport:
|
|||
metric=metric,
|
||||
name=conf,
|
||||
train_prev=self.train_prev,
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "shift":
|
||||
_shift_data = self.shift_data(metric=metric, estimators=estimators)
|
||||
|
@ -289,8 +293,9 @@ class CompReport:
|
|||
name=conf,
|
||||
train_prev=self.train_prev,
|
||||
counts=shift_counts.T.to_numpy(),
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
def to_md(
|
||||
|
@ -322,11 +327,12 @@ class CompReport:
|
|||
plot_modes = [m for m in modes if not m.endswith("table")]
|
||||
for mode in plot_modes:
|
||||
res += f"### {mode}\n"
|
||||
op = self.get_plots(
|
||||
_, op = self.get_plots(
|
||||
mode=mode,
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
save_fig=True,
|
||||
base_path=plot_path,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
@ -423,8 +429,9 @@ class DatasetReport:
|
|||
metric="acc",
|
||||
estimators=None,
|
||||
conf="default",
|
||||
return_fig=False,
|
||||
save_fig=True,
|
||||
base_path=None,
|
||||
backend=None,
|
||||
):
|
||||
if mode == "delta_train":
|
||||
_data = self.data(metric, estimators) if data is None else data
|
||||
|
@ -440,8 +447,9 @@ class DatasetReport:
|
|||
name=conf,
|
||||
train_prev=None,
|
||||
avg="train",
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "stdev_train":
|
||||
_data = self.data(metric, estimators) if data is None else data
|
||||
|
@ -459,8 +467,9 @@ class DatasetReport:
|
|||
train_prev=None,
|
||||
stdevs=stdev_on_train.T.to_numpy(),
|
||||
avg="train",
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "delta_test":
|
||||
_data = self.data(metric, estimators) if data is None else data
|
||||
|
@ -474,8 +483,9 @@ class DatasetReport:
|
|||
name=conf,
|
||||
train_prev=None,
|
||||
avg="test",
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "stdev_test":
|
||||
_data = self.data(metric, estimators) if data is None else data
|
||||
|
@ -491,8 +501,9 @@ class DatasetReport:
|
|||
train_prev=None,
|
||||
stdevs=stdev_on_test.T.to_numpy(),
|
||||
avg="test",
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
elif mode == "shift":
|
||||
_shift_data = self.shift_data(metric, estimators) if data is None else data
|
||||
|
@ -507,8 +518,9 @@ class DatasetReport:
|
|||
name=conf,
|
||||
train_prev=None,
|
||||
counts=count_shift.T.to_numpy(),
|
||||
return_fig=return_fig,
|
||||
save_fig=save_fig,
|
||||
base_path=base_path,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
def to_md(
|
||||
|
@ -544,24 +556,26 @@ class DatasetReport:
|
|||
res += avg_on_train_tbl.to_html() + "\n\n"
|
||||
|
||||
if "delta_train" in dr_modes:
|
||||
delta_op = self.get_plots(
|
||||
_, delta_op = self.get_plots(
|
||||
data=_data,
|
||||
mode="delta_train",
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
base_path=plot_path,
|
||||
save_fig=True,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
if "stdev_train" in dr_modes:
|
||||
delta_stdev_op = self.get_plots(
|
||||
_, delta_stdev_op = self.get_plots(
|
||||
data=_data,
|
||||
mode="stdev_train",
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
base_path=plot_path,
|
||||
save_fig=True,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
|
@ -574,24 +588,26 @@ class DatasetReport:
|
|||
res += avg_on_test_tbl.to_html() + "\n\n"
|
||||
|
||||
if "delta_test" in dr_modes:
|
||||
delta_op = self.get_plots(
|
||||
_, delta_op = self.get_plots(
|
||||
data=_data,
|
||||
mode="delta_test",
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
base_path=plot_path,
|
||||
save_fig=True,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
if "stdev_test" in dr_modes:
|
||||
delta_stdev_op = self.get_plots(
|
||||
_, delta_stdev_op = self.get_plots(
|
||||
data=_data,
|
||||
mode="stdev_test",
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
base_path=plot_path,
|
||||
save_fig=True,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
|
@ -604,13 +620,14 @@ class DatasetReport:
|
|||
res += shift_on_train_tbl.to_html() + "\n\n"
|
||||
|
||||
if "shift" in dr_modes:
|
||||
shift_op = self.get_plots(
|
||||
_, shift_op = self.get_plots(
|
||||
data=_shift_data,
|
||||
mode="shift",
|
||||
metric=metric,
|
||||
estimators=estimators,
|
||||
conf=conf,
|
||||
base_path=plot_path,
|
||||
save_fig=True,
|
||||
)
|
||||
res += f".as_posix()})\n"
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from quacc.plot.plot import get_backend, plot_delta, plot_diagonal, plot_shift
|
|
@ -0,0 +1,54 @@
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
class BasePlot:
|
||||
@classmethod
|
||||
def save_fig(cls, fig, base_path, title) -> Path:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def plot_diagonal(
|
||||
cls,
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def plot_delta(
|
||||
cls,
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="prevs.",
|
||||
y_label="error",
|
||||
legend=True,
|
||||
):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def plot_shift(
|
||||
cls,
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
):
|
||||
...
|
|
@ -0,0 +1,222 @@
|
|||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from cycler import cycler
|
||||
|
||||
from quacc import utils
|
||||
from quacc.plot.base import BasePlot
|
||||
|
||||
matplotlib.use("agg")
|
||||
|
||||
|
||||
class MplPlot(BasePlot):
|
||||
def _get_markers(self, n: int):
|
||||
ls = "ovx+sDph*^1234X><.Pd"
|
||||
if n > len(ls):
|
||||
ls = ls * (n / len(ls) + 1)
|
||||
return list(ls)[:n]
|
||||
|
||||
def save_fig(self, fig, base_path, title) -> Path:
|
||||
if base_path is None:
|
||||
base_path = utils.get_quacc_home() / "plots"
|
||||
output_path = base_path / f"{title}.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
return output_path
|
||||
|
||||
def plot_delta(
|
||||
self,
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="prevs.",
|
||||
y_label="error",
|
||||
legend=True,
|
||||
):
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
||||
|
||||
base_prevs = base_prevs[:, pos_class]
|
||||
for method, deltas, _cy in zip(columns, data, cy):
|
||||
ax.plot(
|
||||
base_prevs,
|
||||
deltas,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
if stdevs is not None:
|
||||
_col_idx = np.where(columns == method)[0]
|
||||
stdev = stdevs[_col_idx].flatten()
|
||||
nn_idx = np.intersect1d(
|
||||
np.where(deltas != np.nan)[0],
|
||||
np.where(stdev != np.nan)[0],
|
||||
)
|
||||
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
|
||||
ax.fill_between(
|
||||
_bps,
|
||||
_ds - _st,
|
||||
_ds + _st,
|
||||
color=_cy["color"],
|
||||
alpha=0.25,
|
||||
)
|
||||
|
||||
ax.set(
|
||||
xlabel=f"{x_label} prevalence",
|
||||
ylabel=y_label,
|
||||
title=title,
|
||||
)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
return fig
|
||||
|
||||
def plot_diagonal(
|
||||
self,
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
):
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
ax.set_aspect("equal")
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(
|
||||
color=[cm(i) for i in range(NUM_COLORS)],
|
||||
marker=self._get_markers(NUM_COLORS),
|
||||
)
|
||||
|
||||
reference = np.array(reference)
|
||||
x_ticks = np.unique(reference)
|
||||
x_ticks.sort()
|
||||
|
||||
for deltas, _cy in zip(data, cy):
|
||||
ax.plot(
|
||||
reference,
|
||||
deltas,
|
||||
color=_cy["color"],
|
||||
linestyle="None",
|
||||
marker=_cy["marker"],
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
alpha=0.25,
|
||||
)
|
||||
|
||||
# ensure limits are equal for both axes
|
||||
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
|
||||
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
|
||||
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
|
||||
|
||||
for method, deltas, _cy in zip(columns, data, cy):
|
||||
slope, interc = np.polyfit(reference, deltas, 1)
|
||||
y_lr = np.array([slope * x + interc for x in _lims])
|
||||
ax.plot(
|
||||
_lims,
|
||||
y_lr,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
markersize="0",
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
# plot reference line
|
||||
ax.plot(
|
||||
_lims,
|
||||
_lims,
|
||||
color="black",
|
||||
linestyle="--",
|
||||
markersize=0,
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
ax.set(xlabel=x_label, ylabel=y_label, title=title)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
return fig
|
||||
|
||||
def plot_shift(
|
||||
self,
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
title="default",
|
||||
x_label="true",
|
||||
y_label="estim.",
|
||||
legend=True,
|
||||
):
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(data)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
|
||||
|
||||
shift_prevs = shift_prevs[:, pos_class]
|
||||
for method, shifts, _cy in zip(columns, data, cy):
|
||||
ax.plot(
|
||||
shift_prevs,
|
||||
shifts,
|
||||
label=method,
|
||||
color=_cy["color"],
|
||||
linestyle="-",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
if counts is not None:
|
||||
_col_idx = np.where(columns == method)[0]
|
||||
count = counts[_col_idx].flatten()
|
||||
for prev, shift, cnt in zip(shift_prevs, shifts, count):
|
||||
label = f"{cnt}"
|
||||
plt.annotate(
|
||||
label,
|
||||
(prev, shift),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 10),
|
||||
ha="center",
|
||||
color=_cy["color"],
|
||||
fontsize=12.0,
|
||||
)
|
||||
|
||||
ax.set(xlabel=x_label, ylabel=y_label, title=title)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
|
||||
return fig
|
|
@ -0,0 +1,144 @@
|
|||
from quacc.plot.base import BasePlot
|
||||
from quacc.plot.mpl import MplPlot
|
||||
from quacc.plot.plotly import PlotlyPlot
|
||||
|
||||
__backend: BasePlot = MplPlot()
|
||||
|
||||
|
||||
def get_backend(be, theme=None):
|
||||
match be:
|
||||
case "matplotlib" | "mpl":
|
||||
return MplPlot()
|
||||
case "plotly":
|
||||
return PlotlyPlot(theme=theme)
|
||||
case _:
|
||||
return MplPlot()
|
||||
|
||||
|
||||
def plot_delta(
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
stdevs=None,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
avg=None,
|
||||
save_fig=False,
|
||||
base_path=None,
|
||||
backend=None,
|
||||
):
|
||||
backend = __backend if backend is None else backend
|
||||
_base_title = "delta_stdev" if stdevs is not None else "delta"
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
|
||||
|
||||
x_label = f"{'test' if avg is None or avg == 'train' else 'train'} prevalence"
|
||||
y_label = f"{metric} error"
|
||||
fig = backend.plot_delta(
|
||||
base_prevs,
|
||||
columns,
|
||||
data,
|
||||
stdevs=stdevs,
|
||||
pos_class=pos_class,
|
||||
title=title,
|
||||
x_label=x_label,
|
||||
y_label=y_label,
|
||||
legend=legend,
|
||||
)
|
||||
|
||||
if save_fig:
|
||||
output_path = backend.save_fig(fig, base_path, title)
|
||||
return fig, output_path
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_diagonal(
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
save_fig=False,
|
||||
base_path=None,
|
||||
backend=None,
|
||||
):
|
||||
backend = __backend if backend is None else backend
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"diagonal_{name}_{metric}"
|
||||
|
||||
x_label = f"true {metric}"
|
||||
y_label = f"estim. {metric}"
|
||||
fig = backend.plot_diagonal(
|
||||
reference,
|
||||
columns,
|
||||
data,
|
||||
pos_class=pos_class,
|
||||
title=title,
|
||||
x_label=x_label,
|
||||
y_label=y_label,
|
||||
legend=legend,
|
||||
)
|
||||
|
||||
if save_fig:
|
||||
output_path = backend.save_fig(fig, base_path, title)
|
||||
return fig, output_path
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_shift(
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
*,
|
||||
counts=None,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
save_fig=False,
|
||||
base_path=None,
|
||||
backend=None,
|
||||
):
|
||||
backend = __backend if backend is None else backend
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"shift_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"shift_{name}_avg_{metric}"
|
||||
|
||||
x_label = "dataset shift"
|
||||
y_label = f"{metric} error"
|
||||
fig = backend.plot_shift(
|
||||
shift_prevs,
|
||||
columns,
|
||||
data,
|
||||
counts=counts,
|
||||
pos_class=pos_class,
|
||||
title=title,
|
||||
x_label=x_label,
|
||||
y_label=y_label,
|
||||
legend=legend,
|
||||
)
|
||||
|
||||
if save_fig:
|
||||
output_path = backend.save_fig(fig, base_path, title)
|
||||
return fig, output_path
|
||||
|
||||
return fig
|
Loading…
Reference in New Issue