Binary quantifier added, support added and tested.

This commit is contained in:
Lorenzo Volpi 2023-07-26 00:38:23 +02:00
parent b969234244
commit 1347ac3c9d
12 changed files with 371 additions and 112 deletions

0
.editorconfig Normal file
View File

1
.gitignore vendored
View File

@ -2,3 +2,4 @@
quavenv/* quavenv/*
*.pdf *.pdf
quacc/__pycache__/* quacc/__pycache__/*
tests/__pycache__/*

16
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "main",
"type": "python",
"request": "launch",
"program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\quacc\\main.py",
"console": "integratedTerminal",
"justMyCode": true
}
]
}

129
poetry.lock generated
View File

@ -1,10 +1,9 @@
# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]] [[package]]
name = "abstention" name = "abstention"
version = "0.1.3.1" version = "0.1.3.1"
description = "Functions for abstention, calibration and label shift domain adaptation" description = "Functions for abstention, calibration and label shift domain adaptation"
category = "main"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
@ -20,7 +19,6 @@ scipy = ">=1.1.0"
name = "colorama" name = "colorama"
version = "0.4.6" version = "0.4.6"
description = "Cross-platform colored terminal text." description = "Cross-platform colored terminal text."
category = "main"
optional = false optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [ files = [
@ -32,7 +30,6 @@ files = [
name = "contourpy" name = "contourpy"
version = "1.0.7" version = "1.0.7"
description = "Python library for calculating contours of 2D quadrilateral grids" description = "Python library for calculating contours of 2D quadrilateral grids"
category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -107,7 +104,6 @@ test-no-images = ["pytest"]
name = "cycler" name = "cycler"
version = "0.11.0" version = "0.11.0"
description = "Composable style cycles" description = "Composable style cycles"
category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
@ -119,7 +115,6 @@ files = [
name = "fonttools" name = "fonttools"
version = "4.39.4" version = "4.39.4"
description = "Tools to manipulate font files" description = "Tools to manipulate font files"
category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -141,11 +136,21 @@ ufo = ["fs (>=2.2.0,<3)"]
unicode = ["unicodedata2 (>=15.0.0)"] unicode = ["unicodedata2 (>=15.0.0)"]
woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"]
[[package]]
name = "iniconfig"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.7"
files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]] [[package]]
name = "joblib" name = "joblib"
version = "1.2.0" version = "1.2.0"
description = "Lightweight pipelining with Python functions" description = "Lightweight pipelining with Python functions"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -157,7 +162,6 @@ files = [
name = "kiwisolver" name = "kiwisolver"
version = "1.4.4" version = "1.4.4"
description = "A fast implementation of the Cassowary constraint solver" description = "A fast implementation of the Cassowary constraint solver"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -235,7 +239,6 @@ files = [
name = "matplotlib" name = "matplotlib"
version = "3.7.1" version = "3.7.1"
description = "Python plotting package" description = "Python plotting package"
category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -297,7 +300,6 @@ python-dateutil = ">=2.7"
name = "numpy" name = "numpy"
version = "1.24.3" version = "1.24.3"
description = "Fundamental package for array computing in Python" description = "Fundamental package for array computing in Python"
category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -335,7 +337,6 @@ files = [
name = "packaging" name = "packaging"
version = "23.1" version = "23.1"
description = "Core utilities for Python packages" description = "Core utilities for Python packages"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -345,37 +346,36 @@ files = [
[[package]] [[package]]
name = "pandas" name = "pandas"
version = "2.0.1" version = "2.0.3"
description = "Powerful data structures for data analysis, time series, and statistics" description = "Powerful data structures for data analysis, time series, and statistics"
category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pandas-2.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70a996a1d2432dadedbb638fe7d921c88b0cc4dd90374eab51bb33dc6c0c2a12"}, {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"},
{file = "pandas-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:909a72b52175590debbf1d0c9e3e6bce2f1833c80c76d80bd1aa09188be768e5"}, {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"},
{file = "pandas-2.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe7914d8ddb2d54b900cec264c090b88d141a1eed605c9539a187dbc2547f022"}, {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce0c6f76a0f1ba361551f3e6dceaff06bde7514a374aa43e33b588ec10420183"},
{file = "pandas-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a514ae436b23a92366fbad8365807fc0eed15ca219690b3445dcfa33597a5cc"}, {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba619e410a21d8c387a1ea6e8a0e49bb42216474436245718d7f2e88a2f8d7c0"},
{file = "pandas-2.0.1-cp310-cp310-win32.whl", hash = "sha256:12bd6618e3cc737c5200ecabbbb5eaba8ab645a4b0db508ceeb4004bb10b060e"}, {file = "pandas-2.0.3-cp310-cp310-win32.whl", hash = "sha256:3ef285093b4fe5058eefd756100a367f27029913760773c8bf1d2d8bebe5d210"},
{file = "pandas-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:2b6fe5f7ce1cba0e74188c8473c9091ead9b293ef0a6794939f8cc7947057abd"}, {file = "pandas-2.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:9ee1a69328d5c36c98d8e74db06f4ad518a1840e8ccb94a4ba86920986bb617e"},
{file = "pandas-2.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:00959a04a1d7bbc63d75a768540fb20ecc9e65fd80744c930e23768345a362a7"}, {file = "pandas-2.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b084b91d8d66ab19f5bb3256cbd5ea661848338301940e17f4492b2ce0801fe8"},
{file = "pandas-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af2449e9e984dfad39276b885271ba31c5e0204ffd9f21f287a245980b0e4091"}, {file = "pandas-2.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:37673e3bdf1551b95bf5d4ce372b37770f9529743d2498032439371fc7b7eb26"},
{file = "pandas-2.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910df06feaf9935d05247db6de452f6d59820e432c18a2919a92ffcd98f8f79b"}, {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9cb1e14fdb546396b7e1b923ffaeeac24e4cedd14266c3497216dd4448e4f2d"},
{file = "pandas-2.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa0067f2419f933101bdc6001bcea1d50812afbd367b30943417d67fbb99678"}, {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9cd88488cceb7635aebb84809d087468eb33551097d600c6dad13602029c2df"},
{file = "pandas-2.0.1-cp311-cp311-win32.whl", hash = "sha256:7b8395d335b08bc8b050590da264f94a439b4770ff16bb51798527f1dd840388"}, {file = "pandas-2.0.3-cp311-cp311-win32.whl", hash = "sha256:694888a81198786f0e164ee3a581df7d505024fbb1f15202fc7db88a71d84ebd"},
{file = "pandas-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:8db5a644d184a38e6ed40feeb12d410d7fcc36648443defe4707022da127fc35"}, {file = "pandas-2.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:6a21ab5c89dcbd57f78d0ae16630b090eec626360085a4148693def5452d8a6b"},
{file = "pandas-2.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7bbf173d364130334e0159a9a034f573e8b44a05320995127cf676b85fd8ce86"}, {file = "pandas-2.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9e4da0d45e7f34c069fe4d522359df7d23badf83abc1d1cef398895822d11061"},
{file = "pandas-2.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6c0853d487b6c868bf107a4b270a823746175b1932093b537b9b76c639fc6f7e"}, {file = "pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:32fca2ee1b0d93dd71d979726b12b61faa06aeb93cf77468776287f41ff8fdc5"},
{file = "pandas-2.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25e23a03f7ad7211ffa30cb181c3e5f6d96a8e4cb22898af462a7333f8a74eb"}, {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:258d3624b3ae734490e4d63c430256e716f488c4fcb7c8e9bde2d3aa46c29089"},
{file = "pandas-2.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e09a53a4fe8d6ae2149959a2d02e1ef2f4d2ceb285ac48f74b79798507e468b4"}, {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eae3dc34fa1aa7772dd3fc60270d13ced7346fcbcfee017d3132ec625e23bb0"},
{file = "pandas-2.0.1-cp38-cp38-win32.whl", hash = "sha256:a2564629b3a47b6aa303e024e3d84e850d36746f7e804347f64229f8c87416ea"}, {file = "pandas-2.0.3-cp38-cp38-win32.whl", hash = "sha256:f3421a7afb1a43f7e38e82e844e2bca9a6d793d66c1a7f9f0ff39a795bbc5e02"},
{file = "pandas-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:03e677c6bc9cfb7f93a8b617d44f6091613a5671ef2944818469be7b42114a00"}, {file = "pandas-2.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:69d7f3884c95da3a31ef82b7618af5710dba95bb885ffab339aad925c3e8ce78"},
{file = "pandas-2.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d099ecaa5b9e977b55cd43cf842ec13b14afa1cfa51b7e1179d90b38c53ce6a"}, {file = "pandas-2.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5247fb1ba347c1261cbbf0fcfba4a3121fbb4029d95d9ef4dc45406620b25c8b"},
{file = "pandas-2.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a37ee35a3eb6ce523b2c064af6286c45ea1c7ff882d46e10d0945dbda7572753"}, {file = "pandas-2.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:81af086f4543c9d8bb128328b5d32e9986e0c84d3ee673a2ac6fb57fd14f755e"},
{file = "pandas-2.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:320b180d125c3842c5da5889183b9a43da4ebba375ab2ef938f57bf267a3c684"}, {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1994c789bf12a7c5098277fb43836ce090f1073858c10f9220998ac74f37c69b"},
{file = "pandas-2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18d22cb9043b6c6804529810f492ab09d638ddf625c5dea8529239607295cb59"}, {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ec591c48e29226bcbb316e0c1e9423622bc7a4eaf1ef7c3c9fa1a3981f89641"},
{file = "pandas-2.0.1-cp39-cp39-win32.whl", hash = "sha256:90d1d365d77d287063c5e339f49b27bd99ef06d10a8843cf00b1a49326d492c1"}, {file = "pandas-2.0.3-cp39-cp39-win32.whl", hash = "sha256:04dbdbaf2e4d46ca8da896e1805bc04eb85caa9a82e259e8eed00254d5e0c682"},
{file = "pandas-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:99f7192d8b0e6daf8e0d0fd93baa40056684e4b4aaaef9ea78dff34168e1f2f0"}, {file = "pandas-2.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:1168574b036cd8b93abc746171c9b4f1b83467438a5e45909fed645cf8692dbc"},
{file = "pandas-2.0.1.tar.gz", hash = "sha256:19b8e5270da32b41ebf12f0e7165efa7024492e9513fb46fb631c5022ae5709d"}, {file = "pandas-2.0.3.tar.gz", hash = "sha256:c02f372a88e0d17f36d3093a644c73cfc1788e876a7c4bcb4020a77512e2043c"},
] ]
[package.dependencies] [package.dependencies]
@ -388,7 +388,7 @@ pytz = ">=2020.1"
tzdata = ">=2022.1" tzdata = ">=2022.1"
[package.extras] [package.extras]
all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"]
aws = ["s3fs (>=2021.08.0)"] aws = ["s3fs (>=2021.08.0)"]
clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"] clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"]
compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"] compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"]
@ -407,14 +407,13 @@ plot = ["matplotlib (>=3.6.1)"]
postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"] postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"]
spss = ["pyreadstat (>=1.1.2)"] spss = ["pyreadstat (>=1.1.2)"]
sql-other = ["SQLAlchemy (>=1.4.16)"] sql-other = ["SQLAlchemy (>=1.4.16)"]
test = ["hypothesis (>=6.34.2)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.6.3)"] xml = ["lxml (>=4.6.3)"]
[[package]] [[package]]
name = "pillow" name = "pillow"
version = "9.5.0" version = "9.5.0"
description = "Python Imaging Library (Fork)" description = "Python Imaging Library (Fork)"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -490,11 +489,25 @@ files = [
docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"]
tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"]
[[package]]
name = "pluggy"
version = "1.2.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.7"
files = [
{file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"},
{file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"},
]
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]] [[package]]
name = "pyparsing" name = "pyparsing"
version = "3.0.9" version = "3.0.9"
description = "pyparsing module - Classes and methods to define and execute parsing grammars" description = "pyparsing module - Classes and methods to define and execute parsing grammars"
category = "main"
optional = false optional = false
python-versions = ">=3.6.8" python-versions = ">=3.6.8"
files = [ files = [
@ -505,11 +518,30 @@ files = [
[package.extras] [package.extras]
diagrams = ["jinja2", "railroad-diagrams"] diagrams = ["jinja2", "railroad-diagrams"]
[[package]]
name = "pytest"
version = "7.4.0"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"},
{file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"},
]
[package.dependencies]
colorama = {version = "*", markers = "sys_platform == \"win32\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
[package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.8.2" version = "2.8.2"
description = "Extensions to the standard Python datetime module" description = "Extensions to the standard Python datetime module"
category = "main"
optional = false optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [ files = [
@ -524,7 +556,6 @@ six = ">=1.5"
name = "pytz" name = "pytz"
version = "2023.3" version = "2023.3"
description = "World timezone definitions, modern and historical" description = "World timezone definitions, modern and historical"
category = "main"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
@ -536,7 +567,6 @@ files = [
name = "quapy" name = "quapy"
version = "0.1.7" version = "0.1.7"
description = "QuaPy: a framework for Quantification in Python" description = "QuaPy: a framework for Quantification in Python"
category = "main"
optional = false optional = false
python-versions = ">=3.6, <4" python-versions = ">=3.6, <4"
files = [ files = [
@ -557,7 +587,6 @@ xlrd = "*"
name = "scikit-learn" name = "scikit-learn"
version = "1.2.2" version = "1.2.2"
description = "A set of python modules for machine learning and data mining" description = "A set of python modules for machine learning and data mining"
category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -600,7 +629,6 @@ tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy (
name = "scipy" name = "scipy"
version = "1.9.3" version = "1.9.3"
description = "Fundamental algorithms for scientific computing in Python" description = "Fundamental algorithms for scientific computing in Python"
category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -639,7 +667,6 @@ test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "sciki
name = "six" name = "six"
version = "1.16.0" version = "1.16.0"
description = "Python 2 and 3 compatibility utilities" description = "Python 2 and 3 compatibility utilities"
category = "main"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
files = [ files = [
@ -651,7 +678,6 @@ files = [
name = "threadpoolctl" name = "threadpoolctl"
version = "3.1.0" version = "3.1.0"
description = "threadpoolctl" description = "threadpoolctl"
category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
@ -663,7 +689,6 @@ files = [
name = "tqdm" name = "tqdm"
version = "4.65.0" version = "4.65.0"
description = "Fast, Extensible Progress Meter" description = "Fast, Extensible Progress Meter"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -684,7 +709,6 @@ telegram = ["requests"]
name = "tzdata" name = "tzdata"
version = "2023.3" version = "2023.3"
description = "Provider of IANA time zone data" description = "Provider of IANA time zone data"
category = "main"
optional = false optional = false
python-versions = ">=2" python-versions = ">=2"
files = [ files = [
@ -696,7 +720,6 @@ files = [
name = "xlrd" name = "xlrd"
version = "2.0.1" version = "2.0.1"
description = "Library for developers to extract data from Microsoft Excel (tm) .xls spreadsheet files" description = "Library for developers to extract data from Microsoft Excel (tm) .xls spreadsheet files"
category = "main"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
files = [ files = [
@ -712,4 +735,4 @@ test = ["pytest", "pytest-cov"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "811aa60aea2ab4cf9b4bc4ad3e546e8ee1a81d78f15acd35f3f736b5f97512b4" content-hash = "834ffb619893a1fb006e1b5a3213cc772117c9000e719b95a4478f74fd5d0066"

View File

@ -8,11 +8,15 @@ readme = "README.md"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.11" python = "^3.11"
quapy = "^0.1.7" quapy = "^0.1.7"
pandas = "^2.0.3"
[tool.poetry.scripts] [tool.poetry.scripts]
main = "quacc.main:main" main = "quacc.main:main"
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@ -1,11 +1,40 @@
from typing import List, Optional from typing import Any, List, Optional
import numpy as np import numpy as np
import math
import quapy as qp import quapy as qp
import scipy.sparse as sp import scipy.sparse as sp
from quapy.data import LabelledCollection from quapy.data import LabelledCollection
# Extended classes
#
# 0 ~ True 0
# 1 ~ False 1
# 2 ~ False 0
# 3 ~ True 1
# _____________________
# | | |
# | True 0 | False 1 |
# |__________|__________|
# | | |
# | False 0 | True 1 |
# |__________|__________|
#
class ExClassManager:
@staticmethod
def get_ex(n_classes: int, true_class: int, pred_class: int) -> int:
return true_class * n_classes + pred_class
@staticmethod
def get_pred(n_classes: int, ex_class: int) -> int:
return ex_class % n_classes
@staticmethod
def get_true(n_classes: int, ex_class: int) -> int:
return ex_class // n_classes
class ExtendedCollection(LabelledCollection): class ExtendedCollection(LabelledCollection):
def __init__( def __init__(
self, self,
@ -15,6 +44,67 @@ class ExtendedCollection(LabelledCollection):
): ):
super().__init__(instances, labels, classes=classes) super().__init__(instances, labels, classes=classes)
def split_by_pred(self):
_ncl = int(math.sqrt(self.n_classes))
_indexes = ExtendedCollection.split_index_by_pred(_ncl, self.instances)
return [
ExtendedCollection(
self.instances[ind] if len(ind) > 0 else np.asarray([], dtype=int),
np.asarray(
[
ExClassManager.get_true(_ncl, lbl)
for lbl in (self.labels[ind] if len(ind) > 0 else [])
],
dtype=int,
),
classes=range(0, _ncl),
)
for ind in _indexes
]
@classmethod
def split_index_by_pred(
cls, n_classes: int, instances: np.ndarray
) -> List[np.ndarray]:
_pred_label = [np.argmax(inst[-n_classes:], axis=0) for inst in instances]
return [
np.asarray([j for (j, x) in enumerate(_pred_label) if x == i])
for i in range(0, n_classes)
]
@classmethod
def extend_instances(
cls, instances: np.ndarray, pred_proba: np.ndarray
) -> np.ndarray:
if isinstance(instances, sp.csr_matrix):
_pred_proba = sp.csr_matrix(pred_proba)
n_x = sp.hstack([instances, _pred_proba])
elif isinstance(instances, np.ndarray):
n_x = np.concatenate((instances, pred_proba), axis=1)
else:
raise ValueError("Unsupported matrix format")
return n_x
@classmethod
def extend_collection(cls, base: LabelledCollection, pred_proba: np.ndarray) -> Any:
n_classes = base.n_classes
# n_X = [ X | predicted probs. ]
n_x = cls.extend_instances(base.X, pred_proba)
# n_y = (exptected y, predicted y)
pred = np.asarray([prob.argmax(axis=0) for prob in pred_proba])
n_y = np.asarray(
[
ExClassManager.get_ex(n_classes, true_class, pred_class)
for (true_class, pred_class) in zip(base.y, pred)
]
)
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])
def get_dataset(name): def get_dataset(name):
datasets = { datasets = {
"spambase": lambda: qp.datasets.fetch_UCIDataset( "spambase": lambda: qp.datasets.fetch_UCIDataset(

View File

@ -1,11 +1,14 @@
from abc import abstractmethod
import math
import numpy as np import numpy as np
import scipy.sparse as sp
from quapy.data import LabelledCollection from quapy.data import LabelledCollection
from quapy.method.base import BaseQuantifier from quapy.method.aggregative import SLD
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict from sklearn.model_selection import cross_val_predict
from .data import ExtendedCollection from quacc.data import ExtendedCollection as EC
def _check_prevalence_classes(true_classes, estim_classes, estim_prev): def _check_prevalence_classes(true_classes, estim_classes, estim_prev):
@ -15,60 +18,36 @@ def _check_prevalence_classes(true_classes, estim_classes, estim_prev):
return estim_prev return estim_prev
def _get_ex_class(classes, true_class, pred_class):
return true_class * classes + pred_class
def _extend_instances(instances, pred_proba):
if isinstance(instances, sp.csr_matrix):
_pred_proba = sp.csr_matrix(pred_proba)
n_x = sp.hstack([instances, _pred_proba])
elif isinstance(instances, np.ndarray):
n_x = np.concatenate((instances, pred_proba), axis=1)
else:
raise ValueError("Unsupported matrix format")
return n_x
def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollection:
n_classes = base.n_classes
# n_X = [ X | predicted probs. ]
n_x = _extend_instances(base.X, pred_proba)
# n_y = (exptected y, predicted y)
pred = np.asarray([prob.argmax(axis=0) for prob in pred_proba])
n_y = np.asarray(
[
_get_ex_class(n_classes, true_class, pred_class)
for (true_class, pred_class) in zip(base.y, pred)
]
)
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])
class AccuracyEstimator: class AccuracyEstimator:
def __init__(self, model: BaseEstimator, q_model: BaseQuantifier): def extend(self, base: LabelledCollection, pred_proba=None) -> EC:
self.model = model
self.q_model = q_model
self.e_train = None
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
if not pred_proba: if not pred_proba:
pred_proba = self.model.predict_proba(base.X) pred_proba = self.model.predict_proba(base.X)
return _extend_collection(base, pred_proba) return EC.extend_collection(base, pred_proba)
def fit(self, train: LabelledCollection | ExtendedCollection): @abstractmethod
def fit(self, train: LabelledCollection | EC):
...
@abstractmethod
def estimate(self, instances, ext=False):
...
class MulticlassAccuracyEstimator(AccuracyEstimator):
def __init__(self, c_model: BaseEstimator):
self.c_model = c_model
self.q_model = SLD(LogisticRegression())
self.e_train = None
def fit(self, train: LabelledCollection | EC):
# check if model is fit # check if model is fit
# self.model.fit(*train.Xy) # self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection): if isinstance(train, LabelledCollection):
pred_prob_train = cross_val_predict( pred_prob_train = cross_val_predict(
self.model, *train.Xy, method="predict_proba" self.c_model, *train.Xy, method="predict_proba"
) )
self.e_train = _extend_collection(train, pred_prob_train) self.e_train = EC.extend_collection(train, pred_prob_train)
else: else:
self.e_train = train self.e_train = train
@ -76,8 +55,8 @@ class AccuracyEstimator:
def estimate(self, instances, ext=False): def estimate(self, instances, ext=False):
if not ext: if not ext:
pred_prob = self.model.predict_proba(instances) pred_prob = self.c_model.predict_proba(instances)
e_inst = _extend_instances(instances, pred_prob) e_inst = EC.extend_instances(instances, pred_prob)
else: else:
e_inst = instances e_inst = instances
@ -86,3 +65,51 @@ class AccuracyEstimator:
return _check_prevalence_classes( return _check_prevalence_classes(
self.e_train.classes_, self.q_model.classes_, estim_prev self.e_train.classes_, self.q_model.classes_, estim_prev
) )
class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
def __init__(self, c_model: BaseEstimator):
self.c_model = c_model
self.q_model_0 = SLD(LogisticRegression())
self.q_model_1 = SLD(LogisticRegression())
self.e_train: EC = None
def fit(self, train: LabelledCollection | EC):
# check if model is fit
# self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection):
pred_prob_train = cross_val_predict(
self.c_model, *train.Xy, method="predict_proba"
)
self.e_train = EC.extend_collection(train, pred_prob_train)
else:
self.e_train = train
[e_train_0, e_train_1] = self.e_train.split_by_pred()
self.q_model_0.fit(self.e_train_0)
self.q_model_1.fit(self.e_train_1)
def estimate(self, instances, ext=False):
# TODO: test
if not ext:
pred_prob = self.c_model.predict_proba(instances)
e_inst = EC.extend_instances(instances, pred_prob)
else:
e_inst = instances
_ncl = int(math.sqrt(self.e_train.n_classes))
[e_inst_0, e_inst_1] = [
e_inst[ind] for ind in EC.split_index_by_pred(_ncl, e_inst)
]
estim_prev_0 = self.q_model_0.quantify(e_inst_0)
estim_prev_1 = self.q_model_1.quantify(e_inst_1)
estim_prev = []
for prev_row in zip(estim_prev_0, estim_prev_1):
for prev in prev_row:
estim_prev.append(prev)
return estim_prev

View File

@ -1,13 +1,12 @@
import pandas as pd import pandas as pd
import quapy as qp import quapy as qp
from quapy.method.aggregative import SLD
from quapy.protocol import APP from quapy.protocol import APP
from sklearn.svm import SVC from sklearn.linear_model import LogisticRegression
import quacc.evaluation as eval import quacc.evaluation as eval
from quacc.estimator import AccuracyEstimator from quacc.estimator import MulticlassAccuracyEstimator
from .data import get_dataset from quacc.data import get_dataset
qp.environ["SAMPLE_SIZE"] = 100 qp.environ["SAMPLE_SIZE"] = 100
@ -17,16 +16,17 @@ pd.set_option("display.float_format", "{:.4f}".format)
def test_2(dataset_name): def test_2(dataset_name):
train, test = get_dataset(dataset_name) train, test = get_dataset(dataset_name)
model = SVC(probability=True) model = LogisticRegression()
print(f"fitting model {model.__class__.__name__}...", end=" ", flush=True) print(f"fitting model {model.__class__.__name__}...", end=" ", flush=True)
model.fit(*train.Xy) model.fit(*train.Xy)
print("fit") print("fit")
qmodel = SLD(SVC(probability=True)) estimator = MulticlassAccuracyEstimator(model)
estimator = AccuracyEstimator(model, qmodel)
print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ", flush=True) print(
f"fitting qmodel {estimator.q_model.__class__.__name__}...", end=" ", flush=True
)
estimator.fit(train) estimator.fit(train)
print("fit") print("fit")

0
tests/__init__.py Normal file
View File

94
tests/test_data.py Normal file
View File

@ -0,0 +1,94 @@
import pytest
from quacc.data import ExClassManager as ECM, ExtendedCollection
import numpy as np
class TestExClassManager:
@pytest.mark.parametrize(
"true_class,pred_class,result",
[
(0, 0, 0),
(0, 1, 1),
(1, 0, 2),
(1, 1, 3),
],
)
def test_get_ex(self, true_class, pred_class, result):
ncl = 2
assert ECM.get_ex(ncl, true_class, pred_class) == result
@pytest.mark.parametrize(
"ex_class,result",
[
(0, 0),
(1, 1),
(2, 0),
(3, 1),
],
)
def test_get_pred(self, ex_class, result):
ncl = 2
assert ECM.get_pred(ncl, ex_class) == result
@pytest.mark.parametrize(
"ex_class,result",
[
(0, 0),
(1, 0),
(2, 1),
(3, 1),
],
)
def test_get_true(self, ex_class, result):
ncl = 2
assert ECM.get_true(ncl, ex_class) == result
class TestExtendedCollection:
@pytest.mark.parametrize(
"instances,labels,inst0,lbl0,inst1,lbl1",
[
(
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]],
[3, 0, 1, 2],
[[1, 0.54, 0.46], [3, 0.6, 0.4]],
[0, 1],
[[0, 0.3, 0.7], [2, 0.28, 0.72]],
[1, 0],
),
(
[[0, 0.3, 0.7], [2, 0.28, 0.72]],
[3, 1],
[],
[],
[[0, 0.3, 0.7], [2, 0.28, 0.72]],
[1, 0],
),
(
[[1, 0.54, 0.46], [3, 0.6, 0.4]],
[0, 2],
[[1, 0.54, 0.46], [3, 0.6, 0.4]],
[0, 1],
[],
[],
),
],
)
def test_split_by_pred(self, instances, labels, inst0, lbl0, inst1, lbl1):
ec = ExtendedCollection(
np.asarray(instances), np.asarray(labels), classes=range(0, 4)
)
[ec0, ec1] = ec.split_by_pred()
print(ec0.X, np.asarray(inst0))
assert( np.array_equal(ec0.X, np.asarray(inst0)) )
print(ec0.y, np.asarray(lbl0))
assert( np.array_equal(ec0.y, np.asarray(lbl0)) )
print(ec1.X, np.asarray(inst1))
assert( np.array_equal(ec1.X, np.asarray(inst1)) )
print(ec1.y, np.asarray(lbl1))
assert( np.array_equal(ec1.y, np.asarray(lbl1)) )

4
tests/test_estimator.py Normal file
View File

@ -0,0 +1,4 @@
class TestBinaryQuantifierAccuracyEstimator:
def test_estimate(self):
pass