diff --git a/.gitignore b/.gitignore
index 88fb8aa..6a4219b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,11 +6,13 @@ quavenv/*
__pycache__/*
baselines/__pycache__/*
baselines/densratio/__pycache__/*
+qcdash/__pycache__/*
qcpanel/__pycache__/*
quacc/__pycache__/*
quacc/evaluation/__pycache__/*
quacc/method/__pycache__/*
quacc/quantification/__pycache__/*
+quacc/plot/__pycache__/*
tests/__pycache__/*
tests/*/__pycache__/*
tests/*/*/__pycache__/*
diff --git a/poetry.lock b/poetry.lock
index 41efb57..07d4bff 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,9 +1,10 @@
-# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand.
[[package]]
name = "abstention"
version = "0.1.3.1"
description = "Functions for abstention, calibration and label shift domain adaptation"
+category = "main"
optional = false
python-versions = "*"
files = [
@@ -15,10 +16,27 @@ numpy = ">=1.9"
scikit-learn = ">=0.20.0"
scipy = ">=1.1.0"
+[[package]]
+name = "ansi2html"
+version = "1.8.0"
+description = ""
+category = "dev"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "ansi2html-1.8.0-py3-none-any.whl", hash = "sha256:ef9cc9682539dbe524fbf8edad9c9462a308e04bce1170c32daa8fdfd0001785"},
+ {file = "ansi2html-1.8.0.tar.gz", hash = "sha256:38b82a298482a1fa2613f0f9c9beb3db72a8f832eeac58eb2e47bf32cd37f6d5"},
+]
+
+[package.extras]
+docs = ["Sphinx", "setuptools-scm", "sphinx-rtd-theme"]
+test = ["pytest", "pytest-cov"]
+
[[package]]
name = "appnope"
version = "0.1.3"
description = "Disable App Nap on macOS >= 10.9"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -30,6 +48,7 @@ files = [
name = "asttokens"
version = "2.4.1"
description = "Annotate AST trees with source code positions"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -48,6 +67,7 @@ test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"]
name = "bcrypt"
version = "4.0.1"
description = "Modern password hashing for your software and your servers"
+category = "dev"
optional = false
python-versions = ">=3.6"
files = [
@@ -82,6 +102,7 @@ typecheck = ["mypy"]
name = "bleach"
version = "6.1.0"
description = "An easy safelist-based HTML-sanitizing tool."
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -96,10 +117,23 @@ webencodings = "*"
[package.extras]
css = ["tinycss2 (>=1.1.0,<1.3)"]
+[[package]]
+name = "blinker"
+version = "1.7.0"
+description = "Fast, simple object-to-object and broadcast signaling"
+category = "dev"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "blinker-1.7.0-py3-none-any.whl", hash = "sha256:c3f865d4d54db7abc53758a01601cf343fe55b84c1de4e3fa910e420b438d5b9"},
+ {file = "blinker-1.7.0.tar.gz", hash = "sha256:e6820ff6fa4e4d1d8e2747c2283749c3f547e4fee112b98555cdcdae32996182"},
+]
+
[[package]]
name = "bokeh"
version = "3.3.1"
description = "Interactive plots and applications in the browser from Python"
+category = "dev"
optional = false
python-versions = ">=3.9"
files = [
@@ -122,6 +156,7 @@ xyzservices = ">=2021.09.1"
name = "certifi"
version = "2023.7.22"
description = "Python package for providing Mozilla's CA Bundle."
+category = "dev"
optional = false
python-versions = ">=3.6"
files = [
@@ -133,6 +168,7 @@ files = [
name = "cffi"
version = "1.16.0"
description = "Foreign Function Interface for Python calling C code."
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -197,6 +233,7 @@ pycparser = "*"
name = "charset-normalizer"
version = "3.3.2"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
+category = "dev"
optional = false
python-versions = ">=3.7.0"
files = [
@@ -292,10 +329,26 @@ files = [
{file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
]
+[[package]]
+name = "click"
+version = "8.1.7"
+description = "Composable command line interface toolkit"
+category = "dev"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"},
+ {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
[[package]]
name = "colorama"
version = "0.4.6"
description = "Cross-platform colored terminal text."
+category = "main"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [
@@ -307,6 +360,7 @@ files = [
name = "comm"
version = "0.2.0"
description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc."
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -324,6 +378,7 @@ test = ["pytest"]
name = "contourpy"
version = "1.2.0"
description = "Python library for calculating contours of 2D quadrilateral grids"
+category = "main"
optional = false
python-versions = ">=3.9"
files = [
@@ -387,6 +442,7 @@ test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
name = "coverage"
version = "7.3.2"
description = "Code coverage measurement for Python"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -454,6 +510,7 @@ toml = ["tomli"]
name = "cryptography"
version = "41.0.5"
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -499,6 +556,7 @@ test-randomorder = ["pytest-randomly"]
name = "cycler"
version = "0.12.1"
description = "Composable style cycles"
+category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -510,10 +568,100 @@ files = [
docs = ["ipython", "matplotlib", "numpydoc", "sphinx"]
tests = ["pytest", "pytest-cov", "pytest-xdist"]
+[[package]]
+name = "dash"
+version = "2.14.1"
+description = "A Python framework for building reactive web-apps. Developed by Plotly."
+category = "dev"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "dash-2.14.1-py3-none-any.whl", hash = "sha256:ce440ef7416945c9daa8274948483a0aac928a4fec768c0384fd4e9a6196eaf2"},
+ {file = "dash-2.14.1.tar.gz", hash = "sha256:93dc9d665ec5d3720647d4cef4520a1c7cd1bde57e893ffeb7e6cd59781d3294"},
+]
+
+[package.dependencies]
+ansi2html = "*"
+dash-core-components = "2.0.0"
+dash-html-components = "2.0.0"
+dash-table = "5.0.0"
+Flask = ">=1.0.4,<3.1"
+importlib-metadata = {version = "*", markers = "python_version >= \"3.7\""}
+nest-asyncio = "*"
+plotly = ">=5.0.0"
+requests = "*"
+retrying = "*"
+setuptools = "*"
+typing-extensions = ">=4.1.1"
+Werkzeug = "<3.1"
+
+[package.extras]
+celery = ["celery[redis] (>=5.1.2)", "importlib-metadata (<5)", "redis (>=3.5.3)"]
+ci = ["black (==21.6b0)", "black (==22.3.0)", "dash-dangerously-set-inner-html", "dash-flow-example (==0.0.5)", "flake8 (==3.9.2)", "flaky (==3.7.0)", "flask-talisman (==1.0.0)", "isort (==4.3.21)", "jupyterlab (<4.0.0)", "mimesis", "mock (==4.0.3)", "numpy (<=1.25.2)", "openpyxl", "orjson (==3.5.4)", "orjson (==3.6.7)", "pandas (==1.1.5)", "pandas (>=1.4.0)", "preconditions", "pyarrow", "pyarrow (<3)", "pylint (==2.13.5)", "pytest-mock", "pytest-rerunfailures", "pytest-sugar (==0.9.6)", "xlrd (<2)", "xlrd (>=2.0.1)"]
+compress = ["flask-compress"]
+dev = ["PyYAML (>=5.4.1)", "coloredlogs (>=15.0.1)", "fire (>=0.4.0)"]
+diskcache = ["diskcache (>=5.2.1)", "multiprocess (>=0.70.12)", "psutil (>=5.8.0)"]
+testing = ["beautifulsoup4 (>=4.8.2)", "cryptography (<3.4)", "dash-testing-stub (>=0.0.2)", "lxml (>=4.6.2)", "multiprocess (>=0.70.12)", "percy (>=2.0.2)", "psutil (>=5.8.0)", "pytest (>=6.0.2)", "requests[security] (>=2.21.0)", "selenium (>=3.141.0,<=4.2.0)", "waitress (>=1.4.4)"]
+
+[[package]]
+name = "dash-bootstrap-components"
+version = "1.5.0"
+description = "Bootstrap themed components for use in Plotly Dash"
+category = "dev"
+optional = false
+python-versions = ">=3.7, <4"
+files = [
+ {file = "dash-bootstrap-components-1.5.0.tar.gz", hash = "sha256:083158c07434b9965e2d6c3e8ca72dbbe47dab23e676258cef9bf0ad47d2e250"},
+ {file = "dash_bootstrap_components-1.5.0-py3-none-any.whl", hash = "sha256:b487fec1a85e3d6a8564fe04c0a9cd9e846f75ea9e563456ed3879592889c591"},
+]
+
+[package.dependencies]
+dash = ">=2.0.0"
+
+[package.extras]
+pandas = ["numpy", "pandas"]
+
+[[package]]
+name = "dash-core-components"
+version = "2.0.0"
+description = "Core component suite for Dash"
+category = "dev"
+optional = false
+python-versions = "*"
+files = [
+ {file = "dash_core_components-2.0.0-py3-none-any.whl", hash = "sha256:52b8e8cce13b18d0802ee3acbc5e888cb1248a04968f962d63d070400af2e346"},
+ {file = "dash_core_components-2.0.0.tar.gz", hash = "sha256:c6733874af975e552f95a1398a16c2ee7df14ce43fa60bb3718a3c6e0b63ffee"},
+]
+
+[[package]]
+name = "dash-html-components"
+version = "2.0.0"
+description = "Vanilla HTML components for Dash"
+category = "dev"
+optional = false
+python-versions = "*"
+files = [
+ {file = "dash_html_components-2.0.0-py3-none-any.whl", hash = "sha256:b42cc903713c9706af03b3f2548bda4be7307a7cf89b7d6eae3da872717d1b63"},
+ {file = "dash_html_components-2.0.0.tar.gz", hash = "sha256:8703a601080f02619a6390998e0b3da4a5daabe97a1fd7a9cebc09d015f26e50"},
+]
+
+[[package]]
+name = "dash-table"
+version = "5.0.0"
+description = "Dash table"
+category = "dev"
+optional = false
+python-versions = "*"
+files = [
+ {file = "dash_table-5.0.0-py3-none-any.whl", hash = "sha256:19036fa352bb1c11baf38068ec62d172f0515f73ca3276c79dee49b95ddc16c9"},
+ {file = "dash_table-5.0.0.tar.gz", hash = "sha256:18624d693d4c8ef2ddec99a6f167593437a7ea0bf153aa20f318c170c5bc7308"},
+]
+
[[package]]
name = "debugpy"
version = "1.8.0"
description = "An implementation of the Debug Adapter Protocol for Python"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -541,6 +689,7 @@ files = [
name = "decorator"
version = "5.1.1"
description = "Decorators for Humans"
+category = "dev"
optional = false
python-versions = ">=3.5"
files = [
@@ -552,6 +701,7 @@ files = [
name = "exceptiongroup"
version = "1.1.3"
description = "Backport of PEP 654 (exception groups)"
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -566,6 +716,7 @@ test = ["pytest (>=6)"]
name = "executing"
version = "2.0.1"
description = "Get the currently executing AST node of a frame, and other information"
+category = "dev"
optional = false
python-versions = ">=3.5"
files = [
@@ -576,10 +727,34 @@ files = [
[package.extras]
tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"]
+[[package]]
+name = "flask"
+version = "3.0.0"
+description = "A simple framework for building complex web applications."
+category = "dev"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "flask-3.0.0-py3-none-any.whl", hash = "sha256:21128f47e4e3b9d597a3e8521a329bf56909b690fcc3fa3e477725aa81367638"},
+ {file = "flask-3.0.0.tar.gz", hash = "sha256:cfadcdb638b609361d29ec22360d6070a77d7463dcb3ab08d2c2f2f168845f58"},
+]
+
+[package.dependencies]
+blinker = ">=1.6.2"
+click = ">=8.1.3"
+itsdangerous = ">=2.1.2"
+Jinja2 = ">=3.1.2"
+Werkzeug = ">=3.0.0"
+
+[package.extras]
+async = ["asgiref (>=3.2)"]
+dotenv = ["python-dotenv"]
+
[[package]]
name = "fonttools"
version = "4.44.0"
description = "Tools to manipulate font files"
+category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -641,10 +816,32 @@ ufo = ["fs (>=2.2.0,<3)"]
unicode = ["unicodedata2 (>=15.1.0)"]
woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"]
+[[package]]
+name = "gunicorn"
+version = "21.2.0"
+description = "WSGI HTTP Server for UNIX"
+category = "dev"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "gunicorn-21.2.0-py3-none-any.whl", hash = "sha256:3213aa5e8c24949e792bcacfc176fef362e7aac80b76c56f6b5122bf350722f0"},
+ {file = "gunicorn-21.2.0.tar.gz", hash = "sha256:88ec8bff1d634f98e61b9f65bc4bf3cd918a90806c6f5c48bc5603849ec81033"},
+]
+
+[package.dependencies]
+packaging = "*"
+
+[package.extras]
+eventlet = ["eventlet (>=0.24.1)"]
+gevent = ["gevent (>=1.4.0)"]
+setproctitle = ["setproctitle"]
+tornado = ["tornado (>=0.2)"]
+
[[package]]
name = "idna"
version = "3.4"
description = "Internationalized Domain Names in Applications (IDNA)"
+category = "dev"
optional = false
python-versions = ">=3.5"
files = [
@@ -652,10 +849,31 @@ files = [
{file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
]
+[[package]]
+name = "importlib-metadata"
+version = "6.8.0"
+description = "Read metadata from Python packages"
+category = "dev"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"},
+ {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"},
+]
+
+[package.dependencies]
+zipp = ">=0.5"
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+perf = ["ipython"]
+testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
+
[[package]]
name = "iniconfig"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -667,6 +885,7 @@ files = [
name = "ipykernel"
version = "6.26.0"
description = "IPython Kernel for Jupyter"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -680,7 +899,7 @@ comm = ">=0.1.1"
debugpy = ">=1.6.5"
ipython = ">=7.23.1"
jupyter-client = ">=6.1.12"
-jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
+jupyter-core = ">=4.12,<5.0.0 || >=5.1.0"
matplotlib-inline = ">=0.1"
nest-asyncio = "*"
packaging = "*"
@@ -700,6 +919,7 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio"
name = "ipympl"
version = "0.9.3"
description = "Matplotlib Jupyter Extension"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -723,6 +943,7 @@ docs = ["Sphinx (>=1.5)", "myst-nb", "sphinx-book-theme", "sphinx-copybutton", "
name = "ipython"
version = "8.17.2"
description = "IPython: Productive Interactive Computing"
+category = "dev"
optional = false
python-versions = ">=3.9"
files = [
@@ -760,6 +981,7 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.22)", "pa
name = "ipython-genutils"
version = "0.2.0"
description = "Vestigial utilities from IPython"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -771,6 +993,7 @@ files = [
name = "ipywidgets"
version = "8.1.1"
description = "Jupyter interactive widgets"
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -792,6 +1015,7 @@ test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"]
name = "ipywidgets-bokeh"
version = "1.5.0"
description = "Allows embedding of Jupyter widgets in Bokeh layouts."
+category = "dev"
optional = false
python-versions = ">=3.9"
files = [
@@ -800,17 +1024,30 @@ files = [
]
[package.dependencies]
-bokeh = "==3.*"
-ipykernel = ">=6.dev0,<6.18.0 || >6.18.0,<7.dev0"
-ipywidgets = "==8.*"
+bokeh = ">=3.0.0,<4.0.0"
+ipykernel = ">=6.0.0,<6.18.0 || >6.18.0,<7.0.0"
+ipywidgets = ">=8.0.0,<9.0.0"
[package.extras]
dev = ["anywidget (>=0.3.0)", "panel (>=1.0.4)", "pytest (>=7.3.1)", "pytest-cov (>=4.1.0)", "pytest-playwright (>=0.3.3)"]
+[[package]]
+name = "itsdangerous"
+version = "2.1.2"
+description = "Safely pass data to untrusted environments and back."
+category = "dev"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "itsdangerous-2.1.2-py3-none-any.whl", hash = "sha256:2c2349112351b88699d8d4b6b075022c0808887cb7ad10069318a8b0bc88db44"},
+ {file = "itsdangerous-2.1.2.tar.gz", hash = "sha256:5dbbc68b317e5e42f327f9021763545dc3fc3bfe22e6deb96aaf1fc38874156a"},
+]
+
[[package]]
name = "jedi"
version = "0.19.1"
description = "An autocompletion tool for Python that can be used for text editors."
+category = "dev"
optional = false
python-versions = ">=3.6"
files = [
@@ -830,6 +1067,7 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"]
name = "jinja2"
version = "3.1.2"
description = "A very fast and expressive template engine."
+category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -847,6 +1085,7 @@ i18n = ["Babel (>=2.7)"]
name = "joblib"
version = "1.3.2"
description = "Lightweight pipelining with Python functions"
+category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -858,6 +1097,7 @@ files = [
name = "jupyter-client"
version = "8.6.0"
description = "Jupyter protocol implementation and client libraries"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -866,7 +1106,7 @@ files = [
]
[package.dependencies]
-jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
+jupyter-core = ">=4.12,<5.0.0 || >=5.1.0"
python-dateutil = ">=2.8.2"
pyzmq = ">=23.0"
tornado = ">=6.2"
@@ -880,6 +1120,7 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt
name = "jupyter-core"
version = "5.5.0"
description = "Jupyter core package. A base package on which Jupyter projects rely."
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -900,6 +1141,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"]
name = "jupyterlab-widgets"
version = "3.0.9"
description = "Jupyter interactive widgets for JupyterLab"
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -911,6 +1153,7 @@ files = [
name = "kiwisolver"
version = "1.4.5"
description = "A fast implementation of the Cassowary constraint solver"
+category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -1024,6 +1267,7 @@ files = [
name = "linkify-it-py"
version = "2.0.2"
description = "Links recognition library with FULL unicode support."
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -1044,6 +1288,7 @@ test = ["coverage", "pytest", "pytest-cov"]
name = "logging"
version = "0.4.9.6"
description = "A logging module for Python"
+category = "main"
optional = false
python-versions = "*"
files = [
@@ -1054,6 +1299,7 @@ files = [
name = "markdown"
version = "3.5.1"
description = "Python implementation of John Gruber's Markdown."
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -1069,6 +1315,7 @@ testing = ["coverage", "pyyaml"]
name = "markdown-it-py"
version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -1093,6 +1340,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
name = "markupsafe"
version = "2.1.3"
description = "Safely add untrusted strings to HTML/XML markup."
+category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -1162,6 +1410,7 @@ files = [
name = "matplotlib"
version = "3.8.1"
description = "Python plotting package"
+category = "main"
optional = false
python-versions = ">=3.9"
files = [
@@ -1210,6 +1459,7 @@ python-dateutil = ">=2.7"
name = "matplotlib-inline"
version = "0.1.6"
description = "Inline Matplotlib backend for Jupyter"
+category = "dev"
optional = false
python-versions = ">=3.5"
files = [
@@ -1224,6 +1474,7 @@ traitlets = "*"
name = "mdit-py-plugins"
version = "0.4.0"
description = "Collection of plugins for markdown-it-py"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -1243,6 +1494,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
name = "mdurl"
version = "0.1.2"
description = "Markdown URL utilities"
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -1254,6 +1506,7 @@ files = [
name = "nest-asyncio"
version = "1.5.8"
description = "Patch asyncio to allow nested event loops"
+category = "dev"
optional = false
python-versions = ">=3.5"
files = [
@@ -1265,6 +1518,7 @@ files = [
name = "numpy"
version = "1.26.2"
description = "Fundamental package for array computing in Python"
+category = "main"
optional = false
python-versions = ">=3.9"
files = [
@@ -1310,6 +1564,7 @@ files = [
name = "packaging"
version = "23.2"
description = "Core utilities for Python packages"
+category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -1321,6 +1576,7 @@ files = [
name = "pandas"
version = "2.1.2"
description = "Powerful data structures for data analysis, time series, and statistics"
+category = "main"
optional = false
python-versions = ">=3.9"
files = [
@@ -1389,6 +1645,7 @@ xml = ["lxml (>=4.8.0)"]
name = "pandas-stubs"
version = "2.1.1.230928"
description = "Type annotations for pandas"
+category = "dev"
optional = false
python-versions = ">=3.9"
files = [
@@ -1404,6 +1661,7 @@ types-pytz = ">=2022.1.1"
name = "panel"
version = "1.3.1"
description = "The powerful data exploration & web app framework for Python."
+category = "dev"
optional = false
python-versions = ">=3.9"
files = [
@@ -1441,6 +1699,7 @@ ui = ["jupyter-server", "playwright", "pytest-playwright"]
name = "param"
version = "2.0.1"
description = "Make your Python code clearer and more reliable by declaring Parameters."
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -1462,6 +1721,7 @@ tests-full = ["cloudpickle", "gmpy", "ipython", "jsonschema", "nest-asyncio", "n
name = "paramiko"
version = "3.3.1"
description = "SSH2 protocol library"
+category = "dev"
optional = false
python-versions = ">=3.6"
files = [
@@ -1483,6 +1743,7 @@ invoke = ["invoke (>=2.0)"]
name = "parso"
version = "0.8.3"
description = "A Python Parser"
+category = "dev"
optional = false
python-versions = ">=3.6"
files = [
@@ -1498,6 +1759,7 @@ testing = ["docopt", "pytest (<6.0.0)"]
name = "pexpect"
version = "4.8.0"
description = "Pexpect allows easy control of interactive console applications."
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -1512,6 +1774,7 @@ ptyprocess = ">=0.5"
name = "pillow"
version = "10.1.0"
description = "Python Imaging Library (Fork)"
+category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -1579,6 +1842,7 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa
name = "platformdirs"
version = "4.0.0"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -1590,10 +1854,27 @@ files = [
docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"]
+[[package]]
+name = "plotly"
+version = "5.18.0"
+description = "An open-source, interactive data visualization library for Python"
+category = "dev"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "plotly-5.18.0-py3-none-any.whl", hash = "sha256:23aa8ea2f4fb364a20d34ad38235524bd9d691bf5299e800bca608c31e8db8de"},
+ {file = "plotly-5.18.0.tar.gz", hash = "sha256:360a31e6fbb49d12b007036eb6929521343d6bee2236f8459915821baefa2cbb"},
+]
+
+[package.dependencies]
+packaging = "*"
+tenacity = ">=6.2.0"
+
[[package]]
name = "pluggy"
version = "1.3.0"
description = "plugin and hook calling mechanisms for python"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -1609,6 +1890,7 @@ testing = ["pytest", "pytest-benchmark"]
name = "prompt-toolkit"
version = "3.0.40"
description = "Library for building powerful interactive command lines in Python"
+category = "dev"
optional = false
python-versions = ">=3.7.0"
files = [
@@ -1623,6 +1905,7 @@ wcwidth = "*"
name = "psutil"
version = "5.9.6"
description = "Cross-platform lib for process and system monitoring in Python."
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
files = [
@@ -1651,6 +1934,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
name = "ptyprocess"
version = "0.7.0"
description = "Run a subprocess in a pseudo terminal"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -1662,6 +1946,7 @@ files = [
name = "pure-eval"
version = "0.2.2"
description = "Safely evaluate AST nodes without side effects"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -1676,6 +1961,7 @@ tests = ["pytest"]
name = "pyarrow"
version = "14.0.1"
description = "Python library for Apache Arrow"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -1724,6 +2010,7 @@ numpy = ">=1.16.6"
name = "pycparser"
version = "2.21"
description = "C parser in Python"
+category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
@@ -1735,6 +2022,7 @@ files = [
name = "pygments"
version = "2.16.1"
description = "Pygments is a syntax highlighting package written in Python."
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -1749,6 +2037,7 @@ plugins = ["importlib-metadata"]
name = "pylance"
version = "0.5.10"
description = "python wrapper for lance-rs"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -1770,6 +2059,7 @@ tests = ["duckdb", "ml_dtypes", "pandas (>=1.4)", "polars[pandas,pyarrow]", "pyt
name = "pynacl"
version = "1.5.0"
description = "Python binding to the Networking and Cryptography (NaCl) library"
+category = "dev"
optional = false
python-versions = ">=3.6"
files = [
@@ -1796,6 +2086,7 @@ tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
name = "pyparsing"
version = "3.1.1"
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
+category = "main"
optional = false
python-versions = ">=3.6.8"
files = [
@@ -1810,6 +2101,7 @@ diagrams = ["jinja2", "railroad-diagrams"]
name = "pytest"
version = "7.4.3"
description = "pytest: simple powerful testing with Python"
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -1832,6 +2124,7 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no
name = "pytest-cov"
version = "4.1.0"
description = "Pytest plugin for measuring coverage."
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -1850,6 +2143,7 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale
name = "pytest-mock"
version = "3.12.0"
description = "Thin-wrapper around the mock package for easier use with pytest"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -1867,6 +2161,7 @@ dev = ["pre-commit", "pytest-asyncio", "tox"]
name = "python-dateutil"
version = "2.8.2"
description = "Extensions to the standard Python datetime module"
+category = "main"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [
@@ -1881,6 +2176,7 @@ six = ">=1.5"
name = "pytz"
version = "2023.3.post1"
description = "World timezone definitions, modern and historical"
+category = "main"
optional = false
python-versions = "*"
files = [
@@ -1892,6 +2188,7 @@ files = [
name = "pyviz-comms"
version = "3.0.0"
description = "A JupyterLab extension for rendering HoloViz content."
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -1911,6 +2208,7 @@ tests = ["flake8", "pytest"]
name = "pywin32"
version = "306"
description = "Python for Window Extensions"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -1934,6 +2232,7 @@ files = [
name = "pyyaml"
version = "6.0.1"
description = "YAML parser and emitter for Python"
+category = "main"
optional = false
python-versions = ">=3.6"
files = [
@@ -1993,6 +2292,7 @@ files = [
name = "pyzmq"
version = "25.1.1"
description = "Python bindings for 0MQ"
+category = "dev"
optional = false
python-versions = ">=3.6"
files = [
@@ -2098,6 +2398,7 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""}
name = "quapy"
version = "0.1.7"
description = "QuaPy: a framework for Quantification in Python"
+category = "main"
optional = false
python-versions = ">=3.6, <4"
files = [
@@ -2118,6 +2419,7 @@ xlrd = "*"
name = "requests"
version = "2.31.0"
description = "Python HTTP for Humans."
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -2135,10 +2437,26 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+[[package]]
+name = "retrying"
+version = "1.3.4"
+description = "Retrying"
+category = "dev"
+optional = false
+python-versions = "*"
+files = [
+ {file = "retrying-1.3.4-py3-none-any.whl", hash = "sha256:8cc4d43cb8e1125e0ff3344e9de678fefd85db3b750b81b2240dc0183af37b35"},
+ {file = "retrying-1.3.4.tar.gz", hash = "sha256:345da8c5765bd982b1d1915deb9102fd3d1f7ad16bd84a9700b85f64d24e8f3e"},
+]
+
+[package.dependencies]
+six = ">=1.7.0"
+
[[package]]
name = "scikit-learn"
version = "1.3.2"
description = "A set of python modules for machine learning and data mining"
+category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -2186,6 +2504,7 @@ tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc (
name = "scipy"
version = "1.11.4"
description = "Fundamental algorithms for scientific computing in Python"
+category = "main"
optional = false
python-versions = ">=3.9"
files = [
@@ -2224,10 +2543,28 @@ dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyl
doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"]
test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+[[package]]
+name = "setuptools"
+version = "69.0.2"
+description = "Easily download, build, install, upgrade, and uninstall Python packages"
+category = "dev"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "setuptools-69.0.2-py3-none-any.whl", hash = "sha256:1e8fdff6797d3865f37397be788a4e3cba233608e9b509382a2777d25ebde7f2"},
+ {file = "setuptools-69.0.2.tar.gz", hash = "sha256:735896e78a4742605974de002ac60562d286fa8051a7e2299445e8e8fbb01aa6"},
+]
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
+testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
+testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
+
[[package]]
name = "six"
version = "1.16.0"
description = "Python 2 and 3 compatibility utilities"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
files = [
@@ -2239,6 +2576,7 @@ files = [
name = "stack-data"
version = "0.6.3"
description = "Extract data from python stack frames and tracebacks for informative displays"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -2258,6 +2596,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
name = "tabulate"
version = "0.9.0"
description = "Pretty-print tabular data"
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -2268,10 +2607,26 @@ files = [
[package.extras]
widechars = ["wcwidth"]
+[[package]]
+name = "tenacity"
+version = "8.2.3"
+description = "Retry code until it succeeds"
+category = "dev"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"},
+ {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"},
+]
+
+[package.extras]
+doc = ["reno", "sphinx", "tornado (>=4.5)"]
+
[[package]]
name = "threadpoolctl"
version = "3.2.0"
description = "threadpoolctl"
+category = "main"
optional = false
python-versions = ">=3.8"
files = [
@@ -2283,6 +2638,7 @@ files = [
name = "tomli"
version = "2.0.1"
description = "A lil' TOML parser"
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -2294,6 +2650,7 @@ files = [
name = "tornado"
version = "6.3.3"
description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
+category = "dev"
optional = false
python-versions = ">= 3.8"
files = [
@@ -2314,6 +2671,7 @@ files = [
name = "tqdm"
version = "4.66.1"
description = "Fast, Extensible Progress Meter"
+category = "main"
optional = false
python-versions = ">=3.7"
files = [
@@ -2334,6 +2692,7 @@ telegram = ["requests"]
name = "traitlets"
version = "5.13.0"
description = "Traitlets Python configuration system"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -2349,6 +2708,7 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.6.0)", "pre-commit", "pytest (>=7.0,
name = "types-pytz"
version = "2023.3.1.1"
description = "Typing stubs for pytz"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -2360,6 +2720,7 @@ files = [
name = "typing-extensions"
version = "4.8.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -2371,6 +2732,7 @@ files = [
name = "tzdata"
version = "2023.3"
description = "Provider of IANA time zone data"
+category = "main"
optional = false
python-versions = ">=2"
files = [
@@ -2382,6 +2744,7 @@ files = [
name = "uc-micro-py"
version = "1.0.2"
description = "Micro subset of unicode data files for linkify-it-py projects."
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -2396,6 +2759,7 @@ test = ["coverage", "pytest", "pytest-cov"]
name = "urllib3"
version = "2.0.7"
description = "HTTP library with thread-safe connection pooling, file post, and more."
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -2413,6 +2777,7 @@ zstd = ["zstandard (>=0.18.0)"]
name = "wcwidth"
version = "0.2.9"
description = "Measures the displayed width of unicode strings in a terminal"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -2424,6 +2789,7 @@ files = [
name = "webencodings"
version = "0.5.1"
description = "Character encoding aliases for legacy web content"
+category = "dev"
optional = false
python-versions = "*"
files = [
@@ -2431,10 +2797,29 @@ files = [
{file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"},
]
+[[package]]
+name = "werkzeug"
+version = "3.0.1"
+description = "The comprehensive WSGI web application library."
+category = "dev"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"},
+ {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"},
+]
+
+[package.dependencies]
+MarkupSafe = ">=2.1.1"
+
+[package.extras]
+watchdog = ["watchdog (>=2.3)"]
+
[[package]]
name = "widgetsnbextension"
version = "4.0.9"
description = "Jupyter interactive widgets for Jupyter Notebook"
+category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@@ -2446,6 +2831,7 @@ files = [
name = "xlrd"
version = "2.0.1"
description = "Library for developers to extract data from Microsoft Excel (tm) .xls spreadsheet files"
+category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
files = [
@@ -2462,6 +2848,7 @@ test = ["pytest", "pytest-cov"]
name = "xyzservices"
version = "2023.10.1"
description = "Source of XYZ tiles providers"
+category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@@ -2469,7 +2856,23 @@ files = [
{file = "xyzservices-2023.10.1.tar.gz", hash = "sha256:091229269043bc8258042edbedad4fcb44684b0473ede027b5672ad40dc9fa02"},
]
+[[package]]
+name = "zipp"
+version = "3.17.0"
+description = "Backport of pathlib-compatible object wrapper for zip files"
+category = "dev"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
+ {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
+]
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
+
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "036900e0883a559f83033add6bf8aa6f3178bd1c8fe8ea68bf504610157a1543"
+content-hash = "244fa10a77087a9c49906733a2af33d1feacb62e1a32f9de580d531378d3eceb"
diff --git a/pyproject.toml b/pyproject.toml
index 36cec09..4a7e446 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,6 +18,7 @@ abstention = "^0.1.3.1"
main = "quacc.main:main"
run = "run:run"
panel = "qcpanel.run:run"
+dash = "qcdash.app:run"
sync_up = "remote:sync_code"
sync_down = "remote:sync_output"
merge_data = "merge_data:run"
@@ -27,6 +28,7 @@ poetry_command = ""
[tool.poe.tasks]
ilona = "ssh volpi@ilona.isti.cnr.it"
+dash = "gunicorn qcdash.app:server -b ilona.isti.cnr.it:33421"
[tool.poe.tasks.logr]
shell = """
@@ -48,6 +50,9 @@ ipympl = "^0.9.3"
ipykernel = "^6.26.0"
ipywidgets-bokeh = "^1.5.0"
pandas-stubs = "^2.1.1.230928"
+dash = "^2.14.1"
+dash-bootstrap-components = "^1.5.0"
+gunicorn = "^21.2.0"
[tool.pytest.ini_options]
addopts = "--cov=quacc --capture=tee-sys"
diff --git a/qcdash/__init__.py b/qcdash/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/qcdash/app.py b/qcdash/app.py
new file mode 100644
index 0000000..d91dae4
--- /dev/null
+++ b/qcdash/app.py
@@ -0,0 +1,413 @@
+import json
+import os
+from collections import defaultdict
+from json import JSONDecodeError
+from pathlib import Path
+from typing import List
+from urllib.parse import parse_qsl, quote, urlencode, urlparse
+
+import dash_bootstrap_components as dbc
+import numpy as np
+from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
+from dash.dash_table.Format import Format, Scheme
+
+from quacc import plot
+from quacc.evaluation.estimators import CE
+from quacc.evaluation.report import CompReport, DatasetReport
+from quacc.evaluation.stats import ttest_rel
+
+backend = plot.get_backend("plotly")
+
+valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
+valid_plot_modes["avg"] = DatasetReport._default_dr_modes
+
+
+def get_datasets(root: str | Path) -> List[DatasetReport]:
+ def load_dataset(dataset):
+ dataset = Path(dataset)
+ return DatasetReport.unpickle(dataset)
+
+ def explore_datasets(root: str | Path) -> List[Path]:
+ if isinstance(root, str):
+ root = Path(root)
+
+ if root.name == "plot":
+ return []
+
+ if not root.exists():
+ return []
+
+ dr_paths = []
+ for f in os.listdir(root):
+ if (root / f).is_dir():
+ dr_paths += explore_datasets(root / f)
+ elif f == f"{root.name}.pickle":
+ dr_paths.append(root / f)
+
+ return dr_paths
+
+ dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
+ return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
+
+
+def get_fig(dr: DatasetReport, metric, estimators, view, mode):
+ estimators = CE.name[estimators]
+ match (view, mode):
+ case ("avg", _):
+ return dr.get_plots(
+ mode=mode,
+ metric=metric,
+ estimators=estimators,
+ conf="plotly",
+ save_fig=False,
+ backend=backend,
+ )
+ case (_, _):
+ cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)]
+ return cr.get_plots(
+ mode=mode,
+ metric=metric,
+ estimators=estimators,
+ conf="plotly",
+ save_fig=False,
+ backend=backend,
+ )
+
+
+def get_table(dr: DatasetReport, metric, estimators, view, mode):
+ estimators = CE.name[estimators]
+ _prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
+ match (view, mode):
+ case ("avg", "train_table"):
+ return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
+ case ("avg", "test_table"):
+ return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
+ case ("avg", "shift_table"):
+ return (
+ dr.shift_data(metric=metric, estimators=estimators)
+ .groupby(level=0)
+ .mean()
+ )
+ case ("avg", "stats_table"):
+ return ttest_rel(dr, metric=metric, estimators=estimators)
+ case (_, "train_table"):
+ cr = dr.crs[_prevs.index(view)]
+ return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
+ case (_, "shift_table"):
+ cr = dr.crs[_prevs.index(view)]
+ return (
+ cr.shift_data(metric=metric, estimators=estimators)
+ .groupby(level=0)
+ .mean()
+ )
+
+
+def get_DataTable(df):
+ _primary = "#0d6efd"
+ if df.empty:
+ return None
+
+ df = df.reset_index()
+ columns = {
+ c: dict(
+ id=c,
+ name=c,
+ type="numeric",
+ format=Format(precision=6, scheme=Scheme.exponent),
+ )
+ for c in df.columns
+ }
+ columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
+ columns = list(columns.values())
+ data = df.to_dict("records")
+
+ return html.Div(
+ [
+ dash_table.DataTable(
+ data=data,
+ columns=columns,
+ id="table1",
+ style_cell={
+ "padding": "0 12px",
+ "border": "0",
+ "border-bottom": f"1px solid {_primary}",
+ },
+ style_table={
+ "margin": "6vh 15px",
+ "padding": "15px",
+ "maxWidth": "80vw",
+ "overflowX": "auto",
+ "border": f"0px solid {_primary}",
+ "border-radius": "6px",
+ },
+ )
+ ],
+ style={
+ "display": "flex",
+ "flex-direction": "column",
+ # "justify-content": "center",
+ "align-items": "center",
+ "height": "100vh",
+ },
+ )
+
+
+def get_Graph(fig):
+ if fig is None:
+ return None
+
+ return dcc.Graph(
+ id="graph1",
+ figure=fig,
+ style={
+ "margin": 0,
+ "height": "100vh",
+ },
+ )
+
+
+datasets = get_datasets("output")
+
+app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
+# app.config.suppress_callback_exceptions = True
+sidebar_style = {
+ "top": 0,
+ "left": 0,
+ "bottom": 0,
+ "padding": "1vw",
+ "padding-top": "2vw",
+ "margin": "0px",
+ "flex": 1,
+ "overflow": "scroll",
+ "height": "100vh",
+}
+
+content_style = {
+ "flex": 5,
+ "maxWidth": "84vw",
+}
+
+
+def parse_href(href: str):
+ parse_result = urlparse(href)
+ params = parse_qsl(parse_result.query)
+ return dict(params)
+
+
+def get_sidebar():
+ return [
+ html.H4("Parameters:", style={"margin-bottom": "1vw"}),
+ dbc.Select(
+ # options=list(datasets.keys()),
+ # value=list(datasets.keys())[0],
+ id="dataset",
+ ),
+ dbc.Select(
+ id="metric",
+ style={"margin-top": "1vh"},
+ ),
+ html.Div(
+ [
+ dbc.RadioItems(
+ id="view",
+ class_name="btn-group mt-3",
+ input_class_name="btn-check",
+ label_class_name="btn btn-outline-primary",
+ label_checked_class_name="active",
+ ),
+ dbc.RadioItems(
+ id="mode",
+ class_name="btn-group mt-3",
+ input_class_name="btn-check",
+ label_class_name="btn btn-outline-primary",
+ label_checked_class_name="active",
+ ),
+ ],
+ className="radio-group-v d-flex justify-content-around",
+ ),
+ html.Div(
+ [
+ dbc.Checklist(
+ id="estimators",
+ className="btn-group mt-3",
+ inputClassName="btn-check",
+ labelClassName="btn btn-outline-primary",
+ labelCheckedClassName="active",
+ ),
+ ],
+ className="radio-group-wide",
+ ),
+ ]
+
+
+app.layout = html.Div(
+ [
+ dcc.Interval(id="reload", interval=10 * 60 * 1000),
+ dcc.Location(id="url", refresh=False),
+ html.Div(
+ [
+ html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
+ html.Div(id="app_content", style=content_style),
+ ],
+ id="page_layout",
+ style={"display": "flex", "flexDirection": "row"},
+ ),
+ ]
+)
+
+server = app.server
+
+
+def apply_param(href, triggered_id, id, curr):
+ match triggered_id:
+ case "url":
+ params = parse_href(href)
+ return params.get(id, None)
+ case _:
+ return curr
+
+
+@callback(
+ Output("dataset", "value"),
+ Output("dataset", "options"),
+ Input("url", "href"),
+ Input("reload", "n_intervals"),
+ State("dataset", "value"),
+)
+def update_dataset(href, n_intervals, dataset):
+ match ctx.triggered_id:
+ case "reload":
+ new_datasets = get_datasets("output")
+ global datasets
+ datasets = new_datasets
+ req_dataset = dataset
+ case "url":
+ params = parse_href(href)
+ req_dataset = params.get("dataset", None)
+
+ available_datasets = list(datasets.keys())
+ new_dataset = (
+ req_dataset if req_dataset in available_datasets else available_datasets[0]
+ )
+ return new_dataset, available_datasets
+
+
+@callback(
+ Output("metric", "options"),
+ Output("metric", "value"),
+ Input("url", "href"),
+ Input("dataset", "value"),
+ State("metric", "value"),
+)
+def update_metrics(href, dataset, curr_metric):
+ dr = datasets[dataset]
+ old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
+ valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
+ new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
+ return valid_metrics, new_metric
+
+
+@callback(
+ Output("estimators", "options"),
+ Output("estimators", "value"),
+ Input("url", "href"),
+ Input("dataset", "value"),
+ Input("metric", "value"),
+ State("estimators", "value"),
+)
+def update_estimators(href, dataset, metric, curr_estimators):
+ dr = datasets[dataset]
+ old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
+ if isinstance(old_estimators, str):
+ try:
+ old_estimators = json.loads(old_estimators)
+ except JSONDecodeError:
+ old_estimators = []
+ valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy()
+ new_estimators = valid_estimators[
+ np.isin(valid_estimators, old_estimators)
+ ].tolist()
+ return valid_estimators, new_estimators
+
+
+@callback(
+ Output("view", "options"),
+ Output("view", "value"),
+ Input("url", "href"),
+ Input("dataset", "value"),
+ State("view", "value"),
+)
+def update_view(href, dataset, curr_view):
+ dr = datasets[dataset]
+ old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
+ valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
+ new_view = old_view if old_view in valid_views else valid_views[0]
+ return valid_views, new_view
+
+
+@callback(
+ Output("mode", "options"),
+ Output("mode", "value"),
+ Input("url", "href"),
+ Input("view", "value"),
+ State("mode", "value"),
+)
+def update_mode(href, view, curr_mode):
+ old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode)
+ valid_modes = valid_plot_modes[view]
+ new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
+ return valid_modes, new_mode
+
+
+@callback(
+ Output("app_content", "children"),
+ Output("url", "search"),
+ Input("dataset", "value"),
+ Input("metric", "value"),
+ Input("estimators", "value"),
+ Input("view", "value"),
+ Input("mode", "value"),
+)
+def update_content(dataset, metric, estimators, view, mode):
+ search = urlencode(
+ dict(
+ dataset=dataset,
+ metric=metric,
+ estimators=json.dumps(estimators),
+ view=view,
+ mode=mode,
+ ),
+ quote_via=quote,
+ )
+ dr = datasets[dataset]
+ match mode:
+ case m if m.endswith("table"):
+ df = get_table(
+ dr=dr,
+ metric=metric,
+ estimators=estimators,
+ view=view,
+ mode=mode,
+ )
+ dt = get_DataTable(df)
+ app_content = [] if dt is None else [dt]
+ case _:
+ fig = get_fig(
+ dr=dr,
+ metric=metric,
+ estimators=estimators,
+ view=view,
+ mode=mode,
+ )
+ g = get_Graph(fig)
+ app_content = [] if g is None else [g]
+
+ return app_content, f"?{search}"
+
+
+def run():
+ app.run(debug=True)
+
+
+if __name__ == "__main__":
+ run()
diff --git a/qcdash/assets/radio_group.css b/qcdash/assets/radio_group.css
new file mode 100644
index 0000000..b6d9e67
--- /dev/null
+++ b/qcdash/assets/radio_group.css
@@ -0,0 +1,86 @@
+/* restyle radio items */
+.radio-group .form-check {
+ padding-left: 0;
+}
+
+.radio-group .btn-group > .form-check:not(:last-child) > .btn {
+ border-top-right-radius: 0;
+ border-bottom-right-radius: 0;
+}
+
+.radio-group .btn-group > .form-check:not(:first-child) > .btn {
+ border-top-left-radius: 0;
+ border-bottom-left-radius: 0;
+ margin-left: -1px;
+}
+
+.radio-group-v{
+ padding: 0 10px;
+}
+.radio-group-v .form-check {
+ padding-top: 0;
+ padding-left: 2px;
+}
+
+.radio-group-v .btn-group {
+ flex-direction: column;
+ justify-content: center;
+ align-items: stretch;
+ flex-grow: 1;
+}
+
+.radio-group-v > .btn-group:first-child {
+ flex-grow: 0;
+}
+
+.radio-group-v > .btn-group:last-child {
+ margin-left: 20px;
+}
+
+.radio-group-v .btn-group > .form-check:not(:last-child) > .btn {
+ border-bottom-right-radius: 0;
+ border-bottom-left-radius: 0;
+}
+
+.radio-group-v .btn-group > .form-check:not(:first-child) > .btn {
+ border-top-right-radius: 0;
+ border-top-left-radius: 0;
+ margin-top: -3px;
+}
+
+
+.radio-group-v .btn-group .btn{
+ width: 100%;
+}
+
+.radio-group-wide .form-check {
+ padding-left: 0px;
+}
+
+.radio-group-wide .btn-group{
+ flex-wrap: wrap;
+}
+
+.radio-group-wide .btn-group .form-check{
+ flex: 1;
+ margin-top: -3px;
+ margin-left: -1px;
+}
+
+
+.radio-group-wide .btn-group .form-check .btn{
+ width: 100%;
+ border-radius: 0;
+}
+
+.radio-group-wide .btn-group .form-check:first-child > .btn{
+ border-top-left-radius: 10px;
+}
+
+.radio-group-wide .btn-group .form-check:last-child > .btn{
+ border-bottom-right-radius: 10px;
+}
+
+div#app-sidebar{
+ border-right: 2px solid var(--primary);
+}
\ No newline at end of file
diff --git a/qcpanel/old_run.py b/qcpanel/old_run.py
deleted file mode 100644
index 1089e7a..0000000
--- a/qcpanel/old_run.py
+++ /dev/null
@@ -1,290 +0,0 @@
-import argparse
-import os
-from pathlib import Path
-
-import panel as pn
-import param
-
-from quacc.evaluation.estimators import CE
-from quacc.evaluation.report import DatasetReport
-
-pn.extension(design="bootstrap")
-
-
-def create_cr_plots(
- dr: DatasetReport,
- mode="delta",
- metric="acc",
- estimators=None,
- prev=None,
-):
- idx = [round(cr.train_prev[1] * 100) for cr in dr.crs].index(prev)
- cr = dr.crs[idx]
- estimators = CE.name[estimators]
- _dpi = 112
- return pn.pane.Matplotlib(
- cr.get_plots(
- mode=mode,
- metric=metric,
- estimators=estimators,
- conf="panel",
- return_fig=True,
- ),
- tight=True,
- format="png",
- sizing_mode="scale_height",
- # sizing_mode="scale_both",
- )
-
-
-def create_avg_plots(
- dr: DatasetReport,
- mode="delta",
- metric="acc",
- estimators=None,
- prev=None,
-):
- estimators = CE.name[estimators]
- return pn.pane.Matplotlib(
- dr.get_plots(
- mode=mode,
- metric=metric,
- estimators=estimators,
- conf="panel",
- return_fig=True,
- ),
- tight=True,
- format="png",
- sizing_mode="scale_height",
- # sizing_mode="scale_both",
- )
-
-
-def build_cr_tab(dr: DatasetReport):
- _data = dr.data()
- _metrics = _data.columns.unique(0)
- _estimators = _data.columns.unique(1)
-
- valid_metrics = [m for m in _metrics if not m.endswith("_score")]
- metric_widget = pn.widgets.Select(
- name="metric",
- value="acc",
- options=valid_metrics,
- align="center",
- )
-
- valid_estimators = [e for e in _estimators if e != "ref"]
- estimators_widget = pn.widgets.CheckButtonGroup(
- name="estimators",
- options=valid_estimators,
- value=valid_estimators,
- button_style="outline",
- button_type="primary",
- align="center",
- orientation="vertical",
- sizing_mode="scale_width",
- )
-
- valid_plot_modes = ["delta", "delta_stdev", "diagonal", "shift"]
- plot_mode_widget = pn.widgets.RadioButtonGroup(
- name="mode",
- value=valid_plot_modes[0],
- options=valid_plot_modes,
- button_style="outline",
- button_type="primary",
- align="center",
- orientation="vertical",
- sizing_mode="scale_width",
- )
-
- valid_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
- prevs_widget = pn.widgets.RadioButtonGroup(
- name="train prevalence",
- value=valid_prevs[0],
- options=valid_prevs,
- button_style="outline",
- button_type="primary",
- align="center",
- orientation="vertical",
- )
-
- plot_pane = pn.bind(
- create_cr_plots,
- dr=dr,
- mode=plot_mode_widget,
- metric=metric_widget,
- estimators=estimators_widget,
- prev=prevs_widget,
- )
-
- return pn.Row(
- pn.Spacer(width=20),
- pn.Column(
- metric_widget,
- pn.Row(
- prevs_widget,
- plot_mode_widget,
- ),
- estimators_widget,
- align="center",
- ),
- pn.Spacer(sizing_mode="scale_width"),
- plot_pane,
- )
-
-
-def build_avg_tab(dr: DatasetReport):
- _data = dr.data()
- _metrics = _data.columns.unique(0)
- _estimators = _data.columns.unique(1)
-
- valid_metrics = [m for m in _metrics if not m.endswith("_score")]
- metric_widget = pn.widgets.Select(
- name="metric",
- value="acc",
- options=valid_metrics,
- align="center",
- )
-
- valid_estimators = [e for e in _estimators if e != "ref"]
- estimators_widget = pn.widgets.CheckButtonGroup(
- name="estimators",
- options=valid_estimators,
- value=valid_estimators,
- button_style="outline",
- button_type="primary",
- align="center",
- orientation="vertical",
- sizing_mode="scale_width",
- )
-
- valid_plot_modes = [
- "delta_train",
- "stdev_train",
- "delta_test",
- "stdev_test",
- "shift",
- ]
- plot_mode_widget = pn.widgets.RadioButtonGroup(
- name="mode",
- value=valid_plot_modes[0],
- options=valid_plot_modes,
- button_style="outline",
- button_type="primary",
- align="center",
- orientation="vertical",
- sizing_mode="scale_width",
- )
-
- plot_pane = pn.bind(
- create_avg_plots,
- dr=dr,
- mode=plot_mode_widget,
- metric=metric_widget,
- estimators=estimators_widget,
- )
-
- return pn.Row(
- pn.Spacer(width=20),
- pn.Column(
- metric_widget,
- plot_mode_widget,
- estimators_widget,
- align="center",
- ),
- pn.Spacer(sizing_mode="scale_width"),
- plot_pane,
- )
-
-
-def build_dataset(dataset_path: Path):
- dr: DatasetReport = DatasetReport.unpickle(dataset_path)
-
- prevs_tab = ("train prevs.", build_cr_tab(dr))
- avg_tab = ("avg", build_avg_tab(dr))
-
- app = pn.Tabs(objects=[avg_tab, prevs_tab], dynamic=False)
- app.servable()
- return app
-
-
-def explore_datasets(root: Path | str):
- if isinstance(root, str):
- root = Path(root)
-
- if root.name == "plot":
- return []
-
- if not root.exists():
- return []
-
- drs = []
- for f in os.listdir(root):
- if (root / f).is_dir():
- drs += explore_datasets(root / f)
- elif f == f"{root.name}.pickle":
- drs.append((root, build_dataset(root / f)))
- # drs.append((str(root),))
-
- return drs
-
-
-class PlotSelector(param.Parameterized):
- metric = param.Selector(objects=["acc", "f1"])
- view = param.Selector(objects=["train prevs", "avg"])
-
-
-def plot_selector_widget():
- return pn.Param(
- PlotSelector.param,
- widgets={
- "metric": pn.widgets.Select,
- "view": pn.widgets.Select,
- },
- )
-
-
-def serve(address="localhost"):
- # app = build_dataset(Path("output/rcv1_CCAT_9prevs/rcv1_CCAT_9prevs.pickle"))
- __base_path = "output"
- __tabs = sorted(
- explore_datasets(__base_path), key=lambda t: (len(t[0].parts), t[0])
- )
- __tabs = [(str(p.relative_to(Path(__base_path))), d) for (p, d) in __tabs]
- if len(__tabs) > 0:
- app = pn.Tabs(
- objects=__tabs,
- tabs_location="left",
- dynamic=False,
- )
- else:
- app = pn.Column(
- pn.pane.Str("No Dataset Found", styles={"font-size": "24pt"}),
- align="center",
- )
-
- __port = 33420
- __allowed = [address]
- if address == "localhost":
- __allowed.append("127.0.0.1")
-
- pn.serve(
- app,
- autoreload=True,
- port=__port,
- show=False,
- address=address,
- websocket_origin=[f"{_a}:{__port}" for _a in __allowed],
- )
-
-
-def run():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--address",
- action="store",
- dest="address",
- default="localhost",
- )
- args = parser.parse_args()
- serve(address=args.address)
diff --git a/qcpanel/run.py b/qcpanel/run.py
index fdaf2b0..ca6a477 100644
--- a/qcpanel/run.py
+++ b/qcpanel/run.py
@@ -1,10 +1,11 @@
import argparse
import panel as pn
+from panel.theme.fast import FastDarkTheme, FastDefaultTheme
from qcpanel.viewer import QuaccTestViewer
-# pn.config.design = pn.theme.Bootstrap
+# pn.config.design = Fast
# pn.config.theme = "dark"
pn.config.notifications = True
@@ -59,8 +60,8 @@ def app_instance():
],
main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")],
modal=[qtv.modal_pane],
- theme=pn.theme.DarkTheme,
- theme_toggle=False,
+ # theme=FastDefaultTheme,
+ theme_toggle=True,
)
app.servable()
diff --git a/qcpanel/util.py b/qcpanel/util.py
index 6d1473d..3286c2f 100644
--- a/qcpanel/util.py
+++ b/qcpanel/util.py
@@ -52,7 +52,7 @@ def create_plots(
metric=metric,
estimators=estimators,
conf="panel",
- return_fig=True,
+ save_fig=False,
)
return (
pn.pane.Matplotlib(
@@ -91,7 +91,7 @@ def create_plots(
metric=metric,
estimators=estimators,
conf="panel",
- return_fig=True,
+ save_fig=False,
)
return (
pn.pane.Matplotlib(
diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py
index e61e709..ef57345 100644
--- a/quacc/evaluation/report.py
+++ b/quacc/evaluation/report.py
@@ -6,7 +6,7 @@ from typing import List, Tuple
import numpy as np
import pandas as pd
-from quacc import plot
+import quacc.plot as plot
from quacc.utils import fmt_line_md
@@ -215,16 +215,17 @@ class CompReport:
def get_plots(
self,
- mode="delta",
+ mode="delta_train",
metric="acc",
estimators=None,
conf="default",
- return_fig=False,
+ save_fig=True,
base_path=None,
+ backend=None,
) -> List[Tuple[str, Path]]:
if mode == "delta_train":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
- if avg_data.empty is True:
+ if avg_data.empty:
return None
return plot.plot_delta(
@@ -234,8 +235,9 @@ class CompReport:
metric=metric,
name=conf,
train_prev=self.train_prev,
- return_fig=return_fig,
+ save_fig=save_fig,
base_path=base_path,
+ backend=backend,
)
elif mode == "stdev_train":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
@@ -251,8 +253,9 @@ class CompReport:
name=conf,
train_prev=self.train_prev,
stdevs=st_data.T.to_numpy(),
- return_fig=return_fig,
+ save_fig=save_fig,
base_path=base_path,
+ backend=backend,
)
elif mode == "diagonal":
f_data = self.data(metric=metric + "_score", estimators=estimators)
@@ -268,8 +271,9 @@ class CompReport:
metric=metric,
name=conf,
train_prev=self.train_prev,
- return_fig=return_fig,
+ save_fig=save_fig,
base_path=base_path,
+ backend=backend,
)
elif mode == "shift":
_shift_data = self.shift_data(metric=metric, estimators=estimators)
@@ -290,8 +294,9 @@ class CompReport:
name=conf,
train_prev=self.train_prev,
counts=shift_counts.T.to_numpy(),
- return_fig=return_fig,
+ save_fig=save_fig,
base_path=base_path,
+ backend=backend,
)
def to_md(
@@ -323,11 +328,12 @@ class CompReport:
plot_modes = [m for m in modes if not m.endswith("table")]
for mode in plot_modes:
res += f"### {mode}\n"
- op = self.get_plots(
+ _, op = self.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf=conf,
+ save_fig=True,
base_path=plot_path,
)
res += f".as_posix()})\n"
@@ -424,12 +430,15 @@ class DatasetReport:
metric="acc",
estimators=None,
conf="default",
- return_fig=False,
+ save_fig=True,
base_path=None,
+ backend=None,
):
if mode == "delta_train":
_data = self.data(metric, estimators) if data is None else data
avg_on_train = _data.groupby(level=1).mean()
+ if avg_on_train.empty:
+ return None
prevs_on_train = np.sort(avg_on_train.index.unique(0))
return plot.plot_delta(
base_prevs=np.around(
@@ -441,12 +450,15 @@ class DatasetReport:
name=conf,
train_prev=None,
avg="train",
- return_fig=return_fig,
+ save_fig=save_fig,
base_path=base_path,
+ backend=backend,
)
elif mode == "stdev_train":
_data = self.data(metric, estimators) if data is None else data
avg_on_train = _data.groupby(level=1).mean()
+ if avg_on_train.empty:
+ return None
prevs_on_train = np.sort(avg_on_train.index.unique(0))
stdev_on_train = _data.groupby(level=1).std()
return plot.plot_delta(
@@ -460,12 +472,15 @@ class DatasetReport:
train_prev=None,
stdevs=stdev_on_train.T.to_numpy(),
avg="train",
- return_fig=return_fig,
+ save_fig=save_fig,
base_path=base_path,
+ backend=backend,
)
elif mode == "delta_test":
_data = self.data(metric, estimators) if data is None else data
avg_on_test = _data.groupby(level=0).mean()
+ if avg_on_test.empty:
+ return None
prevs_on_test = np.sort(avg_on_test.index.unique(0))
return plot.plot_delta(
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
@@ -475,12 +490,15 @@ class DatasetReport:
name=conf,
train_prev=None,
avg="test",
- return_fig=return_fig,
+ save_fig=save_fig,
base_path=base_path,
+ backend=backend,
)
elif mode == "stdev_test":
_data = self.data(metric, estimators) if data is None else data
avg_on_test = _data.groupby(level=0).mean()
+ if avg_on_test.empty:
+ return None
prevs_on_test = np.sort(avg_on_test.index.unique(0))
stdev_on_test = _data.groupby(level=0).std()
return plot.plot_delta(
@@ -492,12 +510,15 @@ class DatasetReport:
train_prev=None,
stdevs=stdev_on_test.T.to_numpy(),
avg="test",
- return_fig=return_fig,
+ save_fig=save_fig,
base_path=base_path,
+ backend=backend,
)
elif mode == "shift":
_shift_data = self.shift_data(metric, estimators) if data is None else data
avg_shift = _shift_data.groupby(level=0).mean()
+ if avg_shift.empty:
+ return None
count_shift = _shift_data.groupby(level=0).count()
prevs_shift = np.sort(avg_shift.index.unique(0))
return plot.plot_shift(
@@ -508,8 +529,9 @@ class DatasetReport:
name=conf,
train_prev=None,
counts=count_shift.T.to_numpy(),
- return_fig=return_fig,
+ save_fig=save_fig,
base_path=base_path,
+ backend=backend,
)
def to_md(
@@ -545,24 +567,26 @@ class DatasetReport:
res += avg_on_train_tbl.to_html() + "\n\n"
if "delta_train" in dr_modes:
- delta_op = self.get_plots(
+ _, delta_op = self.get_plots(
data=_data,
mode="delta_train",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
+ save_fig=True,
)
res += f".as_posix()})\n"
if "stdev_train" in dr_modes:
- delta_stdev_op = self.get_plots(
+ _, delta_stdev_op = self.get_plots(
data=_data,
mode="stdev_train",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
+ save_fig=True,
)
res += f".as_posix()})\n"
@@ -575,24 +599,26 @@ class DatasetReport:
res += avg_on_test_tbl.to_html() + "\n\n"
if "delta_test" in dr_modes:
- delta_op = self.get_plots(
+ _, delta_op = self.get_plots(
data=_data,
mode="delta_test",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
+ save_fig=True,
)
res += f".as_posix()})\n"
if "stdev_test" in dr_modes:
- delta_stdev_op = self.get_plots(
+ _, delta_stdev_op = self.get_plots(
data=_data,
mode="stdev_test",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
+ save_fig=True,
)
res += f".as_posix()})\n"
@@ -605,13 +631,14 @@ class DatasetReport:
res += shift_on_train_tbl.to_html() + "\n\n"
if "shift" in dr_modes:
- shift_op = self.get_plots(
+ _, shift_op = self.get_plots(
data=_shift_data,
mode="shift",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
+ save_fig=True,
)
res += f".as_posix()})\n"
diff --git a/quacc/plot.py b/quacc/plot.py
deleted file mode 100644
index e0bbefa..0000000
--- a/quacc/plot.py
+++ /dev/null
@@ -1,265 +0,0 @@
-from pathlib import Path
-
-import matplotlib
-import matplotlib.pyplot as plt
-import numpy as np
-from cycler import cycler
-
-from quacc import utils
-
-matplotlib.use("agg")
-
-
-def _get_markers(n: int):
- ls = "ovx+sDph*^1234X><.Pd"
- if n > len(ls):
- ls = ls * (n / len(ls) + 1)
- return list(ls)[:n]
-
-
-def plot_delta(
- base_prevs,
- columns,
- data,
- *,
- stdevs=None,
- pos_class=1,
- metric="acc",
- name="default",
- train_prev=None,
- legend=True,
- avg=None,
- return_fig=False,
- base_path=None,
-) -> Path:
- _base_title = "delta_stdev" if stdevs is not None else "delta"
- if train_prev is not None:
- t_prev_pos = int(round(train_prev[pos_class] * 100))
- title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
- else:
- title = f"{_base_title}_{name}_avg_{avg}_{metric}"
-
- if base_path is None:
- base_path = utils.get_quacc_home() / "plots"
-
- fig, ax = plt.subplots()
- ax.set_aspect("auto")
- ax.grid()
-
- NUM_COLORS = len(data)
- cm = plt.get_cmap("tab10")
- if NUM_COLORS > 10:
- cm = plt.get_cmap("tab20")
- cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
-
- base_prevs = base_prevs[:, pos_class]
- for method, deltas, _cy in zip(columns, data, cy):
- ax.plot(
- base_prevs,
- deltas,
- label=method,
- color=_cy["color"],
- linestyle="-",
- marker="o",
- markersize=3,
- zorder=2,
- )
- if stdevs is not None:
- _col_idx = np.where(columns == method)[0]
- stdev = stdevs[_col_idx].flatten()
- nn_idx = np.intersect1d(
- np.where(deltas != np.nan)[0],
- np.where(stdev != np.nan)[0],
- )
- _bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
- ax.fill_between(
- _bps,
- _ds - _st,
- _ds + _st,
- color=_cy["color"],
- alpha=0.25,
- )
-
- x_label = "test" if avg is None or avg == "train" else "train"
- ax.set(
- xlabel=f"{x_label} prevalence",
- ylabel=metric,
- title=title,
- )
-
- if legend:
- ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
-
- if return_fig:
- return fig
-
- output_path = base_path / f"{title}.png"
- fig.savefig(output_path, bbox_inches="tight")
- return output_path
-
-
-def plot_diagonal(
- reference,
- columns,
- data,
- *,
- pos_class=1,
- metric="acc",
- name="default",
- train_prev=None,
- legend=True,
- return_fig=False,
- base_path=None,
-):
- if train_prev is not None:
- t_prev_pos = int(round(train_prev[pos_class] * 100))
- title = f"diagonal_{name}_{t_prev_pos}_{metric}"
- else:
- title = f"diagonal_{name}_{metric}"
-
- if base_path is None:
- base_path = utils.get_quacc_home() / "plots"
-
- fig, ax = plt.subplots()
- ax.set_aspect("auto")
- ax.grid()
- ax.set_aspect("equal")
-
- NUM_COLORS = len(data)
- cm = plt.get_cmap("tab10")
- if NUM_COLORS > 10:
- cm = plt.get_cmap("tab20")
- cy = cycler(
- color=[cm(i) for i in range(NUM_COLORS)],
- marker=_get_markers(NUM_COLORS),
- )
-
- reference = np.array(reference)
- x_ticks = np.unique(reference)
- x_ticks.sort()
-
- for deltas, _cy in zip(data, cy):
- ax.plot(
- reference,
- deltas,
- color=_cy["color"],
- linestyle="None",
- marker=_cy["marker"],
- markersize=3,
- zorder=2,
- alpha=0.25,
- )
-
- # ensure limits are equal for both axes
- _alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
- _lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
- ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
-
- for method, deltas, _cy in zip(columns, data, cy):
- slope, interc = np.polyfit(reference, deltas, 1)
- y_lr = np.array([slope * x + interc for x in _lims])
- ax.plot(
- _lims,
- y_lr,
- label=method,
- color=_cy["color"],
- linestyle="-",
- markersize="0",
- zorder=1,
- )
-
- # plot reference line
- ax.plot(
- _lims,
- _lims,
- color="black",
- linestyle="--",
- markersize=0,
- zorder=1,
- )
-
- ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title)
-
- if legend:
- ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
-
- if return_fig:
- return fig
-
- output_path = base_path / f"{title}.png"
- fig.savefig(output_path, bbox_inches="tight")
- return output_path
-
-
-def plot_shift(
- shift_prevs,
- columns,
- data,
- *,
- counts=None,
- pos_class=1,
- metric="acc",
- name="default",
- train_prev=None,
- legend=True,
- return_fig=False,
- base_path=None,
-) -> Path:
- if train_prev is not None:
- t_prev_pos = int(round(train_prev[pos_class] * 100))
- title = f"shift_{name}_{t_prev_pos}_{metric}"
- else:
- title = f"shift_{name}_avg_{metric}"
-
- if base_path is None:
- base_path = utils.get_quacc_home() / "plots"
-
- fig, ax = plt.subplots()
- ax.set_aspect("auto")
- ax.grid()
-
- NUM_COLORS = len(data)
- cm = plt.get_cmap("tab10")
- if NUM_COLORS > 10:
- cm = plt.get_cmap("tab20")
- cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
-
- shift_prevs = shift_prevs[:, pos_class]
- for method, shifts, _cy in zip(columns, data, cy):
- ax.plot(
- shift_prevs,
- shifts,
- label=method,
- color=_cy["color"],
- linestyle="-",
- marker="o",
- markersize=3,
- zorder=2,
- )
- if counts is not None:
- _col_idx = np.where(columns == method)[0]
- count = counts[_col_idx].flatten()
- for prev, shift, cnt in zip(shift_prevs, shifts, count):
- label = f"{cnt}"
- plt.annotate(
- label,
- (prev, shift),
- textcoords="offset points",
- xytext=(0, 10),
- ha="center",
- color=_cy["color"],
- fontsize=12.0,
- )
-
- ax.set(xlabel="dataset shift", ylabel=metric, title=title)
-
- if legend:
- ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
-
- if return_fig:
- return fig
-
- output_path = base_path / f"{title}.png"
- fig.savefig(output_path, bbox_inches="tight")
-
- return output_path
diff --git a/quacc/plot/__init__.py b/quacc/plot/__init__.py
new file mode 100644
index 0000000..6a182c5
--- /dev/null
+++ b/quacc/plot/__init__.py
@@ -0,0 +1 @@
+from quacc.plot.plot import get_backend, plot_delta, plot_diagonal, plot_shift
diff --git a/quacc/plot/base.py b/quacc/plot/base.py
new file mode 100644
index 0000000..36a58b2
--- /dev/null
+++ b/quacc/plot/base.py
@@ -0,0 +1,54 @@
+from pathlib import Path
+
+
+class BasePlot:
+ @classmethod
+ def save_fig(cls, fig, base_path, title) -> Path:
+ ...
+
+ @classmethod
+ def plot_diagonal(
+ cls,
+ reference,
+ columns,
+ data,
+ *,
+ pos_class=1,
+ title="default",
+ x_label="true",
+ y_label="estim.",
+ legend=True,
+ ):
+ ...
+
+ @classmethod
+ def plot_delta(
+ cls,
+ base_prevs,
+ columns,
+ data,
+ *,
+ stdevs=None,
+ pos_class=1,
+ title="default",
+ x_label="prevs.",
+ y_label="error",
+ legend=True,
+ ):
+ ...
+
+ @classmethod
+ def plot_shift(
+ cls,
+ shift_prevs,
+ columns,
+ data,
+ *,
+ counts=None,
+ pos_class=1,
+ title="default",
+ x_label="true",
+ y_label="estim.",
+ legend=True,
+ ):
+ ...
diff --git a/quacc/plot/mpl.py b/quacc/plot/mpl.py
new file mode 100644
index 0000000..dd84b7a
--- /dev/null
+++ b/quacc/plot/mpl.py
@@ -0,0 +1,222 @@
+from pathlib import Path
+
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+from cycler import cycler
+
+from quacc import utils
+from quacc.plot.base import BasePlot
+
+matplotlib.use("agg")
+
+
+class MplPlot(BasePlot):
+ def _get_markers(self, n: int):
+ ls = "ovx+sDph*^1234X><.Pd"
+ if n > len(ls):
+ ls = ls * (n / len(ls) + 1)
+ return list(ls)[:n]
+
+ def save_fig(self, fig, base_path, title) -> Path:
+ if base_path is None:
+ base_path = utils.get_quacc_home() / "plots"
+ output_path = base_path / f"{title}.png"
+ fig.savefig(output_path, bbox_inches="tight")
+ return output_path
+
+ def plot_delta(
+ self,
+ base_prevs,
+ columns,
+ data,
+ *,
+ stdevs=None,
+ pos_class=1,
+ title="default",
+ x_label="prevs.",
+ y_label="error",
+ legend=True,
+ ):
+ fig, ax = plt.subplots()
+ ax.set_aspect("auto")
+ ax.grid()
+
+ NUM_COLORS = len(data)
+ cm = plt.get_cmap("tab10")
+ if NUM_COLORS > 10:
+ cm = plt.get_cmap("tab20")
+ cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
+
+ base_prevs = base_prevs[:, pos_class]
+ for method, deltas, _cy in zip(columns, data, cy):
+ ax.plot(
+ base_prevs,
+ deltas,
+ label=method,
+ color=_cy["color"],
+ linestyle="-",
+ marker="o",
+ markersize=3,
+ zorder=2,
+ )
+ if stdevs is not None:
+ _col_idx = np.where(columns == method)[0]
+ stdev = stdevs[_col_idx].flatten()
+ nn_idx = np.intersect1d(
+ np.where(deltas != np.nan)[0],
+ np.where(stdev != np.nan)[0],
+ )
+ _bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
+ ax.fill_between(
+ _bps,
+ _ds - _st,
+ _ds + _st,
+ color=_cy["color"],
+ alpha=0.25,
+ )
+
+ ax.set(
+ xlabel=f"{x_label} prevalence",
+ ylabel=y_label,
+ title=title,
+ )
+
+ if legend:
+ ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
+
+ return fig
+
+ def plot_diagonal(
+ self,
+ reference,
+ columns,
+ data,
+ *,
+ pos_class=1,
+ title="default",
+ x_label="true",
+ y_label="estim.",
+ legend=True,
+ ):
+ fig, ax = plt.subplots()
+ ax.set_aspect("auto")
+ ax.grid()
+ ax.set_aspect("equal")
+
+ NUM_COLORS = len(data)
+ cm = plt.get_cmap("tab10")
+ if NUM_COLORS > 10:
+ cm = plt.get_cmap("tab20")
+ cy = cycler(
+ color=[cm(i) for i in range(NUM_COLORS)],
+ marker=self._get_markers(NUM_COLORS),
+ )
+
+ reference = np.array(reference)
+ x_ticks = np.unique(reference)
+ x_ticks.sort()
+
+ for deltas, _cy in zip(data, cy):
+ ax.plot(
+ reference,
+ deltas,
+ color=_cy["color"],
+ linestyle="None",
+ marker=_cy["marker"],
+ markersize=3,
+ zorder=2,
+ alpha=0.25,
+ )
+
+ # ensure limits are equal for both axes
+ _alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
+ _lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
+ ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
+
+ for method, deltas, _cy in zip(columns, data, cy):
+ slope, interc = np.polyfit(reference, deltas, 1)
+ y_lr = np.array([slope * x + interc for x in _lims])
+ ax.plot(
+ _lims,
+ y_lr,
+ label=method,
+ color=_cy["color"],
+ linestyle="-",
+ markersize="0",
+ zorder=1,
+ )
+
+ # plot reference line
+ ax.plot(
+ _lims,
+ _lims,
+ color="black",
+ linestyle="--",
+ markersize=0,
+ zorder=1,
+ )
+
+ ax.set(xlabel=x_label, ylabel=y_label, title=title)
+
+ if legend:
+ ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
+
+ return fig
+
+ def plot_shift(
+ self,
+ shift_prevs,
+ columns,
+ data,
+ *,
+ counts=None,
+ pos_class=1,
+ title="default",
+ x_label="true",
+ y_label="estim.",
+ legend=True,
+ ):
+ fig, ax = plt.subplots()
+ ax.set_aspect("auto")
+ ax.grid()
+
+ NUM_COLORS = len(data)
+ cm = plt.get_cmap("tab10")
+ if NUM_COLORS > 10:
+ cm = plt.get_cmap("tab20")
+ cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
+
+ shift_prevs = shift_prevs[:, pos_class]
+ for method, shifts, _cy in zip(columns, data, cy):
+ ax.plot(
+ shift_prevs,
+ shifts,
+ label=method,
+ color=_cy["color"],
+ linestyle="-",
+ marker="o",
+ markersize=3,
+ zorder=2,
+ )
+ if counts is not None:
+ _col_idx = np.where(columns == method)[0]
+ count = counts[_col_idx].flatten()
+ for prev, shift, cnt in zip(shift_prevs, shifts, count):
+ label = f"{cnt}"
+ plt.annotate(
+ label,
+ (prev, shift),
+ textcoords="offset points",
+ xytext=(0, 10),
+ ha="center",
+ color=_cy["color"],
+ fontsize=12.0,
+ )
+
+ ax.set(xlabel=x_label, ylabel=y_label, title=title)
+
+ if legend:
+ ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
+
+ return fig
diff --git a/quacc/plot/plot.py b/quacc/plot/plot.py
new file mode 100644
index 0000000..1bd2369
--- /dev/null
+++ b/quacc/plot/plot.py
@@ -0,0 +1,144 @@
+from quacc.plot.base import BasePlot
+from quacc.plot.mpl import MplPlot
+from quacc.plot.plotly import PlotlyPlot
+
+__backend: BasePlot = MplPlot()
+
+
+def get_backend(be, theme=None):
+ match be:
+ case "matplotlib" | "mpl":
+ return MplPlot()
+ case "plotly":
+ return PlotlyPlot(theme=theme)
+ case _:
+ return MplPlot()
+
+
+def plot_delta(
+ base_prevs,
+ columns,
+ data,
+ *,
+ stdevs=None,
+ pos_class=1,
+ metric="acc",
+ name="default",
+ train_prev=None,
+ legend=True,
+ avg=None,
+ save_fig=False,
+ base_path=None,
+ backend=None,
+):
+ backend = __backend if backend is None else backend
+ _base_title = "delta_stdev" if stdevs is not None else "delta"
+ if train_prev is not None:
+ t_prev_pos = int(round(train_prev[pos_class] * 100))
+ title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
+ else:
+ title = f"{_base_title}_{name}_avg_{avg}_{metric}"
+
+ x_label = f"{'test' if avg is None or avg == 'train' else 'train'} prevalence"
+ y_label = f"{metric} error"
+ fig = backend.plot_delta(
+ base_prevs,
+ columns,
+ data,
+ stdevs=stdevs,
+ pos_class=pos_class,
+ title=title,
+ x_label=x_label,
+ y_label=y_label,
+ legend=legend,
+ )
+
+ if save_fig:
+ output_path = backend.save_fig(fig, base_path, title)
+ return fig, output_path
+
+ return fig
+
+
+def plot_diagonal(
+ reference,
+ columns,
+ data,
+ *,
+ pos_class=1,
+ metric="acc",
+ name="default",
+ train_prev=None,
+ legend=True,
+ save_fig=False,
+ base_path=None,
+ backend=None,
+):
+ backend = __backend if backend is None else backend
+ if train_prev is not None:
+ t_prev_pos = int(round(train_prev[pos_class] * 100))
+ title = f"diagonal_{name}_{t_prev_pos}_{metric}"
+ else:
+ title = f"diagonal_{name}_{metric}"
+
+ x_label = f"true {metric}"
+ y_label = f"estim. {metric}"
+ fig = backend.plot_diagonal(
+ reference,
+ columns,
+ data,
+ pos_class=pos_class,
+ title=title,
+ x_label=x_label,
+ y_label=y_label,
+ legend=legend,
+ )
+
+ if save_fig:
+ output_path = backend.save_fig(fig, base_path, title)
+ return fig, output_path
+
+ return fig
+
+
+def plot_shift(
+ shift_prevs,
+ columns,
+ data,
+ *,
+ counts=None,
+ pos_class=1,
+ metric="acc",
+ name="default",
+ train_prev=None,
+ legend=True,
+ save_fig=False,
+ base_path=None,
+ backend=None,
+):
+ backend = __backend if backend is None else backend
+ if train_prev is not None:
+ t_prev_pos = int(round(train_prev[pos_class] * 100))
+ title = f"shift_{name}_{t_prev_pos}_{metric}"
+ else:
+ title = f"shift_{name}_avg_{metric}"
+
+ x_label = "dataset shift"
+ y_label = f"{metric} error"
+ fig = backend.plot_shift(
+ shift_prevs,
+ columns,
+ data,
+ counts=counts,
+ pos_class=pos_class,
+ title=title,
+ x_label=x_label,
+ y_label=y_label,
+ legend=legend,
+ )
+
+ if save_fig:
+ output_path = backend.save_fig(fig, base_path, title)
+ return fig, output_path
+
+ return fig
diff --git a/quacc/plot/plotly.py b/quacc/plot/plotly.py
new file mode 100644
index 0000000..074c277
--- /dev/null
+++ b/quacc/plot/plotly.py
@@ -0,0 +1,201 @@
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+import plotly
+import plotly.graph_objects as go
+
+from quacc.plot.base import BasePlot
+
+
+class PlotlyPlot(BasePlot):
+ __themes = defaultdict(
+ lambda: {
+ "template": "seaborn",
+ }
+ )
+ __themes = __themes | {
+ "dark": {
+ "template": "plotly_dark",
+ },
+ }
+
+ def __init__(self, theme=None):
+ self.theme = PlotlyPlot.__themes[theme]
+
+ def hex_to_rgb(self, hex: str, t: float | None = None):
+ hex = hex.lstrip("#")
+ rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
+ if t is not None:
+ rgb.append(t)
+ return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
+
+ def get_colors(self, num):
+ match num:
+ case v if v > 10:
+ __colors = plotly.colors.qualitative.Light24
+ case _:
+ __colors = plotly.colors.qualitative.Plotly
+
+ def __generator(cs):
+ while True:
+ for c in cs:
+ yield c
+
+ return __generator(__colors)
+
+ def update_layout(self, fig, title, x_label, y_label):
+ fig.update_layout(
+ title=title,
+ xaxis_title=x_label,
+ yaxis_title=y_label,
+ template=self.theme["template"],
+ )
+
+ def save_fig(self, fig, base_path, title) -> Path:
+ return None
+
+ def plot_delta(
+ self,
+ base_prevs,
+ columns,
+ data,
+ *,
+ stdevs=None,
+ pos_class=1,
+ title="default",
+ x_label="prevs.",
+ y_label="error",
+ legend=True,
+ ) -> go.Figure:
+ fig = go.Figure()
+ x = base_prevs[:, pos_class]
+ line_colors = self.get_colors(len(columns))
+ for name, delta in zip(columns, data):
+ color = next(line_colors)
+ _line = [
+ go.Scatter(
+ x=x,
+ y=delta,
+ mode="lines+markers",
+ name=name,
+ line=dict(color=self.hex_to_rgb(color)),
+ hovertemplate="prev.: %{x}
error: %{y:,.4f}",
+ )
+ ]
+ _error = []
+ if stdevs is not None:
+ _col_idx = np.where(columns == name)[0]
+ stdev = stdevs[_col_idx].flatten()
+ _error = [
+ go.Scatter(
+ x=np.concatenate([x, x[::-1]]),
+ y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
+ name=int(_col_idx[0]),
+ fill="toself",
+ fillcolor=self.hex_to_rgb(color, t=0.2),
+ line=dict(color="rgba(255, 255, 255, 0)"),
+ hoverinfo="skip",
+ showlegend=False,
+ )
+ ]
+ fig.add_traces(_line + _error)
+
+ self.update_layout(fig, title, x_label, y_label)
+ return fig
+
+ def plot_diagonal(
+ self,
+ reference,
+ columns,
+ data,
+ *,
+ pos_class=1,
+ title="default",
+ x_label="true",
+ y_label="estim.",
+ legend=True,
+ ) -> go.Figure:
+ fig = go.Figure()
+ x = reference
+ line_colors = self.get_colors(len(columns))
+
+ _edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)]))
+ _lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
+
+ for name, val in zip(columns, data):
+ color = next(line_colors)
+ slope, interc = np.polyfit(x, val, 1)
+ y_lr = np.array([slope * _x + interc for _x in _lims[0]])
+ fig.add_traces(
+ [
+ go.Scatter(
+ x=x,
+ y=val,
+ customdata=np.stack((val - x,), axis=-1),
+ mode="markers",
+ name=name,
+ line=dict(color=self.hex_to_rgb(color, t=0.5)),
+ hovertemplate="true acc: %{x:,.4f}
estim. acc: %{y:,.4f}
acc err.: %{customdata[0]:,.4f}",
+ ),
+ go.Scatter(
+ x=_lims[0],
+ y=y_lr,
+ mode="lines",
+ name=name,
+ line=dict(color=self.hex_to_rgb(color), width=3),
+ showlegend=False,
+ ),
+ ]
+ )
+ fig.add_trace(
+ go.Scatter(
+ x=_lims[0],
+ y=_lims[1],
+ mode="lines",
+ name="reference",
+ showlegend=False,
+ line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
+ )
+ )
+
+ self.update_layout(fig, title, x_label, y_label)
+ fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0)
+ return fig
+
+ def plot_shift(
+ self,
+ shift_prevs,
+ columns,
+ data,
+ *,
+ counts=None,
+ pos_class=1,
+ title="default",
+ x_label="true",
+ y_label="estim.",
+ legend=True,
+ ) -> go.Figure:
+ fig = go.Figure()
+ x = shift_prevs[:, pos_class]
+ line_colors = self.get_colors(len(columns))
+ for name, delta in zip(columns, data):
+ col_idx = (columns == name).nonzero()[0][0]
+ color = next(line_colors)
+ fig.add_trace(
+ go.Scatter(
+ x=x,
+ y=delta,
+ customdata=np.stack((counts[col_idx],), axis=-1),
+ mode="lines+markers",
+ name=name,
+ line=dict(color=self.hex_to_rgb(color)),
+ hovertemplate="shift: %{x}
error: %{y}"
+ + "
count: %{customdata[0]}"
+ if counts is not None
+ else "",
+ )
+ )
+
+ self.update_layout(fig, title, x_label, y_label)
+ return fig
diff --git a/remote.py b/remote.py
index 9b8af19..1f3bf32 100644
--- a/remote.py
+++ b/remote.py
@@ -24,6 +24,7 @@ __to_sync_up = {
"quacc",
"baselines",
"qcpanel",
+ "qcdash",
],
"file": [
"conf.yaml",