import argparse import os from pathlib import Path import panel as pn import param from quacc.evaluation.comp import CE from quacc.evaluation.report import DatasetReport pn.extension(design="bootstrap") def create_cr_plots( dr: DatasetReport, mode="delta", metric="acc", estimators=None, prev=None, ): idx = [round(cr.train_prev[1] * 100) for cr in dr.crs].index(prev) cr = dr.crs[idx] estimators = CE.name[estimators] _dpi = 112 return pn.pane.Matplotlib( cr.get_plots( mode=mode, metric=metric, estimators=estimators, conf="panel", return_fig=True, ), tight=True, format="png", sizing_mode="scale_height", # sizing_mode="scale_both", ) def create_avg_plots( dr: DatasetReport, mode="delta", metric="acc", estimators=None, prev=None, ): estimators = CE.name[estimators] return pn.pane.Matplotlib( dr.get_plots( mode=mode, metric=metric, estimators=estimators, conf="panel", return_fig=True, ), tight=True, format="png", sizing_mode="scale_height", # sizing_mode="scale_both", ) def build_cr_tab(dr: DatasetReport): _data = dr.data() _metrics = _data.columns.unique(0) _estimators = _data.columns.unique(1) valid_metrics = [m for m in _metrics if not m.endswith("_score")] metric_widget = pn.widgets.Select( name="metric", value="acc", options=valid_metrics, align="center", ) valid_estimators = [e for e in _estimators if e != "ref"] estimators_widget = pn.widgets.CheckButtonGroup( name="estimators", options=valid_estimators, value=valid_estimators, button_style="outline", button_type="primary", align="center", orientation="vertical", sizing_mode="scale_width", ) valid_plot_modes = ["delta", "delta_stdev", "diagonal", "shift"] plot_mode_widget = pn.widgets.RadioButtonGroup( name="mode", value=valid_plot_modes[0], options=valid_plot_modes, button_style="outline", button_type="primary", align="center", orientation="vertical", sizing_mode="scale_width", ) valid_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] prevs_widget = pn.widgets.RadioButtonGroup( name="train prevalence", value=valid_prevs[0], options=valid_prevs, button_style="outline", button_type="primary", align="center", orientation="vertical", ) plot_pane = pn.bind( create_cr_plots, dr=dr, mode=plot_mode_widget, metric=metric_widget, estimators=estimators_widget, prev=prevs_widget, ) return pn.Row( pn.Spacer(width=20), pn.Column( metric_widget, pn.Row( prevs_widget, plot_mode_widget, ), estimators_widget, align="center", ), pn.Spacer(sizing_mode="scale_width"), plot_pane, ) def build_avg_tab(dr: DatasetReport): _data = dr.data() _metrics = _data.columns.unique(0) _estimators = _data.columns.unique(1) valid_metrics = [m for m in _metrics if not m.endswith("_score")] metric_widget = pn.widgets.Select( name="metric", value="acc", options=valid_metrics, align="center", ) valid_estimators = [e for e in _estimators if e != "ref"] estimators_widget = pn.widgets.CheckButtonGroup( name="estimators", options=valid_estimators, value=valid_estimators, button_style="outline", button_type="primary", align="center", orientation="vertical", sizing_mode="scale_width", ) valid_plot_modes = [ "delta_train", "stdev_train", "delta_test", "stdev_test", "shift", ] plot_mode_widget = pn.widgets.RadioButtonGroup( name="mode", value=valid_plot_modes[0], options=valid_plot_modes, button_style="outline", button_type="primary", align="center", orientation="vertical", sizing_mode="scale_width", ) plot_pane = pn.bind( create_avg_plots, dr=dr, mode=plot_mode_widget, metric=metric_widget, estimators=estimators_widget, ) return pn.Row( pn.Spacer(width=20), pn.Column( metric_widget, plot_mode_widget, estimators_widget, align="center", ), pn.Spacer(sizing_mode="scale_width"), plot_pane, ) def build_dataset(dataset_path: Path): dr: DatasetReport = DatasetReport.unpickle(dataset_path) prevs_tab = ("train prevs.", build_cr_tab(dr)) avg_tab = ("avg", build_avg_tab(dr)) app = pn.Tabs(objects=[avg_tab, prevs_tab], dynamic=False) app.servable() return app def explore_datasets(root: Path | str): if isinstance(root, str): root = Path(root) if root.name == "plot": return [] if not root.exists(): return [] drs = [] for f in os.listdir(root): if (root / f).is_dir(): drs += explore_datasets(root / f) elif f == f"{root.name}.pickle": drs.append((root, build_dataset(root / f))) # drs.append((str(root),)) return drs class PlotSelector(param.Parameterized): metric = param.Selector(objects=["acc", "f1"]) view = param.Selector(objects=["train prevs", "avg"]) def plot_selector_widget(): return pn.Param( PlotSelector.param, widgets={ "metric": pn.widgets.Select, "view": pn.widgets.Select, }, ) def serve(address="localhost"): # app = build_dataset(Path("output/rcv1_CCAT_9prevs/rcv1_CCAT_9prevs.pickle")) __base_path = "output" __tabs = sorted( explore_datasets(__base_path), key=lambda t: (len(t[0].parts), t[0]) ) __tabs = [(str(p.relative_to(Path(__base_path))), d) for (p, d) in __tabs] if len(__tabs) > 0: app = pn.Tabs( objects=__tabs, tabs_location="left", dynamic=False, ) else: app = pn.Column( pn.pane.Str("No Dataset Found", styles={"font-size": "24pt"}), align="center", ) __port = 33420 __allowed = [address] if address == "localhost": __allowed.append("127.0.0.1") pn.serve( app, autoreload=True, port=__port, show=False, address=address, websocket_origin=[f"{_a}:{__port}" for _a in __allowed], ) def run(): parser = argparse.ArgumentParser() parser.add_argument( "--address", action="store", dest="address", default="localhost", ) args = parser.parse_args() serve(address=args.address)