import os from collections import defaultdict from pathlib import Path import panel as pn from quacc.evaluation.comp import CE from quacc.evaluation.report import DatasetReport _plot_sizing_mode = "stretch_both" valid_plot_modes = defaultdict( lambda: [ "delta_train", "stdev_train", "train_table", "shift", "shift_table", "diagonal", ] ) valid_plot_modes["avg"] = [ "delta_train", "stdev_train", "train_table", "shift", "shift_table", "delta_test", "stdev_test", "test_table", ] def create_plots( dr: DatasetReport, mode="delta", metric="acc", estimators=None, plot_view=None, ): _prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] estimators = CE.name[estimators] if mode is None: mode = valid_plot_modes[plot_view][0] _dpi = 112 match (plot_view, mode): case ("avg", "train_table"): _data = ( dr.data(metric=metric, estimators=estimators).groupby(level=1).mean() ) return pn.pane.DataFrame(_data, align="center") if not _data.empty else None case ("avg", "test_table"): _data = ( dr.data(metric=metric, estimators=estimators).groupby(level=0).mean() ) return pn.pane.DataFrame(_data, align="center") if not _data.empty else None case ("avg", "shift_table"): _data = ( dr.shift_data(metric=metric, estimators=estimators) .groupby(level=0) .mean() ) return pn.pane.DataFrame(_data, align="center") if not _data.empty else None case ("avg", _ as plot_mode): _plot = dr.get_plots( mode=mode, metric=metric, estimators=estimators, conf="panel", return_fig=True, ) return ( pn.pane.Matplotlib( _plot, tight=True, format="png", # sizing_mode="scale_height", sizing_mode=_plot_sizing_mode, # sizing_mode="scale_both", ) if _plot is not None else None ) case (_, "train_table"): cr = dr.crs[_prevs.index(int(plot_view))] _data = ( cr.data(metric=metric, estimators=estimators).groupby(level=0).mean() ) return pn.pane.DataFrame(_data, align="center") if not _data.empty else None case (_, "shift_table"): cr = dr.crs[_prevs.index(int(plot_view))] _data = ( cr.shift_data(metric=metric, estimators=estimators) .groupby(level=0) .mean() ) return pn.pane.DataFrame(_data, align="center") if not _data.empty else None case (_, _ as plot_mode): cr = dr.crs[_prevs.index(int(plot_view))] _plot = cr.get_plots( mode=plot_mode, metric=metric, estimators=estimators, conf="panel", return_fig=True, ) return ( pn.pane.Matplotlib( _plot, tight=True, format="png", sizing_mode=_plot_sizing_mode, # sizing_mode="scale_height", # sizing_mode="scale_both", ) if _plot is not None else None ) def explore_datasets(root: Path | str): if isinstance(root, str): root = Path(root) if root.name == "plot": return [] if not root.exists(): return [] drs = [] for f in os.listdir(root): if (root / f).is_dir(): drs += explore_datasets(root / f) elif f == f"{root.name}.pickle": drs.append(root / f) # drs.append((str(root),)) return drs