From 64e546b8fd84f086e023295c111632e982b6305f Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Sat, 11 Nov 2023 16:44:12 +0100 Subject: [PATCH] panel implemented --- qcpanel/__init__.py | 0 qcpanel/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 143 bytes qcpanel/__pycache__/run.cpython-310.pyc | Bin 0 -> 4954 bytes qcpanel/run.py | 253 +++++++++++++++++++ 4 files changed, 253 insertions(+) create mode 100644 qcpanel/__init__.py create mode 100644 qcpanel/__pycache__/__init__.cpython-310.pyc create mode 100644 qcpanel/__pycache__/run.cpython-310.pyc create mode 100644 qcpanel/run.py diff --git a/qcpanel/__init__.py b/qcpanel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/qcpanel/__pycache__/__init__.cpython-310.pyc b/qcpanel/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87941758dc4270e636b18169ae24f5a41c825e1b GIT binary patch literal 143 zcmd1j<>g`kf?b(@sUZ3>h(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o6vAKR2&LKUqJt zIJKx)KPSH^HBY}dzqBYhRlg*)I8(ncxgaqwHAg=_J~J<~BtBlRpz;=nO>TZlX-=vg M$gE-}Ai=@_02=5YivR!s literal 0 HcmV?d00001 diff --git a/qcpanel/__pycache__/run.cpython-310.pyc b/qcpanel/__pycache__/run.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d2cb465162be06b380d30f9b51976262c24d558 GIT binary patch literal 4954 zcmd5=Ns}B$6|Stks(VJQ(JtARjIePVGy?`~!Xbog$u^tEMv_?yOjA>t(^IXUu2xpI zmZ3VrffT{TeISCH!x1yw=9dt`kpqXk5u9|k18mtW-^=P*C9~Ye0B&*Cx03pwB@A`DZ?8Iu!|FCS@zCb<{na1| zOq<&eSED3imhgHp_qab{%!-02j9A3@<2(}4vdv@OxaTE}huWKb2JOvhuUUTJo}0`} zYb|~dwOKwV9L#WtA4cr}KO!8_df7tDQGN_92l*5H1V4F?C3E}~f9#%>9OCS%)joY6 z=dl*sOuKKUY6buHrB{G6ms6FNLR}HVLaHQMDhgF9nGV}5@wq7T?m&B7^i`^zG|RNj zW#-XT(@%(h`dz}WyaM2i+xM)>>RFeqyKJ3R?gkrsRIj`f*4S5e@g9PW$*<@UUu zU-`YDVwGJv+%+=ufB_GAG;y)_crjy`dFx}NKwG$q-v>_sx>)Y*#+WIqVm(~a32fpp z7XOWX!@g>5+V^SneGb^PwI_?wfG07|m=8KbDQ=?Cp8vA8Gr4KWgXs0(+`{TWEo2LC zl|q&a{X&YtLRpMtCKeP#vT!3CrUTJmkfXu;@Q#kp_w!O^#p>|FVGRF(Hu!|Kb~ckD zRicy0&ae-u&Y$lWS=uizps>~d{t$h&XQsQh>1lVh;6nSWLdiVSu_#r(nyNyUiJKLJ zt3unufp%#RI(|J>wAp^Xq@%7-_0HR&4A5J`U@Uop;G`K$(o3A2-MxshNEbJXd;J}Jzgu{;Af8pJGSF7xrTx4-Tuc0e`%uksmm+M{c9>D#(CkBYK- ztkxycs-6oi^QL+PE#%I`1J)WvdK4_mqZoBf9s|l#1dkCs4ru#EAo6q6aF*cn1Wyq> z4cPP^B?9>cdQLjC>%^_;iPM|iT71;0B#FhrXCC%hJl><_BqyROV_37$;L2LJHmort z9bdOOgIU?rvW+s#iMud7o6K(PDu|;;X5?2cnbO1^dpt5d9Mi)$J$xQlFhAblP4t@K zpqxoK4!Ra^{f>=8(AXiK!MK5^ zk%}v;iYl)Pss=wg@yAW&tkaVwKgN$w7)mphM{QN3iZ<*jz`jmQ=sZpA?&O43%?-ND zQ;U+!T9Sr_Cp&Qehx-%5=d3xR`WR4T3 zbUb+)Ep1nl0VfUPO)!pwJVKtQ-C}^%ymBR9R)%7>7Nd2fS@C5mk>t(X@-mfZ&WGeh zE0DY-@q~;Bngp{1d+pTRW=8Eq+G{Q25Y7;#flvpTXb1I?|NSo77<->4fChQBsF>O0 zCaX_9t6dHktAniQk5+Nh{%DZj7zsl%I=~RnELyLMzQ~krU0M;@^_PuoUxTX}>eiK% z=f$oTc?QdAcL_9ZI8r-Tid&dqm}Ww1|B_j!J!AMvdmXcS*r{DgsT-Xgr=n+Y)|_uY zo2ff>>P_iZ{d8CVgFEGYoychzFH`E)c6r(-Y19*uY|Er+{Jje(4{! zi}J^D1pH8@*u);SYoKA9sQ-k@|07X@K0_64Q7=C<@x~zQbt3ErKg^F9;%$PKBSW}H z4dFIGxW_=au_4^!{|({FFG2wFO9bZ#76=|;ckfdBR{(2I*4IqV@;{QOe2#{F6|m<9 zi1D|GQocZ4zCmD?eQ~>FsuYqK!~ZQ6&J&P}kl!RAC6nJGxJd94pp7WK<|{)FNyE@q zr>?})Coqa6o<2ZXzCxo-tn?~P;XX2Qc#S%Jo8Y?y#LTeS?^F48f<=O}1aA<$N$>-L zD+Gq)o~6raO`8onSxim%a-l2&28mO%MPU(l~dVW1clt4*kiYH*ZmEPSrx?+-r`zwhyKg? z#DO}#+P2}ImwI63QV0AFBIP`TQ=p}ja5Na^+4a6Kd|-B56S{jXT`H3hiNnowsV~}| zG%hBfK~c?oQ1c|1>N}FE;qWtLKNZ05F|;Hn7PF?^WOH^1Vv;Lp(5^(bwnsJ&nrST~ z8;JsPB1^Hmwu~SdiVP7&AX9SI5iD=8%huaZBZ@@Sc*HFlVObDPP}%QU?>h*zeR>Z7F+_{nM|3qDDed;fvU(MJClg9$ zpa5luwqRl-fwuhJUm-7fz+MF3#OVkLpXe7UPh69fyX6WTe2{W3MOh}TTVkm!vg@!7 z#8};Y0ERh>N#rttIiN@SC$Qj6YU%G zV0vcgwEdxrq|nxg>p2w24w0cyJ^-On0!GQ##D$ev>kj2E%0Z>+0o!Y(*Q`sG+vJ}beOR_s!MacOkRb|q| ztQ|ZxZAol8QN1)LsmI%iZ`!4H&69m3=pe!W5c?=exybCc$jkpU^jl2+lJ;1KwUQ*U zi1#bfZ2YjigSzYkl7JWt3ANFDlYyG+N{-wZ;f=~X#hW9e3iI>|!K9(dv4@^X^Mb~( zkB-D`B?cu$*3Zx~#t~3pQWFtzgS<{a%1|Gdwz})?cZfbbk9Py3zPLb^P*Ug~8kP@j hVw literal 0 HcmV?d00001 diff --git a/qcpanel/run.py b/qcpanel/run.py new file mode 100644 index 0000000..87d5130 --- /dev/null +++ b/qcpanel/run.py @@ -0,0 +1,253 @@ +import argparse +import os +from pathlib import Path + +import panel as pn + +from quacc.evaluation.comp import CE +from quacc.evaluation.report import DatasetReport + +pn.extension(design="bootstrap") + + +def create_cr_plots( + dr: DatasetReport, + mode="delta", + metric="acc", + estimators=None, + prev=None, +): + idx = [round(cr.train_prev[1] * 100) for cr in dr.crs].index(prev) + cr = dr.crs[idx] + estimators = CE.name[estimators] + _dpi = 112 + return pn.pane.Matplotlib( + cr.get_plots( + mode=mode, + metric=metric, + estimators=estimators, + conf="panel", + return_fig=True, + ), + tight=True, + format="png", + sizing_mode="scale_height", + # sizing_mode="scale_both", + ) + + +def create_avg_plots( + dr: DatasetReport, + mode="delta", + metric="acc", + estimators=None, + prev=None, +): + estimators = CE.name[estimators] + return pn.pane.Matplotlib( + dr.get_plots( + mode=mode, + metric=metric, + estimators=estimators, + conf="panel", + return_fig=True, + ), + tight=True, + format="png", + sizing_mode="scale_height", + # sizing_mode="scale_both", + ) + + +def build_cr_tab(dr: DatasetReport): + _data = dr.data() + _metrics = _data.columns.unique(0) + _estimators = _data.columns.unique(1) + + valid_metrics = [m for m in _metrics if not m.endswith("_score")] + metric_widget = pn.widgets.Select( + name="metric", + value="acc", + options=valid_metrics, + align="center", + ) + + valid_estimators = [e for e in _estimators if e != "ref"] + estimators_widget = pn.widgets.CheckButtonGroup( + name="estimators", + options=valid_estimators, + value=valid_estimators, + button_style="outline", + button_type="primary", + align="center", + orientation="vertical", + sizing_mode="scale_width", + ) + + valid_plot_modes = ["delta", "delta_stdev", "diagonal", "shift"] + plot_mode_widget = pn.widgets.RadioButtonGroup( + name="mode", + value=valid_plot_modes[0], + options=valid_plot_modes, + button_style="outline", + button_type="primary", + align="center", + orientation="vertical", + sizing_mode="scale_width", + ) + + valid_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] + prevs_widget = pn.widgets.RadioButtonGroup( + name="train prevalence", + value=valid_prevs[0], + options=valid_prevs, + button_style="outline", + button_type="primary", + align="center", + orientation="vertical", + ) + + plot_pane = pn.bind( + create_cr_plots, + dr=dr, + mode=plot_mode_widget, + metric=metric_widget, + estimators=estimators_widget, + prev=prevs_widget, + ) + + return pn.Row( + pn.Spacer(width=20), + pn.Column( + metric_widget, + pn.Row( + prevs_widget, + plot_mode_widget, + ), + estimators_widget, + align="center", + ), + pn.Spacer(sizing_mode="scale_width"), + plot_pane, + ) + + +def build_avg_tab(dr: DatasetReport): + _data = dr.data() + _metrics = _data.columns.unique(0) + _estimators = _data.columns.unique(1) + + valid_metrics = [m for m in _metrics if not m.endswith("_score")] + metric_widget = pn.widgets.Select( + name="metric", + value="acc", + options=valid_metrics, + align="center", + ) + + valid_estimators = [e for e in _estimators if e != "ref"] + estimators_widget = pn.widgets.CheckButtonGroup( + name="estimators", + options=valid_estimators, + value=valid_estimators, + button_style="outline", + button_type="primary", + align="center", + orientation="vertical", + sizing_mode="scale_width", + ) + + valid_plot_modes = [ + "delta_train", + "stdev_train", + "delta_test", + "stdev_test", + "shift", + ] + plot_mode_widget = pn.widgets.RadioButtonGroup( + name="mode", + value=valid_plot_modes[0], + options=valid_plot_modes, + button_style="outline", + button_type="primary", + align="center", + orientation="vertical", + sizing_mode="scale_width", + ) + + plot_pane = pn.bind( + create_avg_plots, + dr=dr, + mode=plot_mode_widget, + metric=metric_widget, + estimators=estimators_widget, + ) + + return pn.Row( + pn.Spacer(width=20), + pn.Column( + metric_widget, + plot_mode_widget, + estimators_widget, + align="center", + ), + pn.Spacer(sizing_mode="scale_width"), + plot_pane, + ) + + +def build_dataset(dataset_path: Path): + dr: DatasetReport = DatasetReport.unpickle(dataset_path) + + prevs_tab = ("train prevs.", build_cr_tab(dr)) + avg_tab = ("avg", build_avg_tab(dr)) + + app = pn.Tabs(objects=[avg_tab, prevs_tab], dynamic=False) + app.servable() + return app + + +def explore_datasets(root: Path | str): + if isinstance(root, str): + root = Path(root) + + drs = [] + for f in os.listdir(root): + if (root / f).is_dir(): + drs += explore_datasets(root / f) + elif f == f"{root.name}.pickle": + drs.append((str(root), build_dataset(root / f))) + # drs.append((str(root),)) + + return drs + + +def serve(address="localhost"): + # app = build_dataset(Path("output/rcv1_CCAT_9prevs/rcv1_CCAT_9prevs.pickle")) + app = pn.Tabs( + objects=explore_datasets("output"), + tabs_location="left", + dynamic=False, + ) + + __port = 33420 + pn.serve( + app, + autoreload=True, + port=__port, + show=False, + address=address, + websocket_origin=f"{address}:{__port}", + ) + + +def run(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--address", + action="store", + dest="address", + default="localhost", + ) + args = parser.parse_args() + serve(address=args.address)