2023-11-16 17:10:19 +01:00
|
|
|
import os
|
|
|
|
from collections import defaultdict
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import panel as pn
|
|
|
|
|
|
|
|
from quacc.evaluation.comp import CE
|
|
|
|
from quacc.evaluation.report import DatasetReport
|
|
|
|
|
|
|
|
_plot_sizing_mode = "stretch_both"
|
|
|
|
valid_plot_modes = defaultdict(
|
2023-11-22 19:19:51 +01:00
|
|
|
lambda: [
|
|
|
|
"delta_train",
|
|
|
|
"stdev_train",
|
|
|
|
"train_table",
|
|
|
|
"shift",
|
|
|
|
"shift_table",
|
|
|
|
"diagonal",
|
|
|
|
]
|
2023-11-16 17:10:19 +01:00
|
|
|
)
|
|
|
|
valid_plot_modes["avg"] = [
|
|
|
|
"delta_train",
|
|
|
|
"stdev_train",
|
2023-11-22 19:19:51 +01:00
|
|
|
"train_table",
|
|
|
|
"shift",
|
|
|
|
"shift_table",
|
2023-11-16 17:10:19 +01:00
|
|
|
"delta_test",
|
|
|
|
"stdev_test",
|
|
|
|
"test_table",
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2023-11-22 19:19:51 +01:00
|
|
|
def create_plots(
|
2023-11-16 17:10:19 +01:00
|
|
|
dr: DatasetReport,
|
|
|
|
mode="delta",
|
|
|
|
metric="acc",
|
|
|
|
estimators=None,
|
2023-11-22 19:19:51 +01:00
|
|
|
plot_view=None,
|
2023-11-16 17:10:19 +01:00
|
|
|
):
|
|
|
|
_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
|
|
|
|
estimators = CE.name[estimators]
|
|
|
|
if mode is None:
|
2023-11-22 19:19:51 +01:00
|
|
|
mode = valid_plot_modes[plot_view][0]
|
2023-11-16 17:10:19 +01:00
|
|
|
_dpi = 112
|
2023-11-22 19:19:51 +01:00
|
|
|
match (plot_view, mode):
|
|
|
|
case ("avg", "train_table"):
|
|
|
|
_data = (
|
|
|
|
dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
|
|
|
|
)
|
|
|
|
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
|
|
|
|
case ("avg", "test_table"):
|
|
|
|
_data = (
|
|
|
|
dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
|
|
|
)
|
|
|
|
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
|
|
|
|
case ("avg", "shift_table"):
|
|
|
|
_data = (
|
|
|
|
dr.shift_data(metric=metric, estimators=estimators)
|
|
|
|
.groupby(level=0)
|
|
|
|
.mean()
|
|
|
|
)
|
|
|
|
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
|
|
|
|
case ("avg", _ as plot_mode):
|
|
|
|
_plot = dr.get_plots(
|
2023-11-16 17:10:19 +01:00
|
|
|
mode=mode,
|
|
|
|
metric=metric,
|
|
|
|
estimators=estimators,
|
|
|
|
conf="panel",
|
|
|
|
return_fig=True,
|
2023-11-22 19:19:51 +01:00
|
|
|
)
|
|
|
|
return (
|
|
|
|
pn.pane.Matplotlib(
|
|
|
|
_plot,
|
|
|
|
tight=True,
|
|
|
|
format="png",
|
|
|
|
# sizing_mode="scale_height",
|
|
|
|
sizing_mode=_plot_sizing_mode,
|
|
|
|
# sizing_mode="scale_both",
|
|
|
|
)
|
|
|
|
if _plot is not None
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
case (_, "train_table"):
|
|
|
|
cr = dr.crs[_prevs.index(int(plot_view))]
|
|
|
|
_data = (
|
|
|
|
cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
|
|
|
|
)
|
|
|
|
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
|
|
|
|
case (_, "shift_table"):
|
|
|
|
cr = dr.crs[_prevs.index(int(plot_view))]
|
|
|
|
_data = (
|
|
|
|
cr.shift_data(metric=metric, estimators=estimators)
|
|
|
|
.groupby(level=0)
|
|
|
|
.mean()
|
|
|
|
)
|
|
|
|
return pn.pane.DataFrame(_data, align="center") if not _data.empty else None
|
|
|
|
case (_, _ as plot_mode):
|
|
|
|
cr = dr.crs[_prevs.index(int(plot_view))]
|
|
|
|
_plot = cr.get_plots(
|
|
|
|
mode=plot_mode,
|
|
|
|
metric=metric,
|
|
|
|
estimators=estimators,
|
|
|
|
conf="panel",
|
|
|
|
return_fig=True,
|
|
|
|
)
|
|
|
|
return (
|
|
|
|
pn.pane.Matplotlib(
|
|
|
|
_plot,
|
|
|
|
tight=True,
|
|
|
|
format="png",
|
|
|
|
sizing_mode=_plot_sizing_mode,
|
|
|
|
# sizing_mode="scale_height",
|
|
|
|
# sizing_mode="scale_both",
|
|
|
|
)
|
|
|
|
if _plot is not None
|
|
|
|
else None
|
|
|
|
)
|
2023-11-16 17:10:19 +01:00
|
|
|
|
|
|
|
|
|
|
|
def explore_datasets(root: Path | str):
|
|
|
|
if isinstance(root, str):
|
|
|
|
root = Path(root)
|
|
|
|
|
|
|
|
if root.name == "plot":
|
|
|
|
return []
|
|
|
|
|
|
|
|
if not root.exists():
|
|
|
|
return []
|
|
|
|
|
|
|
|
drs = []
|
|
|
|
for f in os.listdir(root):
|
|
|
|
if (root / f).is_dir():
|
|
|
|
drs += explore_datasets(root / f)
|
|
|
|
elif f == f"{root.name}.pickle":
|
|
|
|
drs.append(root / f)
|
|
|
|
# drs.append((str(root),))
|
|
|
|
|
|
|
|
return drs
|