diff --git a/.gitignore b/.gitignore index 88fb8aa..6a4219b 100644 --- a/.gitignore +++ b/.gitignore @@ -6,11 +6,13 @@ quavenv/* __pycache__/* baselines/__pycache__/* baselines/densratio/__pycache__/* +qcdash/__pycache__/* qcpanel/__pycache__/* quacc/__pycache__/* quacc/evaluation/__pycache__/* quacc/method/__pycache__/* quacc/quantification/__pycache__/* +quacc/plot/__pycache__/* tests/__pycache__/* tests/*/__pycache__/* tests/*/*/__pycache__/* diff --git a/poetry.lock b/poetry.lock index 41efb57..07d4bff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "abstention" version = "0.1.3.1" description = "Functions for abstention, calibration and label shift domain adaptation" +category = "main" optional = false python-versions = "*" files = [ @@ -15,10 +16,27 @@ numpy = ">=1.9" scikit-learn = ">=0.20.0" scipy = ">=1.1.0" +[[package]] +name = "ansi2html" +version = "1.8.0" +description = "" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "ansi2html-1.8.0-py3-none-any.whl", hash = "sha256:ef9cc9682539dbe524fbf8edad9c9462a308e04bce1170c32daa8fdfd0001785"}, + {file = "ansi2html-1.8.0.tar.gz", hash = "sha256:38b82a298482a1fa2613f0f9c9beb3db72a8f832eeac58eb2e47bf32cd37f6d5"}, +] + +[package.extras] +docs = ["Sphinx", "setuptools-scm", "sphinx-rtd-theme"] +test = ["pytest", "pytest-cov"] + [[package]] name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" +category = "dev" optional = false python-versions = "*" files = [ @@ -30,6 +48,7 @@ files = [ name = "asttokens" version = "2.4.1" description = "Annotate AST trees with source code positions" +category = "dev" optional = false python-versions = "*" files = [ @@ -48,6 +67,7 @@ test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] name = "bcrypt" version = "4.0.1" description = "Modern password hashing for your software and your servers" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -82,6 +102,7 @@ typecheck = ["mypy"] name = "bleach" version = "6.1.0" description = "An easy safelist-based HTML-sanitizing tool." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -96,10 +117,23 @@ webencodings = "*" [package.extras] css = ["tinycss2 (>=1.1.0,<1.3)"] +[[package]] +name = "blinker" +version = "1.7.0" +description = "Fast, simple object-to-object and broadcast signaling" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "blinker-1.7.0-py3-none-any.whl", hash = "sha256:c3f865d4d54db7abc53758a01601cf343fe55b84c1de4e3fa910e420b438d5b9"}, + {file = "blinker-1.7.0.tar.gz", hash = "sha256:e6820ff6fa4e4d1d8e2747c2283749c3f547e4fee112b98555cdcdae32996182"}, +] + [[package]] name = "bokeh" version = "3.3.1" description = "Interactive plots and applications in the browser from Python" +category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -122,6 +156,7 @@ xyzservices = ">=2021.09.1" name = "certifi" version = "2023.7.22" description = "Python package for providing Mozilla's CA Bundle." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -133,6 +168,7 @@ files = [ name = "cffi" version = "1.16.0" description = "Foreign Function Interface for Python calling C code." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -197,6 +233,7 @@ pycparser = "*" name = "charset-normalizer" version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -292,10 +329,26 @@ files = [ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [[package]] name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -307,6 +360,7 @@ files = [ name = "comm" version = "0.2.0" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -324,6 +378,7 @@ test = ["pytest"] name = "contourpy" version = "1.2.0" description = "Python library for calculating contours of 2D quadrilateral grids" +category = "main" optional = false python-versions = ">=3.9" files = [ @@ -387,6 +442,7 @@ test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] name = "coverage" version = "7.3.2" description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -454,6 +510,7 @@ toml = ["tomli"] name = "cryptography" version = "41.0.5" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -499,6 +556,7 @@ test-randomorder = ["pytest-randomly"] name = "cycler" version = "0.12.1" description = "Composable style cycles" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -510,10 +568,100 @@ files = [ docs = ["ipython", "matplotlib", "numpydoc", "sphinx"] tests = ["pytest", "pytest-cov", "pytest-xdist"] +[[package]] +name = "dash" +version = "2.14.1" +description = "A Python framework for building reactive web-apps. Developed by Plotly." +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "dash-2.14.1-py3-none-any.whl", hash = "sha256:ce440ef7416945c9daa8274948483a0aac928a4fec768c0384fd4e9a6196eaf2"}, + {file = "dash-2.14.1.tar.gz", hash = "sha256:93dc9d665ec5d3720647d4cef4520a1c7cd1bde57e893ffeb7e6cd59781d3294"}, +] + +[package.dependencies] +ansi2html = "*" +dash-core-components = "2.0.0" +dash-html-components = "2.0.0" +dash-table = "5.0.0" +Flask = ">=1.0.4,<3.1" +importlib-metadata = {version = "*", markers = "python_version >= \"3.7\""} +nest-asyncio = "*" +plotly = ">=5.0.0" +requests = "*" +retrying = "*" +setuptools = "*" +typing-extensions = ">=4.1.1" +Werkzeug = "<3.1" + +[package.extras] +celery = ["celery[redis] (>=5.1.2)", "importlib-metadata (<5)", "redis (>=3.5.3)"] +ci = ["black (==21.6b0)", "black (==22.3.0)", "dash-dangerously-set-inner-html", "dash-flow-example (==0.0.5)", "flake8 (==3.9.2)", "flaky (==3.7.0)", "flask-talisman (==1.0.0)", "isort (==4.3.21)", "jupyterlab (<4.0.0)", "mimesis", "mock (==4.0.3)", "numpy (<=1.25.2)", "openpyxl", "orjson (==3.5.4)", "orjson (==3.6.7)", "pandas (==1.1.5)", "pandas (>=1.4.0)", "preconditions", "pyarrow", "pyarrow (<3)", "pylint (==2.13.5)", "pytest-mock", "pytest-rerunfailures", "pytest-sugar (==0.9.6)", "xlrd (<2)", "xlrd (>=2.0.1)"] +compress = ["flask-compress"] +dev = ["PyYAML (>=5.4.1)", "coloredlogs (>=15.0.1)", "fire (>=0.4.0)"] +diskcache = ["diskcache (>=5.2.1)", "multiprocess (>=0.70.12)", "psutil (>=5.8.0)"] +testing = ["beautifulsoup4 (>=4.8.2)", "cryptography (<3.4)", "dash-testing-stub (>=0.0.2)", "lxml (>=4.6.2)", "multiprocess (>=0.70.12)", "percy (>=2.0.2)", "psutil (>=5.8.0)", "pytest (>=6.0.2)", "requests[security] (>=2.21.0)", "selenium (>=3.141.0,<=4.2.0)", "waitress (>=1.4.4)"] + +[[package]] +name = "dash-bootstrap-components" +version = "1.5.0" +description = "Bootstrap themed components for use in Plotly Dash" +category = "dev" +optional = false +python-versions = ">=3.7, <4" +files = [ + {file = "dash-bootstrap-components-1.5.0.tar.gz", hash = "sha256:083158c07434b9965e2d6c3e8ca72dbbe47dab23e676258cef9bf0ad47d2e250"}, + {file = "dash_bootstrap_components-1.5.0-py3-none-any.whl", hash = "sha256:b487fec1a85e3d6a8564fe04c0a9cd9e846f75ea9e563456ed3879592889c591"}, +] + +[package.dependencies] +dash = ">=2.0.0" + +[package.extras] +pandas = ["numpy", "pandas"] + +[[package]] +name = "dash-core-components" +version = "2.0.0" +description = "Core component suite for Dash" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "dash_core_components-2.0.0-py3-none-any.whl", hash = "sha256:52b8e8cce13b18d0802ee3acbc5e888cb1248a04968f962d63d070400af2e346"}, + {file = "dash_core_components-2.0.0.tar.gz", hash = "sha256:c6733874af975e552f95a1398a16c2ee7df14ce43fa60bb3718a3c6e0b63ffee"}, +] + +[[package]] +name = "dash-html-components" +version = "2.0.0" +description = "Vanilla HTML components for Dash" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "dash_html_components-2.0.0-py3-none-any.whl", hash = "sha256:b42cc903713c9706af03b3f2548bda4be7307a7cf89b7d6eae3da872717d1b63"}, + {file = "dash_html_components-2.0.0.tar.gz", hash = "sha256:8703a601080f02619a6390998e0b3da4a5daabe97a1fd7a9cebc09d015f26e50"}, +] + +[[package]] +name = "dash-table" +version = "5.0.0" +description = "Dash table" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "dash_table-5.0.0-py3-none-any.whl", hash = "sha256:19036fa352bb1c11baf38068ec62d172f0515f73ca3276c79dee49b95ddc16c9"}, + {file = "dash_table-5.0.0.tar.gz", hash = "sha256:18624d693d4c8ef2ddec99a6f167593437a7ea0bf153aa20f318c170c5bc7308"}, +] + [[package]] name = "debugpy" version = "1.8.0" description = "An implementation of the Debug Adapter Protocol for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -541,6 +689,7 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -552,6 +701,7 @@ files = [ name = "exceptiongroup" version = "1.1.3" description = "Backport of PEP 654 (exception groups)" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -566,6 +716,7 @@ test = ["pytest (>=6)"] name = "executing" version = "2.0.1" description = "Get the currently executing AST node of a frame, and other information" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -576,10 +727,34 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +[[package]] +name = "flask" +version = "3.0.0" +description = "A simple framework for building complex web applications." +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "flask-3.0.0-py3-none-any.whl", hash = "sha256:21128f47e4e3b9d597a3e8521a329bf56909b690fcc3fa3e477725aa81367638"}, + {file = "flask-3.0.0.tar.gz", hash = "sha256:cfadcdb638b609361d29ec22360d6070a77d7463dcb3ab08d2c2f2f168845f58"}, +] + +[package.dependencies] +blinker = ">=1.6.2" +click = ">=8.1.3" +itsdangerous = ">=2.1.2" +Jinja2 = ">=3.1.2" +Werkzeug = ">=3.0.0" + +[package.extras] +async = ["asgiref (>=3.2)"] +dotenv = ["python-dotenv"] + [[package]] name = "fonttools" version = "4.44.0" description = "Tools to manipulate font files" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -641,10 +816,32 @@ ufo = ["fs (>=2.2.0,<3)"] unicode = ["unicodedata2 (>=15.1.0)"] woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] +[[package]] +name = "gunicorn" +version = "21.2.0" +description = "WSGI HTTP Server for UNIX" +category = "dev" +optional = false +python-versions = ">=3.5" +files = [ + {file = "gunicorn-21.2.0-py3-none-any.whl", hash = "sha256:3213aa5e8c24949e792bcacfc176fef362e7aac80b76c56f6b5122bf350722f0"}, + {file = "gunicorn-21.2.0.tar.gz", hash = "sha256:88ec8bff1d634f98e61b9f65bc4bf3cd918a90806c6f5c48bc5603849ec81033"}, +] + +[package.dependencies] +packaging = "*" + +[package.extras] +eventlet = ["eventlet (>=0.24.1)"] +gevent = ["gevent (>=1.4.0)"] +setproctitle = ["setproctitle"] +tornado = ["tornado (>=0.2)"] + [[package]] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -652,10 +849,31 @@ files = [ {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, ] +[[package]] +name = "importlib-metadata" +version = "6.8.0" +description = "Read metadata from Python packages" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"}, + {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] + [[package]] name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -667,6 +885,7 @@ files = [ name = "ipykernel" version = "6.26.0" description = "IPython Kernel for Jupyter" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -680,7 +899,7 @@ comm = ">=0.1.1" debugpy = ">=1.6.5" ipython = ">=7.23.1" jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" matplotlib-inline = ">=0.1" nest-asyncio = "*" packaging = "*" @@ -700,6 +919,7 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" name = "ipympl" version = "0.9.3" description = "Matplotlib Jupyter Extension" +category = "dev" optional = false python-versions = "*" files = [ @@ -723,6 +943,7 @@ docs = ["Sphinx (>=1.5)", "myst-nb", "sphinx-book-theme", "sphinx-copybutton", " name = "ipython" version = "8.17.2" description = "IPython: Productive Interactive Computing" +category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -760,6 +981,7 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.22)", "pa name = "ipython-genutils" version = "0.2.0" description = "Vestigial utilities from IPython" +category = "dev" optional = false python-versions = "*" files = [ @@ -771,6 +993,7 @@ files = [ name = "ipywidgets" version = "8.1.1" description = "Jupyter interactive widgets" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -792,6 +1015,7 @@ test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] name = "ipywidgets-bokeh" version = "1.5.0" description = "Allows embedding of Jupyter widgets in Bokeh layouts." +category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -800,17 +1024,30 @@ files = [ ] [package.dependencies] -bokeh = "==3.*" -ipykernel = ">=6.dev0,<6.18.0 || >6.18.0,<7.dev0" -ipywidgets = "==8.*" +bokeh = ">=3.0.0,<4.0.0" +ipykernel = ">=6.0.0,<6.18.0 || >6.18.0,<7.0.0" +ipywidgets = ">=8.0.0,<9.0.0" [package.extras] dev = ["anywidget (>=0.3.0)", "panel (>=1.0.4)", "pytest (>=7.3.1)", "pytest-cov (>=4.1.0)", "pytest-playwright (>=0.3.3)"] +[[package]] +name = "itsdangerous" +version = "2.1.2" +description = "Safely pass data to untrusted environments and back." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "itsdangerous-2.1.2-py3-none-any.whl", hash = "sha256:2c2349112351b88699d8d4b6b075022c0808887cb7ad10069318a8b0bc88db44"}, + {file = "itsdangerous-2.1.2.tar.gz", hash = "sha256:5dbbc68b317e5e42f327f9021763545dc3fc3bfe22e6deb96aaf1fc38874156a"}, +] + [[package]] name = "jedi" version = "0.19.1" description = "An autocompletion tool for Python that can be used for text editors." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -830,6 +1067,7 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -847,6 +1085,7 @@ i18n = ["Babel (>=2.7)"] name = "joblib" version = "1.3.2" description = "Lightweight pipelining with Python functions" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -858,6 +1097,7 @@ files = [ name = "jupyter-client" version = "8.6.0" description = "Jupyter protocol implementation and client libraries" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -866,7 +1106,7 @@ files = [ ] [package.dependencies] -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" tornado = ">=6.2" @@ -880,6 +1120,7 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-core" version = "5.5.0" description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -900,6 +1141,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyterlab-widgets" version = "3.0.9" description = "Jupyter interactive widgets for JupyterLab" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -911,6 +1153,7 @@ files = [ name = "kiwisolver" version = "1.4.5" description = "A fast implementation of the Cassowary constraint solver" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1024,6 +1267,7 @@ files = [ name = "linkify-it-py" version = "2.0.2" description = "Links recognition library with FULL unicode support." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1044,6 +1288,7 @@ test = ["coverage", "pytest", "pytest-cov"] name = "logging" version = "0.4.9.6" description = "A logging module for Python" +category = "main" optional = false python-versions = "*" files = [ @@ -1054,6 +1299,7 @@ files = [ name = "markdown" version = "3.5.1" description = "Python implementation of John Gruber's Markdown." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1069,6 +1315,7 @@ testing = ["coverage", "pyyaml"] name = "markdown-it-py" version = "3.0.0" description = "Python port of markdown-it. Markdown parsing, done right!" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1093,6 +1340,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "markupsafe" version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1162,6 +1410,7 @@ files = [ name = "matplotlib" version = "3.8.1" description = "Python plotting package" +category = "main" optional = false python-versions = ">=3.9" files = [ @@ -1210,6 +1459,7 @@ python-dateutil = ">=2.7" name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1224,6 +1474,7 @@ traitlets = "*" name = "mdit-py-plugins" version = "0.4.0" description = "Collection of plugins for markdown-it-py" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1243,6 +1494,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "mdurl" version = "0.1.2" description = "Markdown URL utilities" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1254,6 +1506,7 @@ files = [ name = "nest-asyncio" version = "1.5.8" description = "Patch asyncio to allow nested event loops" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1265,6 +1518,7 @@ files = [ name = "numpy" version = "1.26.2" description = "Fundamental package for array computing in Python" +category = "main" optional = false python-versions = ">=3.9" files = [ @@ -1310,6 +1564,7 @@ files = [ name = "packaging" version = "23.2" description = "Core utilities for Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1321,6 +1576,7 @@ files = [ name = "pandas" version = "2.1.2" description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" optional = false python-versions = ">=3.9" files = [ @@ -1389,6 +1645,7 @@ xml = ["lxml (>=4.8.0)"] name = "pandas-stubs" version = "2.1.1.230928" description = "Type annotations for pandas" +category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -1404,6 +1661,7 @@ types-pytz = ">=2022.1.1" name = "panel" version = "1.3.1" description = "The powerful data exploration & web app framework for Python." +category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -1441,6 +1699,7 @@ ui = ["jupyter-server", "playwright", "pytest-playwright"] name = "param" version = "2.0.1" description = "Make your Python code clearer and more reliable by declaring Parameters." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1462,6 +1721,7 @@ tests-full = ["cloudpickle", "gmpy", "ipython", "jsonschema", "nest-asyncio", "n name = "paramiko" version = "3.3.1" description = "SSH2 protocol library" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1483,6 +1743,7 @@ invoke = ["invoke (>=2.0)"] name = "parso" version = "0.8.3" description = "A Python Parser" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1498,6 +1759,7 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." +category = "dev" optional = false python-versions = "*" files = [ @@ -1512,6 +1774,7 @@ ptyprocess = ">=0.5" name = "pillow" version = "10.1.0" description = "Python Imaging Library (Fork)" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1579,6 +1842,7 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa name = "platformdirs" version = "4.0.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1590,10 +1854,27 @@ files = [ docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] +[[package]] +name = "plotly" +version = "5.18.0" +description = "An open-source, interactive data visualization library for Python" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "plotly-5.18.0-py3-none-any.whl", hash = "sha256:23aa8ea2f4fb364a20d34ad38235524bd9d691bf5299e800bca608c31e8db8de"}, + {file = "plotly-5.18.0.tar.gz", hash = "sha256:360a31e6fbb49d12b007036eb6929521343d6bee2236f8459915821baefa2cbb"}, +] + +[package.dependencies] +packaging = "*" +tenacity = ">=6.2.0" + [[package]] name = "pluggy" version = "1.3.0" description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1609,6 +1890,7 @@ testing = ["pytest", "pytest-benchmark"] name = "prompt-toolkit" version = "3.0.40" description = "Library for building powerful interactive command lines in Python" +category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -1623,6 +1905,7 @@ wcwidth = "*" name = "psutil" version = "5.9.6" description = "Cross-platform lib for process and system monitoring in Python." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -1651,6 +1934,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -1662,6 +1946,7 @@ files = [ name = "pure-eval" version = "0.2.2" description = "Safely evaluate AST nodes without side effects" +category = "dev" optional = false python-versions = "*" files = [ @@ -1676,6 +1961,7 @@ tests = ["pytest"] name = "pyarrow" version = "14.0.1" description = "Python library for Apache Arrow" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1724,6 +2010,7 @@ numpy = ">=1.16.6" name = "pycparser" version = "2.21" description = "C parser in Python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1735,6 +2022,7 @@ files = [ name = "pygments" version = "2.16.1" description = "Pygments is a syntax highlighting package written in Python." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1749,6 +2037,7 @@ plugins = ["importlib-metadata"] name = "pylance" version = "0.5.10" description = "python wrapper for lance-rs" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1770,6 +2059,7 @@ tests = ["duckdb", "ml_dtypes", "pandas (>=1.4)", "polars[pandas,pyarrow]", "pyt name = "pynacl" version = "1.5.0" description = "Python binding to the Networking and Cryptography (NaCl) library" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1796,6 +2086,7 @@ tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] name = "pyparsing" version = "3.1.1" description = "pyparsing module - Classes and methods to define and execute parsing grammars" +category = "main" optional = false python-versions = ">=3.6.8" files = [ @@ -1810,6 +2101,7 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pytest" version = "7.4.3" description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1832,6 +2124,7 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "pytest-cov" version = "4.1.0" description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1850,6 +2143,7 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "pytest-mock" version = "3.12.0" description = "Thin-wrapper around the mock package for easier use with pytest" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1867,6 +2161,7 @@ dev = ["pre-commit", "pytest-asyncio", "tox"] name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -1881,6 +2176,7 @@ six = ">=1.5" name = "pytz" version = "2023.3.post1" description = "World timezone definitions, modern and historical" +category = "main" optional = false python-versions = "*" files = [ @@ -1892,6 +2188,7 @@ files = [ name = "pyviz-comms" version = "3.0.0" description = "A JupyterLab extension for rendering HoloViz content." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1911,6 +2208,7 @@ tests = ["flake8", "pytest"] name = "pywin32" version = "306" description = "Python for Window Extensions" +category = "dev" optional = false python-versions = "*" files = [ @@ -1934,6 +2232,7 @@ files = [ name = "pyyaml" version = "6.0.1" description = "YAML parser and emitter for Python" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -1993,6 +2292,7 @@ files = [ name = "pyzmq" version = "25.1.1" description = "Python bindings for 0MQ" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2098,6 +2398,7 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "quapy" version = "0.1.7" description = "QuaPy: a framework for Quantification in Python" +category = "main" optional = false python-versions = ">=3.6, <4" files = [ @@ -2118,6 +2419,7 @@ xlrd = "*" name = "requests" version = "2.31.0" description = "Python HTTP for Humans." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2135,10 +2437,26 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "retrying" +version = "1.3.4" +description = "Retrying" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "retrying-1.3.4-py3-none-any.whl", hash = "sha256:8cc4d43cb8e1125e0ff3344e9de678fefd85db3b750b81b2240dc0183af37b35"}, + {file = "retrying-1.3.4.tar.gz", hash = "sha256:345da8c5765bd982b1d1915deb9102fd3d1f7ad16bd84a9700b85f64d24e8f3e"}, +] + +[package.dependencies] +six = ">=1.7.0" + [[package]] name = "scikit-learn" version = "1.3.2" description = "A set of python modules for machine learning and data mining" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2186,6 +2504,7 @@ tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc ( name = "scipy" version = "1.11.4" description = "Fundamental algorithms for scientific computing in Python" +category = "main" optional = false python-versions = ">=3.9" files = [ @@ -2224,10 +2543,28 @@ dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyl doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +[[package]] +name = "setuptools" +version = "69.0.2" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "setuptools-69.0.2-py3-none-any.whl", hash = "sha256:1e8fdff6797d3865f37397be788a4e3cba233608e9b509382a2777d25ebde7f2"}, + {file = "setuptools-69.0.2.tar.gz", hash = "sha256:735896e78a4742605974de002ac60562d286fa8051a7e2299445e8e8fbb01aa6"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] + [[package]] name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -2239,6 +2576,7 @@ files = [ name = "stack-data" version = "0.6.3" description = "Extract data from python stack frames and tracebacks for informative displays" +category = "dev" optional = false python-versions = "*" files = [ @@ -2258,6 +2596,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "tabulate" version = "0.9.0" description = "Pretty-print tabular data" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2268,10 +2607,26 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tenacity" +version = "8.2.3" +description = "Retry code until it succeeds" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"}, + {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"}, +] + +[package.extras] +doc = ["reno", "sphinx", "tornado (>=4.5)"] + [[package]] name = "threadpoolctl" version = "3.2.0" description = "threadpoolctl" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2283,6 +2638,7 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2294,6 +2650,7 @@ files = [ name = "tornado" version = "6.3.3" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "dev" optional = false python-versions = ">= 3.8" files = [ @@ -2314,6 +2671,7 @@ files = [ name = "tqdm" version = "4.66.1" description = "Fast, Extensible Progress Meter" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2334,6 +2692,7 @@ telegram = ["requests"] name = "traitlets" version = "5.13.0" description = "Traitlets Python configuration system" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2349,6 +2708,7 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.6.0)", "pre-commit", "pytest (>=7.0, name = "types-pytz" version = "2023.3.1.1" description = "Typing stubs for pytz" +category = "dev" optional = false python-versions = "*" files = [ @@ -2360,6 +2720,7 @@ files = [ name = "typing-extensions" version = "4.8.0" description = "Backported and Experimental Type Hints for Python 3.8+" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2371,6 +2732,7 @@ files = [ name = "tzdata" version = "2023.3" description = "Provider of IANA time zone data" +category = "main" optional = false python-versions = ">=2" files = [ @@ -2382,6 +2744,7 @@ files = [ name = "uc-micro-py" version = "1.0.2" description = "Micro subset of unicode data files for linkify-it-py projects." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2396,6 +2759,7 @@ test = ["coverage", "pytest", "pytest-cov"] name = "urllib3" version = "2.0.7" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2413,6 +2777,7 @@ zstd = ["zstandard (>=0.18.0)"] name = "wcwidth" version = "0.2.9" description = "Measures the displayed width of unicode strings in a terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -2424,6 +2789,7 @@ files = [ name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" +category = "dev" optional = false python-versions = "*" files = [ @@ -2431,10 +2797,29 @@ files = [ {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"}, ] +[[package]] +name = "werkzeug" +version = "3.0.1" +description = "The comprehensive WSGI web application library." +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"}, + {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "widgetsnbextension" version = "4.0.9" description = "Jupyter interactive widgets for Jupyter Notebook" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2446,6 +2831,7 @@ files = [ name = "xlrd" version = "2.0.1" description = "Library for developers to extract data from Microsoft Excel (tm) .xls spreadsheet files" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -2462,6 +2848,7 @@ test = ["pytest", "pytest-cov"] name = "xyzservices" version = "2023.10.1" description = "Source of XYZ tiles providers" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2469,7 +2856,23 @@ files = [ {file = "xyzservices-2023.10.1.tar.gz", hash = "sha256:091229269043bc8258042edbedad4fcb44684b0473ede027b5672ad40dc9fa02"}, ] +[[package]] +name = "zipp" +version = "3.17.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"}, + {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "036900e0883a559f83033add6bf8aa6f3178bd1c8fe8ea68bf504610157a1543" +content-hash = "244fa10a77087a9c49906733a2af33d1feacb62e1a32f9de580d531378d3eceb" diff --git a/pyproject.toml b/pyproject.toml index 36cec09..4a7e446 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ abstention = "^0.1.3.1" main = "quacc.main:main" run = "run:run" panel = "qcpanel.run:run" +dash = "qcdash.app:run" sync_up = "remote:sync_code" sync_down = "remote:sync_output" merge_data = "merge_data:run" @@ -27,6 +28,7 @@ poetry_command = "" [tool.poe.tasks] ilona = "ssh volpi@ilona.isti.cnr.it" +dash = "gunicorn qcdash.app:server -b ilona.isti.cnr.it:33421" [tool.poe.tasks.logr] shell = """ @@ -48,6 +50,9 @@ ipympl = "^0.9.3" ipykernel = "^6.26.0" ipywidgets-bokeh = "^1.5.0" pandas-stubs = "^2.1.1.230928" +dash = "^2.14.1" +dash-bootstrap-components = "^1.5.0" +gunicorn = "^21.2.0" [tool.pytest.ini_options] addopts = "--cov=quacc --capture=tee-sys" diff --git a/qcdash/__init__.py b/qcdash/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/qcdash/app.py b/qcdash/app.py new file mode 100644 index 0000000..d91dae4 --- /dev/null +++ b/qcdash/app.py @@ -0,0 +1,413 @@ +import json +import os +from collections import defaultdict +from json import JSONDecodeError +from pathlib import Path +from typing import List +from urllib.parse import parse_qsl, quote, urlencode, urlparse + +import dash_bootstrap_components as dbc +import numpy as np +from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html +from dash.dash_table.Format import Format, Scheme + +from quacc import plot +from quacc.evaluation.estimators import CE +from quacc.evaluation.report import CompReport, DatasetReport +from quacc.evaluation.stats import ttest_rel + +backend = plot.get_backend("plotly") + +valid_plot_modes = defaultdict(lambda: CompReport._default_modes) +valid_plot_modes["avg"] = DatasetReport._default_dr_modes + + +def get_datasets(root: str | Path) -> List[DatasetReport]: + def load_dataset(dataset): + dataset = Path(dataset) + return DatasetReport.unpickle(dataset) + + def explore_datasets(root: str | Path) -> List[Path]: + if isinstance(root, str): + root = Path(root) + + if root.name == "plot": + return [] + + if not root.exists(): + return [] + + dr_paths = [] + for f in os.listdir(root): + if (root / f).is_dir(): + dr_paths += explore_datasets(root / f) + elif f == f"{root.name}.pickle": + dr_paths.append(root / f) + + return dr_paths + + dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t)) + return {str(drp.parent): load_dataset(drp) for drp in dr_paths} + + +def get_fig(dr: DatasetReport, metric, estimators, view, mode): + estimators = CE.name[estimators] + match (view, mode): + case ("avg", _): + return dr.get_plots( + mode=mode, + metric=metric, + estimators=estimators, + conf="plotly", + save_fig=False, + backend=backend, + ) + case (_, _): + cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)] + return cr.get_plots( + mode=mode, + metric=metric, + estimators=estimators, + conf="plotly", + save_fig=False, + backend=backend, + ) + + +def get_table(dr: DatasetReport, metric, estimators, view, mode): + estimators = CE.name[estimators] + _prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs] + match (view, mode): + case ("avg", "train_table"): + return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean() + case ("avg", "test_table"): + return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + case ("avg", "shift_table"): + return ( + dr.shift_data(metric=metric, estimators=estimators) + .groupby(level=0) + .mean() + ) + case ("avg", "stats_table"): + return ttest_rel(dr, metric=metric, estimators=estimators) + case (_, "train_table"): + cr = dr.crs[_prevs.index(view)] + return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean() + case (_, "shift_table"): + cr = dr.crs[_prevs.index(view)] + return ( + cr.shift_data(metric=metric, estimators=estimators) + .groupby(level=0) + .mean() + ) + + +def get_DataTable(df): + _primary = "#0d6efd" + if df.empty: + return None + + df = df.reset_index() + columns = { + c: dict( + id=c, + name=c, + type="numeric", + format=Format(precision=6, scheme=Scheme.exponent), + ) + for c in df.columns + } + columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed) + columns = list(columns.values()) + data = df.to_dict("records") + + return html.Div( + [ + dash_table.DataTable( + data=data, + columns=columns, + id="table1", + style_cell={ + "padding": "0 12px", + "border": "0", + "border-bottom": f"1px solid {_primary}", + }, + style_table={ + "margin": "6vh 15px", + "padding": "15px", + "maxWidth": "80vw", + "overflowX": "auto", + "border": f"0px solid {_primary}", + "border-radius": "6px", + }, + ) + ], + style={ + "display": "flex", + "flex-direction": "column", + # "justify-content": "center", + "align-items": "center", + "height": "100vh", + }, + ) + + +def get_Graph(fig): + if fig is None: + return None + + return dcc.Graph( + id="graph1", + figure=fig, + style={ + "margin": 0, + "height": "100vh", + }, + ) + + +datasets = get_datasets("output") + +app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) +# app.config.suppress_callback_exceptions = True +sidebar_style = { + "top": 0, + "left": 0, + "bottom": 0, + "padding": "1vw", + "padding-top": "2vw", + "margin": "0px", + "flex": 1, + "overflow": "scroll", + "height": "100vh", +} + +content_style = { + "flex": 5, + "maxWidth": "84vw", +} + + +def parse_href(href: str): + parse_result = urlparse(href) + params = parse_qsl(parse_result.query) + return dict(params) + + +def get_sidebar(): + return [ + html.H4("Parameters:", style={"margin-bottom": "1vw"}), + dbc.Select( + # options=list(datasets.keys()), + # value=list(datasets.keys())[0], + id="dataset", + ), + dbc.Select( + id="metric", + style={"margin-top": "1vh"}, + ), + html.Div( + [ + dbc.RadioItems( + id="view", + class_name="btn-group mt-3", + input_class_name="btn-check", + label_class_name="btn btn-outline-primary", + label_checked_class_name="active", + ), + dbc.RadioItems( + id="mode", + class_name="btn-group mt-3", + input_class_name="btn-check", + label_class_name="btn btn-outline-primary", + label_checked_class_name="active", + ), + ], + className="radio-group-v d-flex justify-content-around", + ), + html.Div( + [ + dbc.Checklist( + id="estimators", + className="btn-group mt-3", + inputClassName="btn-check", + labelClassName="btn btn-outline-primary", + labelCheckedClassName="active", + ), + ], + className="radio-group-wide", + ), + ] + + +app.layout = html.Div( + [ + dcc.Interval(id="reload", interval=10 * 60 * 1000), + dcc.Location(id="url", refresh=False), + html.Div( + [ + html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style), + html.Div(id="app_content", style=content_style), + ], + id="page_layout", + style={"display": "flex", "flexDirection": "row"}, + ), + ] +) + +server = app.server + + +def apply_param(href, triggered_id, id, curr): + match triggered_id: + case "url": + params = parse_href(href) + return params.get(id, None) + case _: + return curr + + +@callback( + Output("dataset", "value"), + Output("dataset", "options"), + Input("url", "href"), + Input("reload", "n_intervals"), + State("dataset", "value"), +) +def update_dataset(href, n_intervals, dataset): + match ctx.triggered_id: + case "reload": + new_datasets = get_datasets("output") + global datasets + datasets = new_datasets + req_dataset = dataset + case "url": + params = parse_href(href) + req_dataset = params.get("dataset", None) + + available_datasets = list(datasets.keys()) + new_dataset = ( + req_dataset if req_dataset in available_datasets else available_datasets[0] + ) + return new_dataset, available_datasets + + +@callback( + Output("metric", "options"), + Output("metric", "value"), + Input("url", "href"), + Input("dataset", "value"), + State("metric", "value"), +) +def update_metrics(href, dataset, curr_metric): + dr = datasets[dataset] + old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric) + valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")] + new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0] + return valid_metrics, new_metric + + +@callback( + Output("estimators", "options"), + Output("estimators", "value"), + Input("url", "href"), + Input("dataset", "value"), + Input("metric", "value"), + State("estimators", "value"), +) +def update_estimators(href, dataset, metric, curr_estimators): + dr = datasets[dataset] + old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators) + if isinstance(old_estimators, str): + try: + old_estimators = json.loads(old_estimators) + except JSONDecodeError: + old_estimators = [] + valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy() + new_estimators = valid_estimators[ + np.isin(valid_estimators, old_estimators) + ].tolist() + return valid_estimators, new_estimators + + +@callback( + Output("view", "options"), + Output("view", "value"), + Input("url", "href"), + Input("dataset", "value"), + State("view", "value"), +) +def update_view(href, dataset, curr_view): + dr = datasets[dataset] + old_view = apply_param(href, ctx.triggered_id, "view", curr_view) + valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs] + new_view = old_view if old_view in valid_views else valid_views[0] + return valid_views, new_view + + +@callback( + Output("mode", "options"), + Output("mode", "value"), + Input("url", "href"), + Input("view", "value"), + State("mode", "value"), +) +def update_mode(href, view, curr_mode): + old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode) + valid_modes = valid_plot_modes[view] + new_mode = old_mode if old_mode in valid_modes else valid_modes[0] + return valid_modes, new_mode + + +@callback( + Output("app_content", "children"), + Output("url", "search"), + Input("dataset", "value"), + Input("metric", "value"), + Input("estimators", "value"), + Input("view", "value"), + Input("mode", "value"), +) +def update_content(dataset, metric, estimators, view, mode): + search = urlencode( + dict( + dataset=dataset, + metric=metric, + estimators=json.dumps(estimators), + view=view, + mode=mode, + ), + quote_via=quote, + ) + dr = datasets[dataset] + match mode: + case m if m.endswith("table"): + df = get_table( + dr=dr, + metric=metric, + estimators=estimators, + view=view, + mode=mode, + ) + dt = get_DataTable(df) + app_content = [] if dt is None else [dt] + case _: + fig = get_fig( + dr=dr, + metric=metric, + estimators=estimators, + view=view, + mode=mode, + ) + g = get_Graph(fig) + app_content = [] if g is None else [g] + + return app_content, f"?{search}" + + +def run(): + app.run(debug=True) + + +if __name__ == "__main__": + run() diff --git a/qcdash/assets/radio_group.css b/qcdash/assets/radio_group.css new file mode 100644 index 0000000..b6d9e67 --- /dev/null +++ b/qcdash/assets/radio_group.css @@ -0,0 +1,86 @@ +/* restyle radio items */ +.radio-group .form-check { + padding-left: 0; +} + +.radio-group .btn-group > .form-check:not(:last-child) > .btn { + border-top-right-radius: 0; + border-bottom-right-radius: 0; +} + +.radio-group .btn-group > .form-check:not(:first-child) > .btn { + border-top-left-radius: 0; + border-bottom-left-radius: 0; + margin-left: -1px; +} + +.radio-group-v{ + padding: 0 10px; +} +.radio-group-v .form-check { + padding-top: 0; + padding-left: 2px; +} + +.radio-group-v .btn-group { + flex-direction: column; + justify-content: center; + align-items: stretch; + flex-grow: 1; +} + +.radio-group-v > .btn-group:first-child { + flex-grow: 0; +} + +.radio-group-v > .btn-group:last-child { + margin-left: 20px; +} + +.radio-group-v .btn-group > .form-check:not(:last-child) > .btn { + border-bottom-right-radius: 0; + border-bottom-left-radius: 0; +} + +.radio-group-v .btn-group > .form-check:not(:first-child) > .btn { + border-top-right-radius: 0; + border-top-left-radius: 0; + margin-top: -3px; +} + + +.radio-group-v .btn-group .btn{ + width: 100%; +} + +.radio-group-wide .form-check { + padding-left: 0px; +} + +.radio-group-wide .btn-group{ + flex-wrap: wrap; +} + +.radio-group-wide .btn-group .form-check{ + flex: 1; + margin-top: -3px; + margin-left: -1px; +} + + +.radio-group-wide .btn-group .form-check .btn{ + width: 100%; + border-radius: 0; +} + +.radio-group-wide .btn-group .form-check:first-child > .btn{ + border-top-left-radius: 10px; +} + +.radio-group-wide .btn-group .form-check:last-child > .btn{ + border-bottom-right-radius: 10px; +} + +div#app-sidebar{ + border-right: 2px solid var(--primary); +} \ No newline at end of file diff --git a/qcpanel/old_run.py b/qcpanel/old_run.py deleted file mode 100644 index 1089e7a..0000000 --- a/qcpanel/old_run.py +++ /dev/null @@ -1,290 +0,0 @@ -import argparse -import os -from pathlib import Path - -import panel as pn -import param - -from quacc.evaluation.estimators import CE -from quacc.evaluation.report import DatasetReport - -pn.extension(design="bootstrap") - - -def create_cr_plots( - dr: DatasetReport, - mode="delta", - metric="acc", - estimators=None, - prev=None, -): - idx = [round(cr.train_prev[1] * 100) for cr in dr.crs].index(prev) - cr = dr.crs[idx] - estimators = CE.name[estimators] - _dpi = 112 - return pn.pane.Matplotlib( - cr.get_plots( - mode=mode, - metric=metric, - estimators=estimators, - conf="panel", - return_fig=True, - ), - tight=True, - format="png", - sizing_mode="scale_height", - # sizing_mode="scale_both", - ) - - -def create_avg_plots( - dr: DatasetReport, - mode="delta", - metric="acc", - estimators=None, - prev=None, -): - estimators = CE.name[estimators] - return pn.pane.Matplotlib( - dr.get_plots( - mode=mode, - metric=metric, - estimators=estimators, - conf="panel", - return_fig=True, - ), - tight=True, - format="png", - sizing_mode="scale_height", - # sizing_mode="scale_both", - ) - - -def build_cr_tab(dr: DatasetReport): - _data = dr.data() - _metrics = _data.columns.unique(0) - _estimators = _data.columns.unique(1) - - valid_metrics = [m for m in _metrics if not m.endswith("_score")] - metric_widget = pn.widgets.Select( - name="metric", - value="acc", - options=valid_metrics, - align="center", - ) - - valid_estimators = [e for e in _estimators if e != "ref"] - estimators_widget = pn.widgets.CheckButtonGroup( - name="estimators", - options=valid_estimators, - value=valid_estimators, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - sizing_mode="scale_width", - ) - - valid_plot_modes = ["delta", "delta_stdev", "diagonal", "shift"] - plot_mode_widget = pn.widgets.RadioButtonGroup( - name="mode", - value=valid_plot_modes[0], - options=valid_plot_modes, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - sizing_mode="scale_width", - ) - - valid_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs] - prevs_widget = pn.widgets.RadioButtonGroup( - name="train prevalence", - value=valid_prevs[0], - options=valid_prevs, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - ) - - plot_pane = pn.bind( - create_cr_plots, - dr=dr, - mode=plot_mode_widget, - metric=metric_widget, - estimators=estimators_widget, - prev=prevs_widget, - ) - - return pn.Row( - pn.Spacer(width=20), - pn.Column( - metric_widget, - pn.Row( - prevs_widget, - plot_mode_widget, - ), - estimators_widget, - align="center", - ), - pn.Spacer(sizing_mode="scale_width"), - plot_pane, - ) - - -def build_avg_tab(dr: DatasetReport): - _data = dr.data() - _metrics = _data.columns.unique(0) - _estimators = _data.columns.unique(1) - - valid_metrics = [m for m in _metrics if not m.endswith("_score")] - metric_widget = pn.widgets.Select( - name="metric", - value="acc", - options=valid_metrics, - align="center", - ) - - valid_estimators = [e for e in _estimators if e != "ref"] - estimators_widget = pn.widgets.CheckButtonGroup( - name="estimators", - options=valid_estimators, - value=valid_estimators, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - sizing_mode="scale_width", - ) - - valid_plot_modes = [ - "delta_train", - "stdev_train", - "delta_test", - "stdev_test", - "shift", - ] - plot_mode_widget = pn.widgets.RadioButtonGroup( - name="mode", - value=valid_plot_modes[0], - options=valid_plot_modes, - button_style="outline", - button_type="primary", - align="center", - orientation="vertical", - sizing_mode="scale_width", - ) - - plot_pane = pn.bind( - create_avg_plots, - dr=dr, - mode=plot_mode_widget, - metric=metric_widget, - estimators=estimators_widget, - ) - - return pn.Row( - pn.Spacer(width=20), - pn.Column( - metric_widget, - plot_mode_widget, - estimators_widget, - align="center", - ), - pn.Spacer(sizing_mode="scale_width"), - plot_pane, - ) - - -def build_dataset(dataset_path: Path): - dr: DatasetReport = DatasetReport.unpickle(dataset_path) - - prevs_tab = ("train prevs.", build_cr_tab(dr)) - avg_tab = ("avg", build_avg_tab(dr)) - - app = pn.Tabs(objects=[avg_tab, prevs_tab], dynamic=False) - app.servable() - return app - - -def explore_datasets(root: Path | str): - if isinstance(root, str): - root = Path(root) - - if root.name == "plot": - return [] - - if not root.exists(): - return [] - - drs = [] - for f in os.listdir(root): - if (root / f).is_dir(): - drs += explore_datasets(root / f) - elif f == f"{root.name}.pickle": - drs.append((root, build_dataset(root / f))) - # drs.append((str(root),)) - - return drs - - -class PlotSelector(param.Parameterized): - metric = param.Selector(objects=["acc", "f1"]) - view = param.Selector(objects=["train prevs", "avg"]) - - -def plot_selector_widget(): - return pn.Param( - PlotSelector.param, - widgets={ - "metric": pn.widgets.Select, - "view": pn.widgets.Select, - }, - ) - - -def serve(address="localhost"): - # app = build_dataset(Path("output/rcv1_CCAT_9prevs/rcv1_CCAT_9prevs.pickle")) - __base_path = "output" - __tabs = sorted( - explore_datasets(__base_path), key=lambda t: (len(t[0].parts), t[0]) - ) - __tabs = [(str(p.relative_to(Path(__base_path))), d) for (p, d) in __tabs] - if len(__tabs) > 0: - app = pn.Tabs( - objects=__tabs, - tabs_location="left", - dynamic=False, - ) - else: - app = pn.Column( - pn.pane.Str("No Dataset Found", styles={"font-size": "24pt"}), - align="center", - ) - - __port = 33420 - __allowed = [address] - if address == "localhost": - __allowed.append("127.0.0.1") - - pn.serve( - app, - autoreload=True, - port=__port, - show=False, - address=address, - websocket_origin=[f"{_a}:{__port}" for _a in __allowed], - ) - - -def run(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--address", - action="store", - dest="address", - default="localhost", - ) - args = parser.parse_args() - serve(address=args.address) diff --git a/qcpanel/run.py b/qcpanel/run.py index fdaf2b0..ca6a477 100644 --- a/qcpanel/run.py +++ b/qcpanel/run.py @@ -1,10 +1,11 @@ import argparse import panel as pn +from panel.theme.fast import FastDarkTheme, FastDefaultTheme from qcpanel.viewer import QuaccTestViewer -# pn.config.design = pn.theme.Bootstrap +# pn.config.design = Fast # pn.config.theme = "dark" pn.config.notifications = True @@ -59,8 +60,8 @@ def app_instance(): ], main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")], modal=[qtv.modal_pane], - theme=pn.theme.DarkTheme, - theme_toggle=False, + # theme=FastDefaultTheme, + theme_toggle=True, ) app.servable() diff --git a/qcpanel/util.py b/qcpanel/util.py index 6d1473d..3286c2f 100644 --- a/qcpanel/util.py +++ b/qcpanel/util.py @@ -52,7 +52,7 @@ def create_plots( metric=metric, estimators=estimators, conf="panel", - return_fig=True, + save_fig=False, ) return ( pn.pane.Matplotlib( @@ -91,7 +91,7 @@ def create_plots( metric=metric, estimators=estimators, conf="panel", - return_fig=True, + save_fig=False, ) return ( pn.pane.Matplotlib( diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py index e61e709..ef57345 100644 --- a/quacc/evaluation/report.py +++ b/quacc/evaluation/report.py @@ -6,7 +6,7 @@ from typing import List, Tuple import numpy as np import pandas as pd -from quacc import plot +import quacc.plot as plot from quacc.utils import fmt_line_md @@ -215,16 +215,17 @@ class CompReport: def get_plots( self, - mode="delta", + mode="delta_train", metric="acc", estimators=None, conf="default", - return_fig=False, + save_fig=True, base_path=None, + backend=None, ) -> List[Tuple[str, Path]]: if mode == "delta_train": avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) - if avg_data.empty is True: + if avg_data.empty: return None return plot.plot_delta( @@ -234,8 +235,9 @@ class CompReport: metric=metric, name=conf, train_prev=self.train_prev, - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "stdev_train": avg_data = self.avg_by_prevs(metric=metric, estimators=estimators) @@ -251,8 +253,9 @@ class CompReport: name=conf, train_prev=self.train_prev, stdevs=st_data.T.to_numpy(), - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "diagonal": f_data = self.data(metric=metric + "_score", estimators=estimators) @@ -268,8 +271,9 @@ class CompReport: metric=metric, name=conf, train_prev=self.train_prev, - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "shift": _shift_data = self.shift_data(metric=metric, estimators=estimators) @@ -290,8 +294,9 @@ class CompReport: name=conf, train_prev=self.train_prev, counts=shift_counts.T.to_numpy(), - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) def to_md( @@ -323,11 +328,12 @@ class CompReport: plot_modes = [m for m in modes if not m.endswith("table")] for mode in plot_modes: res += f"### {mode}\n" - op = self.get_plots( + _, op = self.get_plots( mode=mode, metric=metric, estimators=estimators, conf=conf, + save_fig=True, base_path=plot_path, ) res += f"![plot_{mode}]({op.relative_to(op.parents[1]).as_posix()})\n" @@ -424,12 +430,15 @@ class DatasetReport: metric="acc", estimators=None, conf="default", - return_fig=False, + save_fig=True, base_path=None, + backend=None, ): if mode == "delta_train": _data = self.data(metric, estimators) if data is None else data avg_on_train = _data.groupby(level=1).mean() + if avg_on_train.empty: + return None prevs_on_train = np.sort(avg_on_train.index.unique(0)) return plot.plot_delta( base_prevs=np.around( @@ -441,12 +450,15 @@ class DatasetReport: name=conf, train_prev=None, avg="train", - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "stdev_train": _data = self.data(metric, estimators) if data is None else data avg_on_train = _data.groupby(level=1).mean() + if avg_on_train.empty: + return None prevs_on_train = np.sort(avg_on_train.index.unique(0)) stdev_on_train = _data.groupby(level=1).std() return plot.plot_delta( @@ -460,12 +472,15 @@ class DatasetReport: train_prev=None, stdevs=stdev_on_train.T.to_numpy(), avg="train", - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "delta_test": _data = self.data(metric, estimators) if data is None else data avg_on_test = _data.groupby(level=0).mean() + if avg_on_test.empty: + return None prevs_on_test = np.sort(avg_on_test.index.unique(0)) return plot.plot_delta( base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2), @@ -475,12 +490,15 @@ class DatasetReport: name=conf, train_prev=None, avg="test", - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "stdev_test": _data = self.data(metric, estimators) if data is None else data avg_on_test = _data.groupby(level=0).mean() + if avg_on_test.empty: + return None prevs_on_test = np.sort(avg_on_test.index.unique(0)) stdev_on_test = _data.groupby(level=0).std() return plot.plot_delta( @@ -492,12 +510,15 @@ class DatasetReport: train_prev=None, stdevs=stdev_on_test.T.to_numpy(), avg="test", - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) elif mode == "shift": _shift_data = self.shift_data(metric, estimators) if data is None else data avg_shift = _shift_data.groupby(level=0).mean() + if avg_shift.empty: + return None count_shift = _shift_data.groupby(level=0).count() prevs_shift = np.sort(avg_shift.index.unique(0)) return plot.plot_shift( @@ -508,8 +529,9 @@ class DatasetReport: name=conf, train_prev=None, counts=count_shift.T.to_numpy(), - return_fig=return_fig, + save_fig=save_fig, base_path=base_path, + backend=backend, ) def to_md( @@ -545,24 +567,26 @@ class DatasetReport: res += avg_on_train_tbl.to_html() + "\n\n" if "delta_train" in dr_modes: - delta_op = self.get_plots( + _, delta_op = self.get_plots( data=_data, mode="delta_train", metric=metric, estimators=estimators, conf=conf, base_path=plot_path, + save_fig=True, ) res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" if "stdev_train" in dr_modes: - delta_stdev_op = self.get_plots( + _, delta_stdev_op = self.get_plots( data=_data, mode="stdev_train", metric=metric, estimators=estimators, conf=conf, base_path=plot_path, + save_fig=True, ) res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" @@ -575,24 +599,26 @@ class DatasetReport: res += avg_on_test_tbl.to_html() + "\n\n" if "delta_test" in dr_modes: - delta_op = self.get_plots( + _, delta_op = self.get_plots( data=_data, mode="delta_test", metric=metric, estimators=estimators, conf=conf, base_path=plot_path, + save_fig=True, ) res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n" if "stdev_test" in dr_modes: - delta_stdev_op = self.get_plots( + _, delta_stdev_op = self.get_plots( data=_data, mode="stdev_test", metric=metric, estimators=estimators, conf=conf, base_path=plot_path, + save_fig=True, ) res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n" @@ -605,13 +631,14 @@ class DatasetReport: res += shift_on_train_tbl.to_html() + "\n\n" if "shift" in dr_modes: - shift_op = self.get_plots( + _, shift_op = self.get_plots( data=_shift_data, mode="shift", metric=metric, estimators=estimators, conf=conf, base_path=plot_path, + save_fig=True, ) res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n" diff --git a/quacc/plot.py b/quacc/plot.py deleted file mode 100644 index e0bbefa..0000000 --- a/quacc/plot.py +++ /dev/null @@ -1,265 +0,0 @@ -from pathlib import Path - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -from cycler import cycler - -from quacc import utils - -matplotlib.use("agg") - - -def _get_markers(n: int): - ls = "ovx+sDph*^1234X><.Pd" - if n > len(ls): - ls = ls * (n / len(ls) + 1) - return list(ls)[:n] - - -def plot_delta( - base_prevs, - columns, - data, - *, - stdevs=None, - pos_class=1, - metric="acc", - name="default", - train_prev=None, - legend=True, - avg=None, - return_fig=False, - base_path=None, -) -> Path: - _base_title = "delta_stdev" if stdevs is not None else "delta" - if train_prev is not None: - t_prev_pos = int(round(train_prev[pos_class] * 100)) - title = f"{_base_title}_{name}_{t_prev_pos}_{metric}" - else: - title = f"{_base_title}_{name}_avg_{avg}_{metric}" - - if base_path is None: - base_path = utils.get_quacc_home() / "plots" - - fig, ax = plt.subplots() - ax.set_aspect("auto") - ax.grid() - - NUM_COLORS = len(data) - cm = plt.get_cmap("tab10") - if NUM_COLORS > 10: - cm = plt.get_cmap("tab20") - cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) - - base_prevs = base_prevs[:, pos_class] - for method, deltas, _cy in zip(columns, data, cy): - ax.plot( - base_prevs, - deltas, - label=method, - color=_cy["color"], - linestyle="-", - marker="o", - markersize=3, - zorder=2, - ) - if stdevs is not None: - _col_idx = np.where(columns == method)[0] - stdev = stdevs[_col_idx].flatten() - nn_idx = np.intersect1d( - np.where(deltas != np.nan)[0], - np.where(stdev != np.nan)[0], - ) - _bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx] - ax.fill_between( - _bps, - _ds - _st, - _ds + _st, - color=_cy["color"], - alpha=0.25, - ) - - x_label = "test" if avg is None or avg == "train" else "train" - ax.set( - xlabel=f"{x_label} prevalence", - ylabel=metric, - title=title, - ) - - if legend: - ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) - - if return_fig: - return fig - - output_path = base_path / f"{title}.png" - fig.savefig(output_path, bbox_inches="tight") - return output_path - - -def plot_diagonal( - reference, - columns, - data, - *, - pos_class=1, - metric="acc", - name="default", - train_prev=None, - legend=True, - return_fig=False, - base_path=None, -): - if train_prev is not None: - t_prev_pos = int(round(train_prev[pos_class] * 100)) - title = f"diagonal_{name}_{t_prev_pos}_{metric}" - else: - title = f"diagonal_{name}_{metric}" - - if base_path is None: - base_path = utils.get_quacc_home() / "plots" - - fig, ax = plt.subplots() - ax.set_aspect("auto") - ax.grid() - ax.set_aspect("equal") - - NUM_COLORS = len(data) - cm = plt.get_cmap("tab10") - if NUM_COLORS > 10: - cm = plt.get_cmap("tab20") - cy = cycler( - color=[cm(i) for i in range(NUM_COLORS)], - marker=_get_markers(NUM_COLORS), - ) - - reference = np.array(reference) - x_ticks = np.unique(reference) - x_ticks.sort() - - for deltas, _cy in zip(data, cy): - ax.plot( - reference, - deltas, - color=_cy["color"], - linestyle="None", - marker=_cy["marker"], - markersize=3, - zorder=2, - alpha=0.25, - ) - - # ensure limits are equal for both axes - _alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1) - _lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)]) - ax.set(xlim=tuple(_lims), ylim=tuple(_lims)) - - for method, deltas, _cy in zip(columns, data, cy): - slope, interc = np.polyfit(reference, deltas, 1) - y_lr = np.array([slope * x + interc for x in _lims]) - ax.plot( - _lims, - y_lr, - label=method, - color=_cy["color"], - linestyle="-", - markersize="0", - zorder=1, - ) - - # plot reference line - ax.plot( - _lims, - _lims, - color="black", - linestyle="--", - markersize=0, - zorder=1, - ) - - ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title) - - if legend: - ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) - - if return_fig: - return fig - - output_path = base_path / f"{title}.png" - fig.savefig(output_path, bbox_inches="tight") - return output_path - - -def plot_shift( - shift_prevs, - columns, - data, - *, - counts=None, - pos_class=1, - metric="acc", - name="default", - train_prev=None, - legend=True, - return_fig=False, - base_path=None, -) -> Path: - if train_prev is not None: - t_prev_pos = int(round(train_prev[pos_class] * 100)) - title = f"shift_{name}_{t_prev_pos}_{metric}" - else: - title = f"shift_{name}_avg_{metric}" - - if base_path is None: - base_path = utils.get_quacc_home() / "plots" - - fig, ax = plt.subplots() - ax.set_aspect("auto") - ax.grid() - - NUM_COLORS = len(data) - cm = plt.get_cmap("tab10") - if NUM_COLORS > 10: - cm = plt.get_cmap("tab20") - cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) - - shift_prevs = shift_prevs[:, pos_class] - for method, shifts, _cy in zip(columns, data, cy): - ax.plot( - shift_prevs, - shifts, - label=method, - color=_cy["color"], - linestyle="-", - marker="o", - markersize=3, - zorder=2, - ) - if counts is not None: - _col_idx = np.where(columns == method)[0] - count = counts[_col_idx].flatten() - for prev, shift, cnt in zip(shift_prevs, shifts, count): - label = f"{cnt}" - plt.annotate( - label, - (prev, shift), - textcoords="offset points", - xytext=(0, 10), - ha="center", - color=_cy["color"], - fontsize=12.0, - ) - - ax.set(xlabel="dataset shift", ylabel=metric, title=title) - - if legend: - ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) - - if return_fig: - return fig - - output_path = base_path / f"{title}.png" - fig.savefig(output_path, bbox_inches="tight") - - return output_path diff --git a/quacc/plot/__init__.py b/quacc/plot/__init__.py new file mode 100644 index 0000000..6a182c5 --- /dev/null +++ b/quacc/plot/__init__.py @@ -0,0 +1 @@ +from quacc.plot.plot import get_backend, plot_delta, plot_diagonal, plot_shift diff --git a/quacc/plot/base.py b/quacc/plot/base.py new file mode 100644 index 0000000..36a58b2 --- /dev/null +++ b/quacc/plot/base.py @@ -0,0 +1,54 @@ +from pathlib import Path + + +class BasePlot: + @classmethod + def save_fig(cls, fig, base_path, title) -> Path: + ... + + @classmethod + def plot_diagonal( + cls, + reference, + columns, + data, + *, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ): + ... + + @classmethod + def plot_delta( + cls, + base_prevs, + columns, + data, + *, + stdevs=None, + pos_class=1, + title="default", + x_label="prevs.", + y_label="error", + legend=True, + ): + ... + + @classmethod + def plot_shift( + cls, + shift_prevs, + columns, + data, + *, + counts=None, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ): + ... diff --git a/quacc/plot/mpl.py b/quacc/plot/mpl.py new file mode 100644 index 0000000..dd84b7a --- /dev/null +++ b/quacc/plot/mpl.py @@ -0,0 +1,222 @@ +from pathlib import Path + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from cycler import cycler + +from quacc import utils +from quacc.plot.base import BasePlot + +matplotlib.use("agg") + + +class MplPlot(BasePlot): + def _get_markers(self, n: int): + ls = "ovx+sDph*^1234X><.Pd" + if n > len(ls): + ls = ls * (n / len(ls) + 1) + return list(ls)[:n] + + def save_fig(self, fig, base_path, title) -> Path: + if base_path is None: + base_path = utils.get_quacc_home() / "plots" + output_path = base_path / f"{title}.png" + fig.savefig(output_path, bbox_inches="tight") + return output_path + + def plot_delta( + self, + base_prevs, + columns, + data, + *, + stdevs=None, + pos_class=1, + title="default", + x_label="prevs.", + y_label="error", + legend=True, + ): + fig, ax = plt.subplots() + ax.set_aspect("auto") + ax.grid() + + NUM_COLORS = len(data) + cm = plt.get_cmap("tab10") + if NUM_COLORS > 10: + cm = plt.get_cmap("tab20") + cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) + + base_prevs = base_prevs[:, pos_class] + for method, deltas, _cy in zip(columns, data, cy): + ax.plot( + base_prevs, + deltas, + label=method, + color=_cy["color"], + linestyle="-", + marker="o", + markersize=3, + zorder=2, + ) + if stdevs is not None: + _col_idx = np.where(columns == method)[0] + stdev = stdevs[_col_idx].flatten() + nn_idx = np.intersect1d( + np.where(deltas != np.nan)[0], + np.where(stdev != np.nan)[0], + ) + _bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx] + ax.fill_between( + _bps, + _ds - _st, + _ds + _st, + color=_cy["color"], + alpha=0.25, + ) + + ax.set( + xlabel=f"{x_label} prevalence", + ylabel=y_label, + title=title, + ) + + if legend: + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + return fig + + def plot_diagonal( + self, + reference, + columns, + data, + *, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ): + fig, ax = plt.subplots() + ax.set_aspect("auto") + ax.grid() + ax.set_aspect("equal") + + NUM_COLORS = len(data) + cm = plt.get_cmap("tab10") + if NUM_COLORS > 10: + cm = plt.get_cmap("tab20") + cy = cycler( + color=[cm(i) for i in range(NUM_COLORS)], + marker=self._get_markers(NUM_COLORS), + ) + + reference = np.array(reference) + x_ticks = np.unique(reference) + x_ticks.sort() + + for deltas, _cy in zip(data, cy): + ax.plot( + reference, + deltas, + color=_cy["color"], + linestyle="None", + marker=_cy["marker"], + markersize=3, + zorder=2, + alpha=0.25, + ) + + # ensure limits are equal for both axes + _alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1) + _lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)]) + ax.set(xlim=tuple(_lims), ylim=tuple(_lims)) + + for method, deltas, _cy in zip(columns, data, cy): + slope, interc = np.polyfit(reference, deltas, 1) + y_lr = np.array([slope * x + interc for x in _lims]) + ax.plot( + _lims, + y_lr, + label=method, + color=_cy["color"], + linestyle="-", + markersize="0", + zorder=1, + ) + + # plot reference line + ax.plot( + _lims, + _lims, + color="black", + linestyle="--", + markersize=0, + zorder=1, + ) + + ax.set(xlabel=x_label, ylabel=y_label, title=title) + + if legend: + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + return fig + + def plot_shift( + self, + shift_prevs, + columns, + data, + *, + counts=None, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ): + fig, ax = plt.subplots() + ax.set_aspect("auto") + ax.grid() + + NUM_COLORS = len(data) + cm = plt.get_cmap("tab10") + if NUM_COLORS > 10: + cm = plt.get_cmap("tab20") + cy = cycler(color=[cm(i) for i in range(NUM_COLORS)]) + + shift_prevs = shift_prevs[:, pos_class] + for method, shifts, _cy in zip(columns, data, cy): + ax.plot( + shift_prevs, + shifts, + label=method, + color=_cy["color"], + linestyle="-", + marker="o", + markersize=3, + zorder=2, + ) + if counts is not None: + _col_idx = np.where(columns == method)[0] + count = counts[_col_idx].flatten() + for prev, shift, cnt in zip(shift_prevs, shifts, count): + label = f"{cnt}" + plt.annotate( + label, + (prev, shift), + textcoords="offset points", + xytext=(0, 10), + ha="center", + color=_cy["color"], + fontsize=12.0, + ) + + ax.set(xlabel=x_label, ylabel=y_label, title=title) + + if legend: + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + return fig diff --git a/quacc/plot/plot.py b/quacc/plot/plot.py new file mode 100644 index 0000000..1bd2369 --- /dev/null +++ b/quacc/plot/plot.py @@ -0,0 +1,144 @@ +from quacc.plot.base import BasePlot +from quacc.plot.mpl import MplPlot +from quacc.plot.plotly import PlotlyPlot + +__backend: BasePlot = MplPlot() + + +def get_backend(be, theme=None): + match be: + case "matplotlib" | "mpl": + return MplPlot() + case "plotly": + return PlotlyPlot(theme=theme) + case _: + return MplPlot() + + +def plot_delta( + base_prevs, + columns, + data, + *, + stdevs=None, + pos_class=1, + metric="acc", + name="default", + train_prev=None, + legend=True, + avg=None, + save_fig=False, + base_path=None, + backend=None, +): + backend = __backend if backend is None else backend + _base_title = "delta_stdev" if stdevs is not None else "delta" + if train_prev is not None: + t_prev_pos = int(round(train_prev[pos_class] * 100)) + title = f"{_base_title}_{name}_{t_prev_pos}_{metric}" + else: + title = f"{_base_title}_{name}_avg_{avg}_{metric}" + + x_label = f"{'test' if avg is None or avg == 'train' else 'train'} prevalence" + y_label = f"{metric} error" + fig = backend.plot_delta( + base_prevs, + columns, + data, + stdevs=stdevs, + pos_class=pos_class, + title=title, + x_label=x_label, + y_label=y_label, + legend=legend, + ) + + if save_fig: + output_path = backend.save_fig(fig, base_path, title) + return fig, output_path + + return fig + + +def plot_diagonal( + reference, + columns, + data, + *, + pos_class=1, + metric="acc", + name="default", + train_prev=None, + legend=True, + save_fig=False, + base_path=None, + backend=None, +): + backend = __backend if backend is None else backend + if train_prev is not None: + t_prev_pos = int(round(train_prev[pos_class] * 100)) + title = f"diagonal_{name}_{t_prev_pos}_{metric}" + else: + title = f"diagonal_{name}_{metric}" + + x_label = f"true {metric}" + y_label = f"estim. {metric}" + fig = backend.plot_diagonal( + reference, + columns, + data, + pos_class=pos_class, + title=title, + x_label=x_label, + y_label=y_label, + legend=legend, + ) + + if save_fig: + output_path = backend.save_fig(fig, base_path, title) + return fig, output_path + + return fig + + +def plot_shift( + shift_prevs, + columns, + data, + *, + counts=None, + pos_class=1, + metric="acc", + name="default", + train_prev=None, + legend=True, + save_fig=False, + base_path=None, + backend=None, +): + backend = __backend if backend is None else backend + if train_prev is not None: + t_prev_pos = int(round(train_prev[pos_class] * 100)) + title = f"shift_{name}_{t_prev_pos}_{metric}" + else: + title = f"shift_{name}_avg_{metric}" + + x_label = "dataset shift" + y_label = f"{metric} error" + fig = backend.plot_shift( + shift_prevs, + columns, + data, + counts=counts, + pos_class=pos_class, + title=title, + x_label=x_label, + y_label=y_label, + legend=legend, + ) + + if save_fig: + output_path = backend.save_fig(fig, base_path, title) + return fig, output_path + + return fig diff --git a/quacc/plot/plotly.py b/quacc/plot/plotly.py new file mode 100644 index 0000000..074c277 --- /dev/null +++ b/quacc/plot/plotly.py @@ -0,0 +1,201 @@ +from collections import defaultdict +from pathlib import Path + +import numpy as np +import plotly +import plotly.graph_objects as go + +from quacc.plot.base import BasePlot + + +class PlotlyPlot(BasePlot): + __themes = defaultdict( + lambda: { + "template": "seaborn", + } + ) + __themes = __themes | { + "dark": { + "template": "plotly_dark", + }, + } + + def __init__(self, theme=None): + self.theme = PlotlyPlot.__themes[theme] + + def hex_to_rgb(self, hex: str, t: float | None = None): + hex = hex.lstrip("#") + rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]] + if t is not None: + rgb.append(t) + return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}" + + def get_colors(self, num): + match num: + case v if v > 10: + __colors = plotly.colors.qualitative.Light24 + case _: + __colors = plotly.colors.qualitative.Plotly + + def __generator(cs): + while True: + for c in cs: + yield c + + return __generator(__colors) + + def update_layout(self, fig, title, x_label, y_label): + fig.update_layout( + title=title, + xaxis_title=x_label, + yaxis_title=y_label, + template=self.theme["template"], + ) + + def save_fig(self, fig, base_path, title) -> Path: + return None + + def plot_delta( + self, + base_prevs, + columns, + data, + *, + stdevs=None, + pos_class=1, + title="default", + x_label="prevs.", + y_label="error", + legend=True, + ) -> go.Figure: + fig = go.Figure() + x = base_prevs[:, pos_class] + line_colors = self.get_colors(len(columns)) + for name, delta in zip(columns, data): + color = next(line_colors) + _line = [ + go.Scatter( + x=x, + y=delta, + mode="lines+markers", + name=name, + line=dict(color=self.hex_to_rgb(color)), + hovertemplate="prev.: %{x}
error: %{y:,.4f}", + ) + ] + _error = [] + if stdevs is not None: + _col_idx = np.where(columns == name)[0] + stdev = stdevs[_col_idx].flatten() + _error = [ + go.Scatter( + x=np.concatenate([x, x[::-1]]), + y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]), + name=int(_col_idx[0]), + fill="toself", + fillcolor=self.hex_to_rgb(color, t=0.2), + line=dict(color="rgba(255, 255, 255, 0)"), + hoverinfo="skip", + showlegend=False, + ) + ] + fig.add_traces(_line + _error) + + self.update_layout(fig, title, x_label, y_label) + return fig + + def plot_diagonal( + self, + reference, + columns, + data, + *, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ) -> go.Figure: + fig = go.Figure() + x = reference + line_colors = self.get_colors(len(columns)) + + _edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)])) + _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]]) + + for name, val in zip(columns, data): + color = next(line_colors) + slope, interc = np.polyfit(x, val, 1) + y_lr = np.array([slope * _x + interc for _x in _lims[0]]) + fig.add_traces( + [ + go.Scatter( + x=x, + y=val, + customdata=np.stack((val - x,), axis=-1), + mode="markers", + name=name, + line=dict(color=self.hex_to_rgb(color, t=0.5)), + hovertemplate="true acc: %{x:,.4f}
estim. acc: %{y:,.4f}
acc err.: %{customdata[0]:,.4f}", + ), + go.Scatter( + x=_lims[0], + y=y_lr, + mode="lines", + name=name, + line=dict(color=self.hex_to_rgb(color), width=3), + showlegend=False, + ), + ] + ) + fig.add_trace( + go.Scatter( + x=_lims[0], + y=_lims[1], + mode="lines", + name="reference", + showlegend=False, + line=dict(color=self.hex_to_rgb("#000000"), dash="dash"), + ) + ) + + self.update_layout(fig, title, x_label, y_label) + fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0) + return fig + + def plot_shift( + self, + shift_prevs, + columns, + data, + *, + counts=None, + pos_class=1, + title="default", + x_label="true", + y_label="estim.", + legend=True, + ) -> go.Figure: + fig = go.Figure() + x = shift_prevs[:, pos_class] + line_colors = self.get_colors(len(columns)) + for name, delta in zip(columns, data): + col_idx = (columns == name).nonzero()[0][0] + color = next(line_colors) + fig.add_trace( + go.Scatter( + x=x, + y=delta, + customdata=np.stack((counts[col_idx],), axis=-1), + mode="lines+markers", + name=name, + line=dict(color=self.hex_to_rgb(color)), + hovertemplate="shift: %{x}
error: %{y}" + + "
count: %{customdata[0]}" + if counts is not None + else "", + ) + ) + + self.update_layout(fig, title, x_label, y_label) + return fig diff --git a/remote.py b/remote.py index 9b8af19..1f3bf32 100644 --- a/remote.py +++ b/remote.py @@ -24,6 +24,7 @@ __to_sync_up = { "quacc", "baselines", "qcpanel", + "qcdash", ], "file": [ "conf.yaml",