plots, avg table, conf added; method updated
|
@ -11,4 +11,6 @@ lipton_bbse/__pycache__/*
|
||||||
elsahar19_rca/__pycache__/*
|
elsahar19_rca/__pycache__/*
|
||||||
*.coverage
|
*.coverage
|
||||||
.coverage
|
.coverage
|
||||||
scp_sync.py
|
scp_sync.py
|
||||||
|
out/*
|
||||||
|
output/*
|
|
@ -41,12 +41,12 @@
|
||||||
</head>
|
</head>
|
||||||
<body class="vscode-body vscode-light">
|
<body class="vscode-body vscode-light">
|
||||||
<ul class="contains-task-list">
|
<ul class="contains-task-list">
|
||||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> aggiungere media tabelle</li>
|
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> aggiungere media tabelle</li>
|
||||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> plot; 3 tipi (appunti + email + garg)</li>
|
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> plot; 3 tipi (appunti + email + garg)</li>
|
||||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> sistemare kfcv baseline</li>
|
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> sistemare kfcv baseline</li>
|
||||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> aggiungere metodo con CC oltre SLD</li>
|
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> aggiungere metodo con CC oltre SLD</li>
|
||||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)</li>
|
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)</li>
|
||||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> variare parametro recalibration in SLD</li>
|
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> variare parametro recalibration in SLD</li>
|
||||||
</ul>
|
</ul>
|
||||||
|
|
||||||
|
|
||||||
|
|
8
TODO.md
|
@ -1,6 +1,6 @@
|
||||||
- [ ] aggiungere media tabelle
|
- [x] aggiungere media tabelle
|
||||||
- [ ] plot; 3 tipi (appunti + email + garg)
|
- [x] plot; 3 tipi (appunti + email + garg)
|
||||||
- [ ] sistemare kfcv baseline
|
- [ ] sistemare kfcv baseline
|
||||||
- [ ] aggiungere metodo con CC oltre SLD
|
- [x] aggiungere metodo con CC oltre SLD
|
||||||
- [x] prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)
|
- [x] prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)
|
||||||
- [ ] variare parametro recalibration in SLD
|
- [x] variare parametro recalibration in SLD
|
|
@ -0,0 +1,71 @@
|
||||||
|
|
||||||
|
exec: []
|
||||||
|
|
||||||
|
commons:
|
||||||
|
- DATASET_NAME: rcv1
|
||||||
|
DATASET_TARGET: CCAT
|
||||||
|
METRICS:
|
||||||
|
- acc
|
||||||
|
- f1
|
||||||
|
DATASET_N_PREVS: 9
|
||||||
|
- DATASET_NAME: imdb
|
||||||
|
METRICS:
|
||||||
|
- acc
|
||||||
|
- f1
|
||||||
|
DATASET_N_PREVS: 9
|
||||||
|
|
||||||
|
confs:
|
||||||
|
|
||||||
|
all_mul_vs_atc:
|
||||||
|
COMP_ESTIMATORS:
|
||||||
|
- our_mul_SLD
|
||||||
|
- our_mul_SLD_nbvs
|
||||||
|
- our_mul_SLD_bcts
|
||||||
|
- our_mul_SLD_ts
|
||||||
|
- our_mul_SLD_vs
|
||||||
|
- our_mul_CC
|
||||||
|
- ref
|
||||||
|
- atc_mc
|
||||||
|
- atc_ne
|
||||||
|
|
||||||
|
all_bin_vs_atc:
|
||||||
|
COMP_ESTIMATORS:
|
||||||
|
- our_bin_SLD
|
||||||
|
- our_bin_SLD_nbvs
|
||||||
|
- our_bin_SLD_bcts
|
||||||
|
- our_bin_SLD_ts
|
||||||
|
- our_bin_SLD_vs
|
||||||
|
- our_bin_CC
|
||||||
|
- ref
|
||||||
|
- atc_mc
|
||||||
|
- atc_ne
|
||||||
|
|
||||||
|
best_our_vs_atc:
|
||||||
|
COMP_ESTIMATORS:
|
||||||
|
- our_bin_SLD
|
||||||
|
- our_bin_SLD_bcts
|
||||||
|
- our_bin_SLD_vs
|
||||||
|
- our_bin_CC
|
||||||
|
- our_mul_SLD
|
||||||
|
- our_mul_SLD_bcts
|
||||||
|
- our_mul_SLD_vs
|
||||||
|
- our_mul_CC
|
||||||
|
- ref
|
||||||
|
- atc_mc
|
||||||
|
- atc_ne
|
||||||
|
|
||||||
|
best_our_vs_all:
|
||||||
|
COMP_ESTIMATORS:
|
||||||
|
- our_bin_SLD
|
||||||
|
- our_bin_SLD_bcts
|
||||||
|
- our_bin_SLD_vs
|
||||||
|
- our_bin_CC
|
||||||
|
- our_mul_SLD
|
||||||
|
- our_mul_SLD_bcts
|
||||||
|
- our_mul_SLD_vs
|
||||||
|
- our_mul_CC
|
||||||
|
- ref
|
||||||
|
- kfcv
|
||||||
|
- atc_mc
|
||||||
|
- atc_ne
|
||||||
|
- doc_feat
|
Before Width: | Height: | Size: 188 KiB |
Before Width: | Height: | Size: 198 KiB |
Before Width: | Height: | Size: 225 KiB |
Before Width: | Height: | Size: 244 KiB |
Before Width: | Height: | Size: 266 KiB |
Before Width: | Height: | Size: 231 KiB |
Before Width: | Height: | Size: 200 KiB |
Before Width: | Height: | Size: 192 KiB |
Before Width: | Height: | Size: 175 KiB |
1955
out/rcv1_CCAT.md
|
@ -956,6 +956,65 @@ files = [
|
||||||
{file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"},
|
{file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyyaml"
|
||||||
|
version = "6.0.1"
|
||||||
|
description = "YAML parser and emitter for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
|
||||||
|
{file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
|
||||||
|
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
|
||||||
|
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
|
||||||
|
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
|
||||||
|
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
|
||||||
|
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
|
||||||
|
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
|
||||||
|
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
|
||||||
|
{file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
|
||||||
|
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
|
||||||
|
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
|
||||||
|
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
|
||||||
|
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
|
||||||
|
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
|
||||||
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
|
||||||
|
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
|
||||||
|
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
|
||||||
|
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
|
||||||
|
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"},
|
||||||
|
{file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"},
|
||||||
|
{file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"},
|
||||||
|
{file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"},
|
||||||
|
{file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"},
|
||||||
|
{file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"},
|
||||||
|
{file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"},
|
||||||
|
{file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"},
|
||||||
|
{file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"},
|
||||||
|
{file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"},
|
||||||
|
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
|
||||||
|
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
|
||||||
|
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
|
||||||
|
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
|
||||||
|
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
|
||||||
|
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
|
||||||
|
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
|
||||||
|
{file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
|
||||||
|
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
|
||||||
|
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
|
||||||
|
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
|
||||||
|
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
|
||||||
|
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
|
||||||
|
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
|
||||||
|
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quapy"
|
name = "quapy"
|
||||||
version = "0.1.7"
|
version = "0.1.7"
|
||||||
|
@ -1164,4 +1223,4 @@ test = ["pytest", "pytest-cov"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "72e3afd9a24b88fc8a8f5f55e1c408f65090fce9015a442f6f41638191276b6f"
|
content-hash = "0ce0e6b058900e7db2939e7eb047a1f868c88de67def370c1c1fa0ba532df0b0"
|
||||||
|
|
|
@ -10,6 +10,7 @@ python = "^3.11"
|
||||||
quapy = "^0.1.7"
|
quapy = "^0.1.7"
|
||||||
pandas = "^2.0.3"
|
pandas = "^2.0.3"
|
||||||
jinja2 = "^3.1.2"
|
jinja2 = "^3.1.2"
|
||||||
|
pyyaml = "^6.0.1"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
main = "quacc.main:main"
|
main = "quacc.main:main"
|
||||||
|
|
|
@ -1,21 +1,33 @@
|
||||||
from pathlib import Path
|
import yaml
|
||||||
|
|
||||||
defalut_env = {
|
defalut_env = {
|
||||||
"DATASET_NAME": "rcv1",
|
"DATASET_NAME": "rcv1",
|
||||||
"DATASET_TARGET": "CCAT",
|
"DATASET_TARGET": "CCAT",
|
||||||
|
"METRICS": ["acc", "f1"],
|
||||||
"COMP_ESTIMATORS": [
|
"COMP_ESTIMATORS": [
|
||||||
"OUR_BIN_SLD",
|
"our_bin_SLD",
|
||||||
"OUR_MUL_SLD",
|
"our_bin_SLD_nbvs",
|
||||||
"KFCV",
|
"our_bin_SLD_bcts",
|
||||||
"ATC_MC",
|
"our_bin_SLD_ts",
|
||||||
"ATC_NE",
|
"our_bin_SLD_vs",
|
||||||
"DOC_FEAT",
|
"our_bin_CC",
|
||||||
# "RCA",
|
"our_mul_SLD",
|
||||||
# "RCA_STAR",
|
"our_mul_SLD_nbvs",
|
||||||
|
"our_mul_SLD_bcts",
|
||||||
|
"our_mul_SLD_ts",
|
||||||
|
"our_mul_SLD_vs",
|
||||||
|
"our_mul_CC",
|
||||||
|
"ref",
|
||||||
|
"kfcv",
|
||||||
|
"atc_mc",
|
||||||
|
"atc_ne",
|
||||||
|
"doc_feat",
|
||||||
|
"rca",
|
||||||
|
"rca_star",
|
||||||
],
|
],
|
||||||
"DATASET_N_PREVS": 9,
|
"DATASET_N_PREVS": 9,
|
||||||
"OUT_DIR": Path("out"),
|
"OUT_DIR_NAME": "output",
|
||||||
"PLOT_OUT_DIR": Path("out/plot"),
|
"PLOT_DIR_NAME": "plot",
|
||||||
"PROTOCOL_N_PREVS": 21,
|
"PROTOCOL_N_PREVS": 21,
|
||||||
"PROTOCOL_REPEATS": 100,
|
"PROTOCOL_REPEATS": 100,
|
||||||
"SAMPLE_SIZE": 1000,
|
"SAMPLE_SIZE": 1000,
|
||||||
|
@ -24,8 +36,37 @@ defalut_env = {
|
||||||
|
|
||||||
class Environ:
|
class Environ:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
for k, v in kwargs.items():
|
self.exec = []
|
||||||
|
self.confs = {}
|
||||||
|
self.__setdict(kwargs)
|
||||||
|
|
||||||
|
def __setdict(self, d):
|
||||||
|
for k, v in d.items():
|
||||||
self.__setattr__(k, v)
|
self.__setattr__(k, v)
|
||||||
|
|
||||||
|
def load_conf(self):
|
||||||
|
with open("conf.yaml", "r") as f:
|
||||||
|
confs = yaml.safe_load(f)
|
||||||
|
|
||||||
|
for common in confs["commons"]:
|
||||||
|
name = common["DATASET_NAME"]
|
||||||
|
if "DATASET_TARGET" in common:
|
||||||
|
name += "_" + common["DATASET_TARGET"]
|
||||||
|
for k, d in confs["confs"].items():
|
||||||
|
_k = f"{name}_{k}"
|
||||||
|
self.confs[_k] = common | d
|
||||||
|
self.exec.append(_k)
|
||||||
|
|
||||||
|
if "exec" in confs:
|
||||||
|
if len(confs["exec"]) > 0:
|
||||||
|
self.exec = confs["exec"]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.load_conf()
|
||||||
|
for _conf in self.exec:
|
||||||
|
if _conf in self.confs:
|
||||||
|
self.__setdict(self.confs[_conf])
|
||||||
|
yield _conf
|
||||||
|
|
||||||
|
|
||||||
env = Environ(**defalut_env)
|
env = Environ(**defalut_env)
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
|
|
||||||
|
|
||||||
def from_name(err_name):
|
def from_name(err_name):
|
||||||
if err_name == 'f1e':
|
if err_name == "f1e":
|
||||||
return f1e
|
return f1e
|
||||||
elif err_name == 'f1':
|
elif err_name == "f1":
|
||||||
return f1
|
return f1
|
||||||
else:
|
else:
|
||||||
return qp.error.from_name(err_name)
|
return qp.error.from_name(err_name)
|
||||||
|
|
||||||
|
|
||||||
# def f1(prev):
|
# def f1(prev):
|
||||||
# # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
|
# # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
|
||||||
# if prev[0] == 0 and prev[1] == 0 and prev[2] == 0:
|
# if prev[0] == 0 and prev[1] == 0 and prev[2] == 0:
|
||||||
|
@ -18,18 +20,21 @@ def from_name(err_name):
|
||||||
# return float('NaN')
|
# return float('NaN')
|
||||||
# else:
|
# else:
|
||||||
# recall = prev[0] / (prev[0] + prev[1])
|
# recall = prev[0] / (prev[0] + prev[1])
|
||||||
# precision = prev[0] / (prev[0] + prev[2])
|
# precision = prev[0] / (prev[0] + prev[2])
|
||||||
# return 2 * (precision * recall) / (precision + recall)
|
# return 2 * (precision * recall) / (precision + recall)
|
||||||
|
|
||||||
|
|
||||||
def f1(prev):
|
def f1(prev):
|
||||||
den = (2*prev[3]) + prev[1] + prev[2]
|
den = (2 * prev[3]) + prev[1] + prev[2]
|
||||||
if den == 0:
|
if den == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
else:
|
else:
|
||||||
return (2*prev[3])/den
|
return (2 * prev[3]) / den
|
||||||
|
|
||||||
|
|
||||||
def f1e(prev):
|
def f1e(prev):
|
||||||
return 1 - f1(prev)
|
return 1 - f1(prev)
|
||||||
|
|
||||||
|
|
||||||
def acc(prev):
|
def acc(prev):
|
||||||
return (prev[1] + prev[2]) / sum(prev)
|
return (prev[0] + prev[3]) / sum(prev)
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from abc import abstractmethod
|
|
||||||
import math
|
import math
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.method.aggregative import SLD
|
from quapy.method.aggregative import CC, SLD
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.model_selection import cross_val_predict
|
from sklearn.model_selection import cross_val_predict
|
||||||
|
@ -15,7 +15,7 @@ class AccuracyEstimator:
|
||||||
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||||
if not pred_proba:
|
if not pred_proba:
|
||||||
pred_proba = self.c_model.predict_proba(base.X)
|
pred_proba = self.c_model.predict_proba(base.X)
|
||||||
return ExtendedCollection.extend_collection(base, pred_proba)
|
return ExtendedCollection.extend_collection(base, pred_proba), pred_proba
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||||
|
@ -27,9 +27,15 @@ class AccuracyEstimator:
|
||||||
|
|
||||||
|
|
||||||
class MulticlassAccuracyEstimator(AccuracyEstimator):
|
class MulticlassAccuracyEstimator(AccuracyEstimator):
|
||||||
def __init__(self, c_model: BaseEstimator):
|
def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs):
|
||||||
self.c_model = c_model
|
self.c_model = c_model
|
||||||
self.q_model = SLD(LogisticRegression())
|
if q_model == "SLD":
|
||||||
|
available_args = ["recalib"]
|
||||||
|
sld_args = {k: v for k, v in kwargs.items() if k in available_args}
|
||||||
|
self.q_model = SLD(LogisticRegression(), **sld_args)
|
||||||
|
elif q_model == "CC":
|
||||||
|
self.q_model = CC(LogisticRegression())
|
||||||
|
|
||||||
self.e_train = None
|
self.e_train = None
|
||||||
|
|
||||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||||
|
@ -67,10 +73,17 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
|
||||||
|
|
||||||
|
|
||||||
class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
||||||
def __init__(self, c_model: BaseEstimator):
|
def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs):
|
||||||
self.c_model = c_model
|
self.c_model = c_model
|
||||||
self.q_model_0 = SLD(LogisticRegression())
|
if q_model == "SLD":
|
||||||
self.q_model_1 = SLD(LogisticRegression())
|
available_args = ["recalib"]
|
||||||
|
sld_args = {k: v for k, v in kwargs.items() if k in available_args}
|
||||||
|
self.q_model_0 = SLD(LogisticRegression(), **sld_args)
|
||||||
|
self.q_model_1 = SLD(LogisticRegression(), **sld_args)
|
||||||
|
elif q_model == "CC":
|
||||||
|
self.q_model_0 = CC(LogisticRegression())
|
||||||
|
self.q_model_1 = CC(LogisticRegression())
|
||||||
|
|
||||||
self.e_train = None
|
self.e_train = None
|
||||||
|
|
||||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||||
|
@ -83,7 +96,7 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
||||||
|
|
||||||
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
||||||
elif isinstance(train, ExtendedCollection):
|
elif isinstance(train, ExtendedCollection):
|
||||||
self.e_train = train
|
self.e_train = train
|
||||||
|
|
||||||
self.n_classes = self.e_train.n_classes
|
self.n_classes = self.e_train.n_classes
|
||||||
[e_train_0, e_train_1] = self.e_train.split_by_pred()
|
[e_train_0, e_train_1] = self.e_train.split_by_pred()
|
||||||
|
|
|
@ -34,14 +34,14 @@ def kfcv(
|
||||||
# ensure that the protocol returns a LabelledCollection for each iteration
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
report = EvaluationReport(prefix="kfcv")
|
report = EvaluationReport(name="kfcv")
|
||||||
for test in protocol():
|
for test in protocol():
|
||||||
test_preds = c_model_predict(test.X)
|
test_preds = c_model_predict(test.X)
|
||||||
meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds))
|
meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds))
|
||||||
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds))
|
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds))
|
||||||
report.append_row(
|
report.append_row(
|
||||||
test.prevalence(),
|
test.prevalence(),
|
||||||
acc_score=(1.0 - acc_score),
|
acc_score=acc_score,
|
||||||
f1_score=f1_score,
|
f1_score=f1_score,
|
||||||
acc=meta_acc,
|
acc=meta_acc,
|
||||||
f1=meta_f1,
|
f1=meta_f1,
|
||||||
|
@ -57,13 +57,13 @@ def reference(
|
||||||
):
|
):
|
||||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
c_model_predict = getattr(c_model, "predict_proba")
|
c_model_predict = getattr(c_model, "predict_proba")
|
||||||
report = EvaluationReport(prefix="ref")
|
report = EvaluationReport(name="ref")
|
||||||
for test in protocol():
|
for test in protocol():
|
||||||
test_probs = c_model_predict(test.X)
|
test_probs = c_model_predict(test.X)
|
||||||
test_preds = np.argmax(test_probs, axis=-1)
|
test_preds = np.argmax(test_probs, axis=-1)
|
||||||
report.append_row(
|
report.append_row(
|
||||||
test.prevalence(),
|
test.prevalence(),
|
||||||
acc_score=(1 - metrics.accuracy_score(test.y, test_preds)),
|
acc_score=metrics.accuracy_score(test.y, test_preds),
|
||||||
f1_score=metrics.f1_score(test.y, test_preds),
|
f1_score=metrics.f1_score(test.y, test_preds),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def atc_mc(
|
||||||
# ensure that the protocol returns a LabelledCollection for each iteration
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
report = EvaluationReport(prefix="atc_mc")
|
report = EvaluationReport(name="atc_mc")
|
||||||
for test in protocol():
|
for test in protocol():
|
||||||
## Load OOD test data probs
|
## Load OOD test data probs
|
||||||
test_probs = c_model_predict(test.X)
|
test_probs = c_model_predict(test.X)
|
||||||
|
@ -102,7 +102,7 @@ def atc_mc(
|
||||||
report.append_row(
|
report.append_row(
|
||||||
test.prevalence(),
|
test.prevalence(),
|
||||||
acc=meta_acc,
|
acc=meta_acc,
|
||||||
acc_score=1.0 - atc_accuracy,
|
acc_score=atc_accuracy,
|
||||||
f1_score=f1_score,
|
f1_score=f1_score,
|
||||||
f1=meta_f1,
|
f1=meta_f1,
|
||||||
)
|
)
|
||||||
|
@ -129,7 +129,7 @@ def atc_ne(
|
||||||
# ensure that the protocol returns a LabelledCollection for each iteration
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
report = EvaluationReport(prefix="atc_ne")
|
report = EvaluationReport(name="atc_ne")
|
||||||
for test in protocol():
|
for test in protocol():
|
||||||
## Load OOD test data probs
|
## Load OOD test data probs
|
||||||
test_probs = c_model_predict(test.X)
|
test_probs = c_model_predict(test.X)
|
||||||
|
@ -142,7 +142,7 @@ def atc_ne(
|
||||||
report.append_row(
|
report.append_row(
|
||||||
test.prevalence(),
|
test.prevalence(),
|
||||||
acc=meta_acc,
|
acc=meta_acc,
|
||||||
acc_score=(1.0 - atc_accuracy),
|
acc_score=atc_accuracy,
|
||||||
f1_score=f1_score,
|
f1_score=f1_score,
|
||||||
f1=meta_f1,
|
f1=meta_f1,
|
||||||
)
|
)
|
||||||
|
@ -182,14 +182,14 @@ def doc_feat(
|
||||||
# ensure that the protocol returns a LabelledCollection for each iteration
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
report = EvaluationReport(prefix="doc_feat")
|
report = EvaluationReport(name="doc_feat")
|
||||||
for test in protocol():
|
for test in protocol():
|
||||||
test_probs = c_model_predict(test.X)
|
test_probs = c_model_predict(test.X)
|
||||||
test_preds = np.argmax(test_probs, axis=-1)
|
test_preds = np.argmax(test_probs, axis=-1)
|
||||||
test_scores = np.max(test_probs, axis=-1)
|
test_scores = np.max(test_probs, axis=-1)
|
||||||
score = (v1acc + doc.get_doc(val_scores, test_scores)) / 100.0
|
score = (v1acc + doc.get_doc(val_scores, test_scores)) / 100.0
|
||||||
meta_acc = abs(score - metrics.accuracy_score(test.y, test_preds))
|
meta_acc = abs(score - metrics.accuracy_score(test.y, test_preds))
|
||||||
report.append_row(test.prevalence(), acc=meta_acc, acc_score=(1.0 - score))
|
report.append_row(test.prevalence(), acc=meta_acc, acc_score=score)
|
||||||
|
|
||||||
return report
|
return report
|
||||||
|
|
||||||
|
@ -206,17 +206,15 @@ def rca_score(
|
||||||
# ensure that the protocol returns a LabelledCollection for each iteration
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
report = EvaluationReport(prefix="rca")
|
report = EvaluationReport(name="rca")
|
||||||
for test in protocol():
|
for test in protocol():
|
||||||
try:
|
try:
|
||||||
test_pred = c_model_predict(test.X)
|
test_pred = c_model_predict(test.X)
|
||||||
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
|
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
|
||||||
c_model2_predict = getattr(c_model2, predict_method)
|
c_model2_predict = getattr(c_model2, predict_method)
|
||||||
val_pred2 = c_model2_predict(validation.X)
|
val_pred2 = c_model2_predict(validation.X)
|
||||||
rca_score = rca.get_score(val_pred1, val_pred2, validation.y)
|
rca_score = 1.0 - rca.get_score(val_pred1, val_pred2, validation.y)
|
||||||
meta_score = abs(
|
meta_score = abs(rca_score - metrics.accuracy_score(test.y, test_pred))
|
||||||
rca_score - (1 - metrics.accuracy_score(test.y, test_pred))
|
|
||||||
)
|
|
||||||
report.append_row(test.prevalence(), acc=meta_score, acc_score=rca_score)
|
report.append_row(test.prevalence(), acc=meta_score, acc_score=rca_score)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
report.append_row(
|
report.append_row(
|
||||||
|
@ -244,17 +242,15 @@ def rca_star_score(
|
||||||
# ensure that the protocol returns a LabelledCollection for each iteration
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
report = EvaluationReport(prefix="rca_star")
|
report = EvaluationReport(name="rca_star")
|
||||||
for test in protocol():
|
for test in protocol():
|
||||||
try:
|
try:
|
||||||
test_pred = c_model_predict(test.X)
|
test_pred = c_model_predict(test.X)
|
||||||
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
|
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
|
||||||
c_model2_predict = getattr(c_model2, predict_method)
|
c_model2_predict = getattr(c_model2, predict_method)
|
||||||
val2_pred2 = c_model2_predict(validation2.X)
|
val2_pred2 = c_model2_predict(validation2.X)
|
||||||
rca_star_score = rca.get_score(val2_pred1, val2_pred2, validation2.y)
|
rca_star_score = 1.0 - rca.get_score(val2_pred1, val2_pred2, validation2.y)
|
||||||
meta_score = abs(
|
meta_score = abs(rca_star_score - metrics.accuracy_score(test.y, test_pred))
|
||||||
rca_star_score - (1 - metrics.accuracy_score(test.y, test_pred))
|
|
||||||
)
|
|
||||||
report.append_row(
|
report.append_row(
|
||||||
test.prevalence(), acc=meta_score, acc_score=rca_star_score
|
test.prevalence(), acc=meta_score, acc_score=rca_star_score
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -19,14 +20,25 @@ pd.set_option("display.float_format", "{:.4f}".format)
|
||||||
|
|
||||||
class CompEstimator:
|
class CompEstimator:
|
||||||
__dict = {
|
__dict = {
|
||||||
"OUR_BIN_SLD": method.evaluate_bin_sld,
|
"our_bin_SLD": method.evaluate_bin_sld,
|
||||||
"OUR_MUL_SLD": method.evaluate_mul_sld,
|
"our_mul_SLD": method.evaluate_mul_sld,
|
||||||
"KFCV": baseline.kfcv,
|
"our_bin_SLD_nbvs": method.evaluate_bin_sld_nbvs,
|
||||||
"ATC_MC": baseline.atc_mc,
|
"our_mul_SLD_nbvs": method.evaluate_mul_sld_nbvs,
|
||||||
"ATC_NE": baseline.atc_ne,
|
"our_bin_SLD_bcts": method.evaluate_bin_sld_bcts,
|
||||||
"DOC_FEAT": baseline.doc_feat,
|
"our_mul_SLD_bcts": method.evaluate_mul_sld_bcts,
|
||||||
"RCA": baseline.rca_score,
|
"our_bin_SLD_ts": method.evaluate_bin_sld_ts,
|
||||||
"RCA_STAR": baseline.rca_star_score,
|
"our_mul_SLD_ts": method.evaluate_mul_sld_ts,
|
||||||
|
"our_bin_SLD_vs": method.evaluate_bin_sld_vs,
|
||||||
|
"our_mul_SLD_vs": method.evaluate_mul_sld_vs,
|
||||||
|
"our_bin_CC": method.evaluate_bin_cc,
|
||||||
|
"our_mul_CC": method.evaluate_mul_cc,
|
||||||
|
"ref": baseline.reference,
|
||||||
|
"kfcv": baseline.kfcv,
|
||||||
|
"atc_mc": baseline.atc_mc,
|
||||||
|
"atc_ne": baseline.atc_ne,
|
||||||
|
"doc_feat": baseline.doc_feat,
|
||||||
|
"rca": baseline.rca_score,
|
||||||
|
"rca_star": baseline.rca_star_score,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __class_getitem__(cls, e: str | List[str]):
|
def __class_getitem__(cls, e: str | List[str]):
|
||||||
|
@ -55,7 +67,17 @@ def fit_and_estimate(_estimate, train, validation, test):
|
||||||
test, n_prevalences=env.PROTOCOL_N_PREVS, repeats=env.PROTOCOL_REPEATS
|
test, n_prevalences=env.PROTOCOL_N_PREVS, repeats=env.PROTOCOL_REPEATS
|
||||||
)
|
)
|
||||||
start = time.time()
|
start = time.time()
|
||||||
result = _estimate(model, validation, protocol)
|
try:
|
||||||
|
result = _estimate(model, validation, protocol)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Method {_estimate.__name__} failed.")
|
||||||
|
traceback(e)
|
||||||
|
return {
|
||||||
|
"name": _estimate.__name__,
|
||||||
|
"result": None,
|
||||||
|
"time": 0,
|
||||||
|
}
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"{_estimate.__name__}: {end-start:.2f}s")
|
print(f"{_estimate.__name__}: {end-start:.2f}s")
|
||||||
|
|
||||||
|
@ -69,22 +91,33 @@ def fit_and_estimate(_estimate, train, validation, test):
|
||||||
def evaluate_comparison(
|
def evaluate_comparison(
|
||||||
dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"]
|
dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"]
|
||||||
) -> EvaluationReport:
|
) -> EvaluationReport:
|
||||||
with multiprocessing.Pool(8) as pool:
|
with multiprocessing.Pool(len(estimators)) as pool:
|
||||||
dr = DatasetReport(dataset.name)
|
dr = DatasetReport(dataset.name)
|
||||||
for d in dataset():
|
for d in dataset():
|
||||||
print(f"train prev.: {d.train_prev}")
|
print(f"train prev.: {d.train_prev}")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]]
|
tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]]
|
||||||
results = [pool.apply_async(fit_and_estimate, t) for t in tasks]
|
results = [pool.apply_async(fit_and_estimate, t) for t in tasks]
|
||||||
results = list(map(lambda r: r.get(), results))
|
|
||||||
|
results_got = []
|
||||||
|
for _r in results:
|
||||||
|
try:
|
||||||
|
r = _r.get()
|
||||||
|
if r["result"] is not None:
|
||||||
|
results_got.append(r)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
er = EvaluationReport.combine_reports(
|
er = EvaluationReport.combine_reports(
|
||||||
*list(map(lambda r: r["result"], results)), name=dataset.name
|
*[r["result"] for r in results_got],
|
||||||
|
name=dataset.name,
|
||||||
|
train_prev=d.train_prev,
|
||||||
|
valid_prev=d.validation_prev,
|
||||||
)
|
)
|
||||||
times = {r["name"]: r["time"] for r in results}
|
times = {r["name"]: r["time"] for r in results_got}
|
||||||
end = time.time()
|
end = time.time()
|
||||||
times["tot"] = end - start
|
times["tot"] = end - start
|
||||||
er.times = times
|
er.times = times
|
||||||
er.train_prevs = d.prevs
|
|
||||||
dr.add(er)
|
dr.add(er)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import numpy as np
|
||||||
|
import sklearn.metrics as metrics
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.protocol import (
|
from quapy.protocol import (
|
||||||
AbstractStochasticSeededProtocol,
|
AbstractStochasticSeededProtocol,
|
||||||
|
@ -22,15 +24,17 @@ def estimate(
|
||||||
# ensure that the protocol returns a LabelledCollection for each iteration
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
base_prevs, true_prevs, estim_prevs = [], [], []
|
base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], []
|
||||||
for sample in protocol():
|
for sample in protocol():
|
||||||
e_sample = estimator.extend(sample)
|
e_sample, pred_proba = estimator.extend(sample)
|
||||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||||
base_prevs.append(sample.prevalence())
|
base_prevs.append(sample.prevalence())
|
||||||
true_prevs.append(e_sample.prevalence())
|
true_prevs.append(e_sample.prevalence())
|
||||||
estim_prevs.append(estim_prev)
|
estim_prevs.append(estim_prev)
|
||||||
|
pred_probas.append(pred_proba)
|
||||||
|
labels.append(sample.y)
|
||||||
|
|
||||||
return base_prevs, true_prevs, estim_prevs
|
return base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
||||||
|
|
||||||
|
|
||||||
def evaluation_report(
|
def evaluation_report(
|
||||||
|
@ -38,16 +42,21 @@ def evaluation_report(
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
method: str,
|
method: str,
|
||||||
) -> EvaluationReport:
|
) -> EvaluationReport:
|
||||||
base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
|
base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate(
|
||||||
report = EvaluationReport(prefix=method)
|
estimator, protocol
|
||||||
|
)
|
||||||
|
report = EvaluationReport(name=method)
|
||||||
|
|
||||||
for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
|
for base_prev, true_prev, estim_prev, pred_proba, label in zip(
|
||||||
|
base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
||||||
|
):
|
||||||
|
pred = np.argmax(pred_proba, axis=-1)
|
||||||
acc_score = error.acc(estim_prev)
|
acc_score = error.acc(estim_prev)
|
||||||
f1_score = error.f1(estim_prev)
|
f1_score = error.f1(estim_prev)
|
||||||
report.append_row(
|
report.append_row(
|
||||||
base_prev,
|
base_prev,
|
||||||
acc_score=1.0 - acc_score,
|
acc_score=acc_score,
|
||||||
acc=abs(error.acc(true_prev) - acc_score),
|
acc=abs(metrics.accuracy_score(label, pred) - acc_score),
|
||||||
f1_score=f1_score,
|
f1_score=f1_score,
|
||||||
f1=abs(error.f1(true_prev) - f1_score),
|
f1=abs(error.f1(true_prev) - f1_score),
|
||||||
)
|
)
|
||||||
|
@ -60,13 +69,18 @@ def evaluate(
|
||||||
validation: LabelledCollection,
|
validation: LabelledCollection,
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
method: str,
|
method: str,
|
||||||
|
q_model: str,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
estimator: AccuracyEstimator = {
|
estimator: AccuracyEstimator = {
|
||||||
"bin": BinaryQuantifierAccuracyEstimator,
|
"bin": BinaryQuantifierAccuracyEstimator,
|
||||||
"mul": MulticlassAccuracyEstimator,
|
"mul": MulticlassAccuracyEstimator,
|
||||||
}[method](c_model)
|
}[method](c_model, q_model=q_model, **kwargs)
|
||||||
estimator.fit(validation)
|
estimator.fit(validation)
|
||||||
return evaluation_report(estimator, protocol, method)
|
_method = f"{method}_{q_model}"
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
_method += f"_{v}"
|
||||||
|
return evaluation_report(estimator, protocol, _method)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_bin_sld(
|
def evaluate_bin_sld(
|
||||||
|
@ -74,7 +88,7 @@ def evaluate_bin_sld(
|
||||||
validation: LabelledCollection,
|
validation: LabelledCollection,
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
) -> EvaluationReport:
|
) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "bin")
|
return evaluate(c_model, validation, protocol, "bin", "SLD")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_mul_sld(
|
def evaluate_mul_sld(
|
||||||
|
@ -82,4 +96,84 @@ def evaluate_mul_sld(
|
||||||
validation: LabelledCollection,
|
validation: LabelledCollection,
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
) -> EvaluationReport:
|
) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "mul")
|
return evaluate(c_model, validation, protocol, "mul", "SLD")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_bin_sld_nbvs(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
) -> EvaluationReport:
|
||||||
|
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="nbvs")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_mul_sld_nbvs(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
) -> EvaluationReport:
|
||||||
|
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="nbvs")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_bin_sld_bcts(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
) -> EvaluationReport:
|
||||||
|
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="bcts")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_mul_sld_bcts(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
) -> EvaluationReport:
|
||||||
|
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="bcts")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_bin_sld_ts(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
) -> EvaluationReport:
|
||||||
|
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="ts")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_mul_sld_ts(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
) -> EvaluationReport:
|
||||||
|
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="ts")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_bin_sld_vs(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
) -> EvaluationReport:
|
||||||
|
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="vs")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_mul_sld_vs(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
) -> EvaluationReport:
|
||||||
|
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="vs")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_bin_cc(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
) -> EvaluationReport:
|
||||||
|
return evaluate(c_model, validation, protocol, "bin", "CC")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_mul_cc(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
) -> EvaluationReport:
|
||||||
|
return evaluate(c_model, validation, protocol, "mul", "CC")
|
||||||
|
|
|
@ -1,22 +1,24 @@
|
||||||
import statistics as stats
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from quacc import plot
|
from quacc import plot
|
||||||
|
from quacc.environ import env
|
||||||
from quacc.utils import fmt_line_md
|
from quacc.utils import fmt_line_md
|
||||||
|
|
||||||
|
|
||||||
class EvaluationReport:
|
class EvaluationReport:
|
||||||
def __init__(self, prefix=None):
|
def __init__(self, name=None):
|
||||||
self._prevs = []
|
self._prevs = []
|
||||||
self._dict = {}
|
self._dict = {}
|
||||||
self._g_prevs = None
|
self._g_prevs = None
|
||||||
self._g_dict = None
|
self._g_dict = None
|
||||||
self.name = prefix if prefix is not None else "default"
|
self.name = name if name is not None else "default"
|
||||||
self.times = {}
|
self.times = {}
|
||||||
self.train_prevs = {}
|
self.train_prev = None
|
||||||
|
self.valid_prev = None
|
||||||
self.target = "default"
|
self.target = "default"
|
||||||
|
|
||||||
def append_row(self, base: np.ndarray | Tuple, **row):
|
def append_row(self, base: np.ndarray | Tuple, **row):
|
||||||
|
@ -34,23 +36,40 @@ class EvaluationReport:
|
||||||
def columns(self):
|
def columns(self):
|
||||||
return self._dict.keys()
|
return self._dict.keys()
|
||||||
|
|
||||||
def groupby_prevs(self, metric: str = None):
|
def group_by_prevs(self, metric: str = None):
|
||||||
if self._g_dict is None:
|
if self._g_dict is None:
|
||||||
self._g_prevs = []
|
self._g_prevs = []
|
||||||
self._g_dict = {k: [] for k in self._dict.keys()}
|
self._g_dict = {k: [] for k in self._dict.keys()}
|
||||||
|
|
||||||
last_end = 0
|
for col, vals in self._dict.items():
|
||||||
for ind, bp in enumerate(self._prevs):
|
col_grouped = {}
|
||||||
if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]:
|
for bp, v in zip(self._prevs, vals):
|
||||||
continue
|
if bp not in col_grouped:
|
||||||
|
col_grouped[bp] = []
|
||||||
|
col_grouped[bp].append(v)
|
||||||
|
|
||||||
self._g_prevs.append(bp)
|
self._g_dict[col] = [
|
||||||
for col in self._dict.keys():
|
vs
|
||||||
self._g_dict[col].append(
|
for bp, vs in sorted(col_grouped.items(), key=lambda cg: cg[0][1])
|
||||||
stats.mean(self._dict[col][last_end : ind + 1])
|
]
|
||||||
)
|
|
||||||
|
|
||||||
last_end = ind + 1
|
self._g_prevs = sorted(
|
||||||
|
[(p0, p1) for [p0, p1] in np.unique(self._prevs, axis=0).tolist()],
|
||||||
|
key=lambda bp: bp[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# last_end = 0
|
||||||
|
# for ind, bp in enumerate(self._prevs):
|
||||||
|
# if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]:
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# self._g_prevs.append(bp)
|
||||||
|
# for col in self._dict.keys():
|
||||||
|
# self._g_dict[col].append(
|
||||||
|
# stats.mean(self._dict[col][last_end : ind + 1])
|
||||||
|
# )
|
||||||
|
|
||||||
|
# last_end = ind + 1
|
||||||
|
|
||||||
filtered_g_dict = self._g_dict
|
filtered_g_dict = self._g_dict
|
||||||
if metric is not None:
|
if metric is not None:
|
||||||
|
@ -60,30 +79,83 @@ class EvaluationReport:
|
||||||
|
|
||||||
return self._g_prevs, filtered_g_dict
|
return self._g_prevs, filtered_g_dict
|
||||||
|
|
||||||
|
def avg_by_prevs(self, metric: str = None):
|
||||||
|
g_prevs, g_dict = self.group_by_prevs(metric=metric)
|
||||||
|
|
||||||
|
a_dict = {}
|
||||||
|
for col, vals in g_dict.items():
|
||||||
|
a_dict[col] = [np.mean(vs) for vs in vals]
|
||||||
|
|
||||||
|
return g_prevs, a_dict
|
||||||
|
|
||||||
|
def avg_all(self, metric: str = None):
|
||||||
|
f_dict = self._dict
|
||||||
|
if metric is not None:
|
||||||
|
f_dict = {c1: ls for ((c0, c1), ls) in self._dict.items() if c0 == metric}
|
||||||
|
|
||||||
|
a_dict = {}
|
||||||
|
for col, vals in f_dict.items():
|
||||||
|
a_dict[col] = [np.mean(vals)]
|
||||||
|
|
||||||
|
return a_dict
|
||||||
|
|
||||||
def get_dataframe(self, metric="acc"):
|
def get_dataframe(self, metric="acc"):
|
||||||
g_prevs, g_dict = self.groupby_prevs(metric=metric)
|
g_prevs, g_dict = self.avg_by_prevs(metric=metric)
|
||||||
|
a_dict = self.avg_all(metric=metric)
|
||||||
|
for col in g_dict.keys():
|
||||||
|
g_dict[col].extend(a_dict[col])
|
||||||
return pd.DataFrame(
|
return pd.DataFrame(
|
||||||
g_dict,
|
g_dict,
|
||||||
index=g_prevs,
|
index=g_prevs + ["tot"],
|
||||||
columns=g_dict.keys(),
|
columns=g_dict.keys(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_plot(self, mode="delta", metric="acc"):
|
def get_plot(self, mode="delta", metric="acc") -> Path:
|
||||||
g_prevs, g_dict = self.groupby_prevs(metric=metric)
|
if mode == "delta":
|
||||||
t_prev = int(round(self.train_prevs["train"][0] * 100))
|
g_prevs, g_dict = self.group_by_prevs(metric=metric)
|
||||||
title = f"{self.name}_{t_prev}_{metric}"
|
return plot.plot_delta(
|
||||||
plot.plot_delta(g_prevs, g_dict, metric, title)
|
g_prevs,
|
||||||
|
g_dict,
|
||||||
|
metric=metric,
|
||||||
|
name=self.name,
|
||||||
|
train_prev=self.train_prev,
|
||||||
|
)
|
||||||
|
elif mode == "diagonal":
|
||||||
|
_, g_dict = self.avg_by_prevs(metric=metric + "_score")
|
||||||
|
f_dict = {k: v for k, v in g_dict.items() if k != "ref"}
|
||||||
|
referece = g_dict["ref"]
|
||||||
|
return plot.plot_diagonal(
|
||||||
|
referece,
|
||||||
|
f_dict,
|
||||||
|
metric=metric,
|
||||||
|
name=self.name,
|
||||||
|
train_prev=self.train_prev,
|
||||||
|
)
|
||||||
|
elif mode == "shift":
|
||||||
|
g_prevs, g_dict = self.avg_by_prevs(metric=metric)
|
||||||
|
return plot.plot_shift(
|
||||||
|
g_prevs,
|
||||||
|
g_dict,
|
||||||
|
metric=metric,
|
||||||
|
name=self.name,
|
||||||
|
train_prev=self.train_prev,
|
||||||
|
)
|
||||||
|
|
||||||
def to_md(self, *metrics):
|
def to_md(self, *metrics):
|
||||||
res = ""
|
res = ""
|
||||||
for k, v in self.train_prevs.items():
|
res += fmt_line_md(f"train: {str(self.train_prev)}")
|
||||||
res += fmt_line_md(f"{k}: {str(v)}")
|
res += fmt_line_md(f"validation: {str(self.valid_prev)}")
|
||||||
for k, v in self.times.items():
|
for k, v in self.times.items():
|
||||||
res += fmt_line_md(f"{k}: {v:.3f}s")
|
res += fmt_line_md(f"{k}: {v:.3f}s")
|
||||||
res += "\n"
|
res += "\n"
|
||||||
for m in metrics:
|
for m in metrics:
|
||||||
res += self.get_dataframe(metric=m).to_html() + "\n\n"
|
res += self.get_dataframe(metric=m).to_html() + "\n\n"
|
||||||
self.get_plot(metric=m)
|
op_delta = self.get_plot(mode="delta", metric=m)
|
||||||
|
res += f")})\n"
|
||||||
|
op_diag = self.get_plot(mode="diagonal", metric=m)
|
||||||
|
res += f")})\n"
|
||||||
|
op_shift = self.get_plot(mode="shift", metric=m)
|
||||||
|
res += f")})\n"
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -91,8 +163,9 @@ class EvaluationReport:
|
||||||
if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)):
|
if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)):
|
||||||
raise ValueError("other has not same base prevalences of self")
|
raise ValueError("other has not same base prevalences of self")
|
||||||
|
|
||||||
if len(set(self._dict.keys()).intersection(set(other._dict.keys()))) > 0:
|
inters_keys = set(self._dict.keys()).intersection(set(other._dict.keys()))
|
||||||
raise ValueError("self and other have matching keys")
|
if len(inters_keys) > 0:
|
||||||
|
raise ValueError(f"self and other have matching keys {str(inters_keys)}.")
|
||||||
|
|
||||||
report = EvaluationReport()
|
report = EvaluationReport()
|
||||||
report._prevs = self._prevs
|
report._prevs = self._prevs
|
||||||
|
@ -100,12 +173,14 @@ class EvaluationReport:
|
||||||
return report
|
return report
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_reports(*args, name="default"):
|
def combine_reports(*args, name="default", train_prev=None, valid_prev=None):
|
||||||
er = args[0]
|
er = args[0]
|
||||||
for r in args[1:]:
|
for r in args[1:]:
|
||||||
er = er.merge(r)
|
er = er.merge(r)
|
||||||
|
|
||||||
er.name = name
|
er.name = name
|
||||||
|
er.train_prev = train_prev
|
||||||
|
er.valid_prev = valid_prev
|
||||||
return er
|
return er
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,39 @@
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import quacc.evaluation.comp as comp
|
import quacc.evaluation.comp as comp
|
||||||
from quacc.dataset import Dataset
|
from quacc.dataset import Dataset
|
||||||
from quacc.environ import env
|
from quacc.environ import env
|
||||||
|
|
||||||
|
|
||||||
|
def create_out_dir(dir_name):
|
||||||
|
dir_path = Path(env.OUT_DIR_NAME) / dir_name
|
||||||
|
env.OUT_DIR = dir_path
|
||||||
|
shutil.rmtree(dir_path, ignore_errors=True)
|
||||||
|
os.mkdir(dir_path)
|
||||||
|
plot_dir_path = dir_path / "plot"
|
||||||
|
env.PLOT_OUT_DIR = plot_dir_path
|
||||||
|
os.mkdir(plot_dir_path)
|
||||||
|
|
||||||
|
|
||||||
def estimate_comparison():
|
def estimate_comparison():
|
||||||
dataset = Dataset(
|
for conf in env:
|
||||||
env.DATASET_NAME, target=env.DATASET_TARGET, n_prevalences=env.DATASET_N_PREVS
|
create_out_dir(conf)
|
||||||
)
|
dataset = Dataset(
|
||||||
output_path = env.OUT_DIR / f"{dataset.name}.md"
|
env.DATASET_NAME,
|
||||||
with open(output_path, "w") as f:
|
target=env.DATASET_TARGET,
|
||||||
dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
|
n_prevalences=env.DATASET_N_PREVS,
|
||||||
f.write(dr.to_md("acc"))
|
)
|
||||||
|
output_path = env.OUT_DIR / f"{dataset.name}.md"
|
||||||
|
try:
|
||||||
|
dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
|
||||||
|
for m in env.METRICS:
|
||||||
|
output_path = env.OUT_DIR / f"{conf}_{m}.md"
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
f.write(dr.to_md(m))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Configuration {conf} failed. {e}")
|
||||||
|
|
||||||
# print(df.to_latex(float_format="{:.4f}".format))
|
# print(df.to_latex(float_format="{:.4f}".format))
|
||||||
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))
|
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))
|
||||||
|
|
191
quacc/plot.py
|
@ -1,16 +1,191 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from quacc.environ import env
|
from quacc.environ import env
|
||||||
|
|
||||||
|
|
||||||
def plot_delta(base_prevs, dict_vals, metric, title):
|
def _get_markers(n: int):
|
||||||
fig, ax = plt.subplots()
|
return [
|
||||||
|
"o",
|
||||||
|
"v",
|
||||||
|
"x",
|
||||||
|
"+",
|
||||||
|
"s",
|
||||||
|
"D",
|
||||||
|
"p",
|
||||||
|
"h",
|
||||||
|
"*",
|
||||||
|
"^",
|
||||||
|
][:n]
|
||||||
|
|
||||||
base_prevs = [f for f, p in base_prevs]
|
|
||||||
|
def plot_delta(
|
||||||
|
base_prevs,
|
||||||
|
dict_vals,
|
||||||
|
*,
|
||||||
|
pos_class=1,
|
||||||
|
metric="acc",
|
||||||
|
name="default",
|
||||||
|
train_prev=None,
|
||||||
|
legend=True,
|
||||||
|
) -> Path:
|
||||||
|
if train_prev is not None:
|
||||||
|
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||||
|
title = f"delta_{name}_{t_prev_pos}_{metric}"
|
||||||
|
else:
|
||||||
|
title = f"delta_{name}_{metric}"
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.set_aspect("auto")
|
||||||
|
ax.grid()
|
||||||
|
|
||||||
|
NUM_COLORS = len(dict_vals)
|
||||||
|
cm = plt.get_cmap("tab10")
|
||||||
|
if NUM_COLORS > 10:
|
||||||
|
cm = plt.get_cmap("tab20")
|
||||||
|
ax.set_prop_cycle(
|
||||||
|
color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)],
|
||||||
|
)
|
||||||
|
|
||||||
|
base_prevs = [bp[pos_class] for bp in base_prevs]
|
||||||
for method, deltas in dict_vals.items():
|
for method, deltas in dict_vals.items():
|
||||||
|
avg = np.array([np.mean(d, axis=-1) for d in deltas])
|
||||||
|
# std = np.array([np.std(d, axis=-1) for d in deltas])
|
||||||
ax.plot(
|
ax.plot(
|
||||||
base_prevs,
|
base_prevs,
|
||||||
|
avg,
|
||||||
|
label=method,
|
||||||
|
linestyle="-",
|
||||||
|
marker="o",
|
||||||
|
markersize=3,
|
||||||
|
zorder=2,
|
||||||
|
)
|
||||||
|
# ax.fill_between(base_prevs, avg - std, avg + std, alpha=0.25)
|
||||||
|
|
||||||
|
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
|
||||||
|
|
||||||
|
if legend:
|
||||||
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||||
|
output_path = env.PLOT_OUT_DIR / f"{title}.png"
|
||||||
|
fig.savefig(output_path, bbox_inches="tight")
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def plot_diagonal(
|
||||||
|
reference,
|
||||||
|
dict_vals,
|
||||||
|
*,
|
||||||
|
pos_class=1,
|
||||||
|
metric="acc",
|
||||||
|
name="default",
|
||||||
|
train_prev=None,
|
||||||
|
legend=True,
|
||||||
|
):
|
||||||
|
if train_prev is not None:
|
||||||
|
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||||
|
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
|
||||||
|
else:
|
||||||
|
title = f"diagonal_{name}_{metric}"
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.set_aspect("auto")
|
||||||
|
ax.grid()
|
||||||
|
|
||||||
|
NUM_COLORS = len(dict_vals)
|
||||||
|
cm = plt.get_cmap("tab10")
|
||||||
|
ax.set_prop_cycle(
|
||||||
|
marker=_get_markers(NUM_COLORS) * 2,
|
||||||
|
color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)] * 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
reference = np.array(reference)
|
||||||
|
x_ticks = np.unique(reference)
|
||||||
|
x_ticks.sort()
|
||||||
|
|
||||||
|
for _, deltas in dict_vals.items():
|
||||||
|
deltas = np.array(deltas)
|
||||||
|
ax.plot(
|
||||||
|
reference,
|
||||||
deltas,
|
deltas,
|
||||||
|
linestyle="None",
|
||||||
|
markersize=3,
|
||||||
|
zorder=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
for method, deltas in dict_vals.items():
|
||||||
|
deltas = np.array(deltas)
|
||||||
|
x_interp = x_ticks[[0, -1]]
|
||||||
|
y_interp = np.interp(x_interp, reference, deltas)
|
||||||
|
ax.plot(
|
||||||
|
x_interp,
|
||||||
|
y_interp,
|
||||||
|
label=method,
|
||||||
|
linestyle="-",
|
||||||
|
markersize="0",
|
||||||
|
zorder=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
|
||||||
|
|
||||||
|
if legend:
|
||||||
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||||
|
output_path = env.PLOT_OUT_DIR / f"{title}.png"
|
||||||
|
fig.savefig(output_path, bbox_inches="tight")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def plot_shift(
|
||||||
|
base_prevs,
|
||||||
|
dict_vals,
|
||||||
|
*,
|
||||||
|
pos_class=1,
|
||||||
|
metric="acc",
|
||||||
|
name="default",
|
||||||
|
train_prev=None,
|
||||||
|
legend=True,
|
||||||
|
) -> Path:
|
||||||
|
if train_prev is None:
|
||||||
|
raise AttributeError("train_prev cannot be None.")
|
||||||
|
|
||||||
|
train_prev = train_prev[pos_class]
|
||||||
|
t_prev_pos = int(round(train_prev * 100))
|
||||||
|
title = f"shift_{name}_{t_prev_pos}_{metric}"
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.set_aspect("auto")
|
||||||
|
ax.grid()
|
||||||
|
|
||||||
|
NUM_COLORS = len(dict_vals)
|
||||||
|
cm = plt.get_cmap("tab10")
|
||||||
|
if NUM_COLORS > 10:
|
||||||
|
cm = plt.get_cmap("tab20")
|
||||||
|
ax.set_prop_cycle(
|
||||||
|
color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)],
|
||||||
|
)
|
||||||
|
|
||||||
|
base_prevs = np.around(
|
||||||
|
[abs(bp[pos_class] - train_prev) for bp in base_prevs], decimals=2
|
||||||
|
)
|
||||||
|
for method, deltas in dict_vals.items():
|
||||||
|
delta_bins = {}
|
||||||
|
for bp, delta in zip(base_prevs, deltas):
|
||||||
|
if bp not in delta_bins:
|
||||||
|
delta_bins[bp] = []
|
||||||
|
delta_bins[bp].append(delta)
|
||||||
|
|
||||||
|
bp_unique, delta_avg = zip(
|
||||||
|
*sorted(
|
||||||
|
{k: np.mean(v) for k, v in delta_bins.items()}.items(),
|
||||||
|
key=lambda db: db[0],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.plot(
|
||||||
|
bp_unique,
|
||||||
|
delta_avg,
|
||||||
label=method,
|
label=method,
|
||||||
linestyle="-",
|
linestyle="-",
|
||||||
marker="o",
|
marker="o",
|
||||||
|
@ -19,8 +194,10 @@ def plot_delta(base_prevs, dict_vals, metric, title):
|
||||||
)
|
)
|
||||||
|
|
||||||
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
|
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
|
||||||
# ax.set_ylim(0, 1)
|
|
||||||
# ax.set_xlim(0, 1)
|
if legend:
|
||||||
ax.legend()
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||||
output_path = env.PLOT_OUT_DIR / f"{title}.png"
|
output_path = env.PLOT_OUT_DIR / f"{title}.png"
|
||||||
plt.savefig(output_path)
|
fig.savefig(output_path, bbox_inches="tight")
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|