plots, avg table, conf added; method updated
|
@ -11,4 +11,6 @@ lipton_bbse/__pycache__/*
|
|||
elsahar19_rca/__pycache__/*
|
||||
*.coverage
|
||||
.coverage
|
||||
scp_sync.py
|
||||
scp_sync.py
|
||||
out/*
|
||||
output/*
|
|
@ -41,12 +41,12 @@
|
|||
</head>
|
||||
<body class="vscode-body vscode-light">
|
||||
<ul class="contains-task-list">
|
||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> aggiungere media tabelle</li>
|
||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> plot; 3 tipi (appunti + email + garg)</li>
|
||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> aggiungere media tabelle</li>
|
||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> plot; 3 tipi (appunti + email + garg)</li>
|
||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> sistemare kfcv baseline</li>
|
||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> aggiungere metodo con CC oltre SLD</li>
|
||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> aggiungere metodo con CC oltre SLD</li>
|
||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)</li>
|
||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> variare parametro recalibration in SLD</li>
|
||||
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> variare parametro recalibration in SLD</li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
|
8
TODO.md
|
@ -1,6 +1,6 @@
|
|||
- [ ] aggiungere media tabelle
|
||||
- [ ] plot; 3 tipi (appunti + email + garg)
|
||||
- [x] aggiungere media tabelle
|
||||
- [x] plot; 3 tipi (appunti + email + garg)
|
||||
- [ ] sistemare kfcv baseline
|
||||
- [ ] aggiungere metodo con CC oltre SLD
|
||||
- [x] aggiungere metodo con CC oltre SLD
|
||||
- [x] prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)
|
||||
- [ ] variare parametro recalibration in SLD
|
||||
- [x] variare parametro recalibration in SLD
|
|
@ -0,0 +1,71 @@
|
|||
|
||||
exec: []
|
||||
|
||||
commons:
|
||||
- DATASET_NAME: rcv1
|
||||
DATASET_TARGET: CCAT
|
||||
METRICS:
|
||||
- acc
|
||||
- f1
|
||||
DATASET_N_PREVS: 9
|
||||
- DATASET_NAME: imdb
|
||||
METRICS:
|
||||
- acc
|
||||
- f1
|
||||
DATASET_N_PREVS: 9
|
||||
|
||||
confs:
|
||||
|
||||
all_mul_vs_atc:
|
||||
COMP_ESTIMATORS:
|
||||
- our_mul_SLD
|
||||
- our_mul_SLD_nbvs
|
||||
- our_mul_SLD_bcts
|
||||
- our_mul_SLD_ts
|
||||
- our_mul_SLD_vs
|
||||
- our_mul_CC
|
||||
- ref
|
||||
- atc_mc
|
||||
- atc_ne
|
||||
|
||||
all_bin_vs_atc:
|
||||
COMP_ESTIMATORS:
|
||||
- our_bin_SLD
|
||||
- our_bin_SLD_nbvs
|
||||
- our_bin_SLD_bcts
|
||||
- our_bin_SLD_ts
|
||||
- our_bin_SLD_vs
|
||||
- our_bin_CC
|
||||
- ref
|
||||
- atc_mc
|
||||
- atc_ne
|
||||
|
||||
best_our_vs_atc:
|
||||
COMP_ESTIMATORS:
|
||||
- our_bin_SLD
|
||||
- our_bin_SLD_bcts
|
||||
- our_bin_SLD_vs
|
||||
- our_bin_CC
|
||||
- our_mul_SLD
|
||||
- our_mul_SLD_bcts
|
||||
- our_mul_SLD_vs
|
||||
- our_mul_CC
|
||||
- ref
|
||||
- atc_mc
|
||||
- atc_ne
|
||||
|
||||
best_our_vs_all:
|
||||
COMP_ESTIMATORS:
|
||||
- our_bin_SLD
|
||||
- our_bin_SLD_bcts
|
||||
- our_bin_SLD_vs
|
||||
- our_bin_CC
|
||||
- our_mul_SLD
|
||||
- our_mul_SLD_bcts
|
||||
- our_mul_SLD_vs
|
||||
- our_mul_CC
|
||||
- ref
|
||||
- kfcv
|
||||
- atc_mc
|
||||
- atc_ne
|
||||
- doc_feat
|
Before Width: | Height: | Size: 188 KiB |
Before Width: | Height: | Size: 198 KiB |
Before Width: | Height: | Size: 225 KiB |
Before Width: | Height: | Size: 244 KiB |
Before Width: | Height: | Size: 266 KiB |
Before Width: | Height: | Size: 231 KiB |
Before Width: | Height: | Size: 200 KiB |
Before Width: | Height: | Size: 192 KiB |
Before Width: | Height: | Size: 175 KiB |
1955
out/rcv1_CCAT.md
|
@ -956,6 +956,65 @@ files = [
|
|||
{file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyyaml"
|
||||
version = "6.0.1"
|
||||
description = "YAML parser and emitter for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"},
|
||||
{file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"},
|
||||
{file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"},
|
||||
{file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"},
|
||||
{file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"},
|
||||
{file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"},
|
||||
{file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
|
||||
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quapy"
|
||||
version = "0.1.7"
|
||||
|
@ -1164,4 +1223,4 @@ test = ["pytest", "pytest-cov"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "72e3afd9a24b88fc8a8f5f55e1c408f65090fce9015a442f6f41638191276b6f"
|
||||
content-hash = "0ce0e6b058900e7db2939e7eb047a1f868c88de67def370c1c1fa0ba532df0b0"
|
||||
|
|
|
@ -10,6 +10,7 @@ python = "^3.11"
|
|||
quapy = "^0.1.7"
|
||||
pandas = "^2.0.3"
|
||||
jinja2 = "^3.1.2"
|
||||
pyyaml = "^6.0.1"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
main = "quacc.main:main"
|
||||
|
|
|
@ -1,21 +1,33 @@
|
|||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
defalut_env = {
|
||||
"DATASET_NAME": "rcv1",
|
||||
"DATASET_TARGET": "CCAT",
|
||||
"METRICS": ["acc", "f1"],
|
||||
"COMP_ESTIMATORS": [
|
||||
"OUR_BIN_SLD",
|
||||
"OUR_MUL_SLD",
|
||||
"KFCV",
|
||||
"ATC_MC",
|
||||
"ATC_NE",
|
||||
"DOC_FEAT",
|
||||
# "RCA",
|
||||
# "RCA_STAR",
|
||||
"our_bin_SLD",
|
||||
"our_bin_SLD_nbvs",
|
||||
"our_bin_SLD_bcts",
|
||||
"our_bin_SLD_ts",
|
||||
"our_bin_SLD_vs",
|
||||
"our_bin_CC",
|
||||
"our_mul_SLD",
|
||||
"our_mul_SLD_nbvs",
|
||||
"our_mul_SLD_bcts",
|
||||
"our_mul_SLD_ts",
|
||||
"our_mul_SLD_vs",
|
||||
"our_mul_CC",
|
||||
"ref",
|
||||
"kfcv",
|
||||
"atc_mc",
|
||||
"atc_ne",
|
||||
"doc_feat",
|
||||
"rca",
|
||||
"rca_star",
|
||||
],
|
||||
"DATASET_N_PREVS": 9,
|
||||
"OUT_DIR": Path("out"),
|
||||
"PLOT_OUT_DIR": Path("out/plot"),
|
||||
"OUT_DIR_NAME": "output",
|
||||
"PLOT_DIR_NAME": "plot",
|
||||
"PROTOCOL_N_PREVS": 21,
|
||||
"PROTOCOL_REPEATS": 100,
|
||||
"SAMPLE_SIZE": 1000,
|
||||
|
@ -24,8 +36,37 @@ defalut_env = {
|
|||
|
||||
class Environ:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
self.exec = []
|
||||
self.confs = {}
|
||||
self.__setdict(kwargs)
|
||||
|
||||
def __setdict(self, d):
|
||||
for k, v in d.items():
|
||||
self.__setattr__(k, v)
|
||||
|
||||
def load_conf(self):
|
||||
with open("conf.yaml", "r") as f:
|
||||
confs = yaml.safe_load(f)
|
||||
|
||||
for common in confs["commons"]:
|
||||
name = common["DATASET_NAME"]
|
||||
if "DATASET_TARGET" in common:
|
||||
name += "_" + common["DATASET_TARGET"]
|
||||
for k, d in confs["confs"].items():
|
||||
_k = f"{name}_{k}"
|
||||
self.confs[_k] = common | d
|
||||
self.exec.append(_k)
|
||||
|
||||
if "exec" in confs:
|
||||
if len(confs["exec"]) > 0:
|
||||
self.exec = confs["exec"]
|
||||
|
||||
def __iter__(self):
|
||||
self.load_conf()
|
||||
for _conf in self.exec:
|
||||
if _conf in self.confs:
|
||||
self.__setdict(self.confs[_conf])
|
||||
yield _conf
|
||||
|
||||
|
||||
env = Environ(**defalut_env)
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import quapy as qp
|
||||
|
||||
|
||||
def from_name(err_name):
|
||||
if err_name == 'f1e':
|
||||
if err_name == "f1e":
|
||||
return f1e
|
||||
elif err_name == 'f1':
|
||||
elif err_name == "f1":
|
||||
return f1
|
||||
else:
|
||||
return qp.error.from_name(err_name)
|
||||
|
||||
|
||||
|
||||
# def f1(prev):
|
||||
# # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
|
||||
# if prev[0] == 0 and prev[1] == 0 and prev[2] == 0:
|
||||
|
@ -18,18 +20,21 @@ def from_name(err_name):
|
|||
# return float('NaN')
|
||||
# else:
|
||||
# recall = prev[0] / (prev[0] + prev[1])
|
||||
# precision = prev[0] / (prev[0] + prev[2])
|
||||
# precision = prev[0] / (prev[0] + prev[2])
|
||||
# return 2 * (precision * recall) / (precision + recall)
|
||||
|
||||
|
||||
def f1(prev):
|
||||
den = (2*prev[3]) + prev[1] + prev[2]
|
||||
den = (2 * prev[3]) + prev[1] + prev[2]
|
||||
if den == 0:
|
||||
return 0.0
|
||||
else:
|
||||
return (2*prev[3])/den
|
||||
return (2 * prev[3]) / den
|
||||
|
||||
|
||||
def f1e(prev):
|
||||
return 1 - f1(prev)
|
||||
|
||||
|
||||
def acc(prev):
|
||||
return (prev[1] + prev[2]) / sum(prev)
|
||||
return (prev[0] + prev[3]) / sum(prev)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from abc import abstractmethod
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
|
||||
import numpy as np
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.method.aggregative import SLD
|
||||
from quapy.method.aggregative import CC, SLD
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import cross_val_predict
|
||||
|
@ -15,7 +15,7 @@ class AccuracyEstimator:
|
|||
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||
if not pred_proba:
|
||||
pred_proba = self.c_model.predict_proba(base.X)
|
||||
return ExtendedCollection.extend_collection(base, pred_proba)
|
||||
return ExtendedCollection.extend_collection(base, pred_proba), pred_proba
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||
|
@ -27,9 +27,15 @@ class AccuracyEstimator:
|
|||
|
||||
|
||||
class MulticlassAccuracyEstimator(AccuracyEstimator):
|
||||
def __init__(self, c_model: BaseEstimator):
|
||||
def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs):
|
||||
self.c_model = c_model
|
||||
self.q_model = SLD(LogisticRegression())
|
||||
if q_model == "SLD":
|
||||
available_args = ["recalib"]
|
||||
sld_args = {k: v for k, v in kwargs.items() if k in available_args}
|
||||
self.q_model = SLD(LogisticRegression(), **sld_args)
|
||||
elif q_model == "CC":
|
||||
self.q_model = CC(LogisticRegression())
|
||||
|
||||
self.e_train = None
|
||||
|
||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||
|
@ -67,10 +73,17 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
|
|||
|
||||
|
||||
class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
||||
def __init__(self, c_model: BaseEstimator):
|
||||
def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs):
|
||||
self.c_model = c_model
|
||||
self.q_model_0 = SLD(LogisticRegression())
|
||||
self.q_model_1 = SLD(LogisticRegression())
|
||||
if q_model == "SLD":
|
||||
available_args = ["recalib"]
|
||||
sld_args = {k: v for k, v in kwargs.items() if k in available_args}
|
||||
self.q_model_0 = SLD(LogisticRegression(), **sld_args)
|
||||
self.q_model_1 = SLD(LogisticRegression(), **sld_args)
|
||||
elif q_model == "CC":
|
||||
self.q_model_0 = CC(LogisticRegression())
|
||||
self.q_model_1 = CC(LogisticRegression())
|
||||
|
||||
self.e_train = None
|
||||
|
||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||
|
@ -83,7 +96,7 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
|||
|
||||
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
||||
elif isinstance(train, ExtendedCollection):
|
||||
self.e_train = train
|
||||
self.e_train = train
|
||||
|
||||
self.n_classes = self.e_train.n_classes
|
||||
[e_train_0, e_train_1] = self.e_train.split_by_pred()
|
||||
|
|
|
@ -34,14 +34,14 @@ def kfcv(
|
|||
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
|
||||
report = EvaluationReport(prefix="kfcv")
|
||||
report = EvaluationReport(name="kfcv")
|
||||
for test in protocol():
|
||||
test_preds = c_model_predict(test.X)
|
||||
meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds))
|
||||
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds))
|
||||
report.append_row(
|
||||
test.prevalence(),
|
||||
acc_score=(1.0 - acc_score),
|
||||
acc_score=acc_score,
|
||||
f1_score=f1_score,
|
||||
acc=meta_acc,
|
||||
f1=meta_f1,
|
||||
|
@ -57,13 +57,13 @@ def reference(
|
|||
):
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
c_model_predict = getattr(c_model, "predict_proba")
|
||||
report = EvaluationReport(prefix="ref")
|
||||
report = EvaluationReport(name="ref")
|
||||
for test in protocol():
|
||||
test_probs = c_model_predict(test.X)
|
||||
test_preds = np.argmax(test_probs, axis=-1)
|
||||
report.append_row(
|
||||
test.prevalence(),
|
||||
acc_score=(1 - metrics.accuracy_score(test.y, test_preds)),
|
||||
acc_score=metrics.accuracy_score(test.y, test_preds),
|
||||
f1_score=metrics.f1_score(test.y, test_preds),
|
||||
)
|
||||
|
||||
|
@ -89,7 +89,7 @@ def atc_mc(
|
|||
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
|
||||
report = EvaluationReport(prefix="atc_mc")
|
||||
report = EvaluationReport(name="atc_mc")
|
||||
for test in protocol():
|
||||
## Load OOD test data probs
|
||||
test_probs = c_model_predict(test.X)
|
||||
|
@ -102,7 +102,7 @@ def atc_mc(
|
|||
report.append_row(
|
||||
test.prevalence(),
|
||||
acc=meta_acc,
|
||||
acc_score=1.0 - atc_accuracy,
|
||||
acc_score=atc_accuracy,
|
||||
f1_score=f1_score,
|
||||
f1=meta_f1,
|
||||
)
|
||||
|
@ -129,7 +129,7 @@ def atc_ne(
|
|||
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
|
||||
report = EvaluationReport(prefix="atc_ne")
|
||||
report = EvaluationReport(name="atc_ne")
|
||||
for test in protocol():
|
||||
## Load OOD test data probs
|
||||
test_probs = c_model_predict(test.X)
|
||||
|
@ -142,7 +142,7 @@ def atc_ne(
|
|||
report.append_row(
|
||||
test.prevalence(),
|
||||
acc=meta_acc,
|
||||
acc_score=(1.0 - atc_accuracy),
|
||||
acc_score=atc_accuracy,
|
||||
f1_score=f1_score,
|
||||
f1=meta_f1,
|
||||
)
|
||||
|
@ -182,14 +182,14 @@ def doc_feat(
|
|||
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
|
||||
report = EvaluationReport(prefix="doc_feat")
|
||||
report = EvaluationReport(name="doc_feat")
|
||||
for test in protocol():
|
||||
test_probs = c_model_predict(test.X)
|
||||
test_preds = np.argmax(test_probs, axis=-1)
|
||||
test_scores = np.max(test_probs, axis=-1)
|
||||
score = (v1acc + doc.get_doc(val_scores, test_scores)) / 100.0
|
||||
meta_acc = abs(score - metrics.accuracy_score(test.y, test_preds))
|
||||
report.append_row(test.prevalence(), acc=meta_acc, acc_score=(1.0 - score))
|
||||
report.append_row(test.prevalence(), acc=meta_acc, acc_score=score)
|
||||
|
||||
return report
|
||||
|
||||
|
@ -206,17 +206,15 @@ def rca_score(
|
|||
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
|
||||
report = EvaluationReport(prefix="rca")
|
||||
report = EvaluationReport(name="rca")
|
||||
for test in protocol():
|
||||
try:
|
||||
test_pred = c_model_predict(test.X)
|
||||
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
|
||||
c_model2_predict = getattr(c_model2, predict_method)
|
||||
val_pred2 = c_model2_predict(validation.X)
|
||||
rca_score = rca.get_score(val_pred1, val_pred2, validation.y)
|
||||
meta_score = abs(
|
||||
rca_score - (1 - metrics.accuracy_score(test.y, test_pred))
|
||||
)
|
||||
rca_score = 1.0 - rca.get_score(val_pred1, val_pred2, validation.y)
|
||||
meta_score = abs(rca_score - metrics.accuracy_score(test.y, test_pred))
|
||||
report.append_row(test.prevalence(), acc=meta_score, acc_score=rca_score)
|
||||
except ValueError:
|
||||
report.append_row(
|
||||
|
@ -244,17 +242,15 @@ def rca_star_score(
|
|||
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
|
||||
report = EvaluationReport(prefix="rca_star")
|
||||
report = EvaluationReport(name="rca_star")
|
||||
for test in protocol():
|
||||
try:
|
||||
test_pred = c_model_predict(test.X)
|
||||
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
|
||||
c_model2_predict = getattr(c_model2, predict_method)
|
||||
val2_pred2 = c_model2_predict(validation2.X)
|
||||
rca_star_score = rca.get_score(val2_pred1, val2_pred2, validation2.y)
|
||||
meta_score = abs(
|
||||
rca_star_score - (1 - metrics.accuracy_score(test.y, test_pred))
|
||||
)
|
||||
rca_star_score = 1.0 - rca.get_score(val2_pred1, val2_pred2, validation2.y)
|
||||
meta_score = abs(rca_star_score - metrics.accuracy_score(test.y, test_pred))
|
||||
report.append_row(
|
||||
test.prevalence(), acc=meta_score, acc_score=rca_star_score
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import multiprocessing
|
||||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
@ -19,14 +20,25 @@ pd.set_option("display.float_format", "{:.4f}".format)
|
|||
|
||||
class CompEstimator:
|
||||
__dict = {
|
||||
"OUR_BIN_SLD": method.evaluate_bin_sld,
|
||||
"OUR_MUL_SLD": method.evaluate_mul_sld,
|
||||
"KFCV": baseline.kfcv,
|
||||
"ATC_MC": baseline.atc_mc,
|
||||
"ATC_NE": baseline.atc_ne,
|
||||
"DOC_FEAT": baseline.doc_feat,
|
||||
"RCA": baseline.rca_score,
|
||||
"RCA_STAR": baseline.rca_star_score,
|
||||
"our_bin_SLD": method.evaluate_bin_sld,
|
||||
"our_mul_SLD": method.evaluate_mul_sld,
|
||||
"our_bin_SLD_nbvs": method.evaluate_bin_sld_nbvs,
|
||||
"our_mul_SLD_nbvs": method.evaluate_mul_sld_nbvs,
|
||||
"our_bin_SLD_bcts": method.evaluate_bin_sld_bcts,
|
||||
"our_mul_SLD_bcts": method.evaluate_mul_sld_bcts,
|
||||
"our_bin_SLD_ts": method.evaluate_bin_sld_ts,
|
||||
"our_mul_SLD_ts": method.evaluate_mul_sld_ts,
|
||||
"our_bin_SLD_vs": method.evaluate_bin_sld_vs,
|
||||
"our_mul_SLD_vs": method.evaluate_mul_sld_vs,
|
||||
"our_bin_CC": method.evaluate_bin_cc,
|
||||
"our_mul_CC": method.evaluate_mul_cc,
|
||||
"ref": baseline.reference,
|
||||
"kfcv": baseline.kfcv,
|
||||
"atc_mc": baseline.atc_mc,
|
||||
"atc_ne": baseline.atc_ne,
|
||||
"doc_feat": baseline.doc_feat,
|
||||
"rca": baseline.rca_score,
|
||||
"rca_star": baseline.rca_star_score,
|
||||
}
|
||||
|
||||
def __class_getitem__(cls, e: str | List[str]):
|
||||
|
@ -55,7 +67,17 @@ def fit_and_estimate(_estimate, train, validation, test):
|
|||
test, n_prevalences=env.PROTOCOL_N_PREVS, repeats=env.PROTOCOL_REPEATS
|
||||
)
|
||||
start = time.time()
|
||||
result = _estimate(model, validation, protocol)
|
||||
try:
|
||||
result = _estimate(model, validation, protocol)
|
||||
except Exception as e:
|
||||
print(f"Method {_estimate.__name__} failed.")
|
||||
traceback(e)
|
||||
return {
|
||||
"name": _estimate.__name__,
|
||||
"result": None,
|
||||
"time": 0,
|
||||
}
|
||||
|
||||
end = time.time()
|
||||
print(f"{_estimate.__name__}: {end-start:.2f}s")
|
||||
|
||||
|
@ -69,22 +91,33 @@ def fit_and_estimate(_estimate, train, validation, test):
|
|||
def evaluate_comparison(
|
||||
dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"]
|
||||
) -> EvaluationReport:
|
||||
with multiprocessing.Pool(8) as pool:
|
||||
with multiprocessing.Pool(len(estimators)) as pool:
|
||||
dr = DatasetReport(dataset.name)
|
||||
for d in dataset():
|
||||
print(f"train prev.: {d.train_prev}")
|
||||
start = time.time()
|
||||
tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]]
|
||||
results = [pool.apply_async(fit_and_estimate, t) for t in tasks]
|
||||
results = list(map(lambda r: r.get(), results))
|
||||
|
||||
results_got = []
|
||||
for _r in results:
|
||||
try:
|
||||
r = _r.get()
|
||||
if r["result"] is not None:
|
||||
results_got.append(r)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
er = EvaluationReport.combine_reports(
|
||||
*list(map(lambda r: r["result"], results)), name=dataset.name
|
||||
*[r["result"] for r in results_got],
|
||||
name=dataset.name,
|
||||
train_prev=d.train_prev,
|
||||
valid_prev=d.validation_prev,
|
||||
)
|
||||
times = {r["name"]: r["time"] for r in results}
|
||||
times = {r["name"]: r["time"] for r in results_got}
|
||||
end = time.time()
|
||||
times["tot"] = end - start
|
||||
er.times = times
|
||||
er.train_prevs = d.prevs
|
||||
dr.add(er)
|
||||
print()
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import numpy as np
|
||||
import sklearn.metrics as metrics
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.protocol import (
|
||||
AbstractStochasticSeededProtocol,
|
||||
|
@ -22,15 +24,17 @@ def estimate(
|
|||
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
|
||||
base_prevs, true_prevs, estim_prevs = [], [], []
|
||||
base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], []
|
||||
for sample in protocol():
|
||||
e_sample = estimator.extend(sample)
|
||||
e_sample, pred_proba = estimator.extend(sample)
|
||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||
base_prevs.append(sample.prevalence())
|
||||
true_prevs.append(e_sample.prevalence())
|
||||
estim_prevs.append(estim_prev)
|
||||
pred_probas.append(pred_proba)
|
||||
labels.append(sample.y)
|
||||
|
||||
return base_prevs, true_prevs, estim_prevs
|
||||
return base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
||||
|
||||
|
||||
def evaluation_report(
|
||||
|
@ -38,16 +42,21 @@ def evaluation_report(
|
|||
protocol: AbstractStochasticSeededProtocol,
|
||||
method: str,
|
||||
) -> EvaluationReport:
|
||||
base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
|
||||
report = EvaluationReport(prefix=method)
|
||||
base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate(
|
||||
estimator, protocol
|
||||
)
|
||||
report = EvaluationReport(name=method)
|
||||
|
||||
for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
|
||||
for base_prev, true_prev, estim_prev, pred_proba, label in zip(
|
||||
base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
||||
):
|
||||
pred = np.argmax(pred_proba, axis=-1)
|
||||
acc_score = error.acc(estim_prev)
|
||||
f1_score = error.f1(estim_prev)
|
||||
report.append_row(
|
||||
base_prev,
|
||||
acc_score=1.0 - acc_score,
|
||||
acc=abs(error.acc(true_prev) - acc_score),
|
||||
acc_score=acc_score,
|
||||
acc=abs(metrics.accuracy_score(label, pred) - acc_score),
|
||||
f1_score=f1_score,
|
||||
f1=abs(error.f1(true_prev) - f1_score),
|
||||
)
|
||||
|
@ -60,13 +69,18 @@ def evaluate(
|
|||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
method: str,
|
||||
q_model: str,
|
||||
**kwargs,
|
||||
):
|
||||
estimator: AccuracyEstimator = {
|
||||
"bin": BinaryQuantifierAccuracyEstimator,
|
||||
"mul": MulticlassAccuracyEstimator,
|
||||
}[method](c_model)
|
||||
}[method](c_model, q_model=q_model, **kwargs)
|
||||
estimator.fit(validation)
|
||||
return evaluation_report(estimator, protocol, method)
|
||||
_method = f"{method}_{q_model}"
|
||||
for k, v in kwargs.items():
|
||||
_method += f"_{v}"
|
||||
return evaluation_report(estimator, protocol, _method)
|
||||
|
||||
|
||||
def evaluate_bin_sld(
|
||||
|
@ -74,7 +88,7 @@ def evaluate_bin_sld(
|
|||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "bin")
|
||||
return evaluate(c_model, validation, protocol, "bin", "SLD")
|
||||
|
||||
|
||||
def evaluate_mul_sld(
|
||||
|
@ -82,4 +96,84 @@ def evaluate_mul_sld(
|
|||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "mul")
|
||||
return evaluate(c_model, validation, protocol, "mul", "SLD")
|
||||
|
||||
|
||||
def evaluate_bin_sld_nbvs(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="nbvs")
|
||||
|
||||
|
||||
def evaluate_mul_sld_nbvs(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="nbvs")
|
||||
|
||||
|
||||
def evaluate_bin_sld_bcts(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="bcts")
|
||||
|
||||
|
||||
def evaluate_mul_sld_bcts(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="bcts")
|
||||
|
||||
|
||||
def evaluate_bin_sld_ts(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="ts")
|
||||
|
||||
|
||||
def evaluate_mul_sld_ts(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="ts")
|
||||
|
||||
|
||||
def evaluate_bin_sld_vs(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="vs")
|
||||
|
||||
|
||||
def evaluate_mul_sld_vs(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="vs")
|
||||
|
||||
|
||||
def evaluate_bin_cc(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "bin", "CC")
|
||||
|
||||
|
||||
def evaluate_mul_cc(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
protocol: AbstractStochasticSeededProtocol,
|
||||
) -> EvaluationReport:
|
||||
return evaluate(c_model, validation, protocol, "mul", "CC")
|
||||
|
|
|
@ -1,22 +1,24 @@
|
|||
import statistics as stats
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from quacc import plot
|
||||
from quacc.environ import env
|
||||
from quacc.utils import fmt_line_md
|
||||
|
||||
|
||||
class EvaluationReport:
|
||||
def __init__(self, prefix=None):
|
||||
def __init__(self, name=None):
|
||||
self._prevs = []
|
||||
self._dict = {}
|
||||
self._g_prevs = None
|
||||
self._g_dict = None
|
||||
self.name = prefix if prefix is not None else "default"
|
||||
self.name = name if name is not None else "default"
|
||||
self.times = {}
|
||||
self.train_prevs = {}
|
||||
self.train_prev = None
|
||||
self.valid_prev = None
|
||||
self.target = "default"
|
||||
|
||||
def append_row(self, base: np.ndarray | Tuple, **row):
|
||||
|
@ -34,23 +36,40 @@ class EvaluationReport:
|
|||
def columns(self):
|
||||
return self._dict.keys()
|
||||
|
||||
def groupby_prevs(self, metric: str = None):
|
||||
def group_by_prevs(self, metric: str = None):
|
||||
if self._g_dict is None:
|
||||
self._g_prevs = []
|
||||
self._g_dict = {k: [] for k in self._dict.keys()}
|
||||
|
||||
last_end = 0
|
||||
for ind, bp in enumerate(self._prevs):
|
||||
if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]:
|
||||
continue
|
||||
for col, vals in self._dict.items():
|
||||
col_grouped = {}
|
||||
for bp, v in zip(self._prevs, vals):
|
||||
if bp not in col_grouped:
|
||||
col_grouped[bp] = []
|
||||
col_grouped[bp].append(v)
|
||||
|
||||
self._g_prevs.append(bp)
|
||||
for col in self._dict.keys():
|
||||
self._g_dict[col].append(
|
||||
stats.mean(self._dict[col][last_end : ind + 1])
|
||||
)
|
||||
self._g_dict[col] = [
|
||||
vs
|
||||
for bp, vs in sorted(col_grouped.items(), key=lambda cg: cg[0][1])
|
||||
]
|
||||
|
||||
last_end = ind + 1
|
||||
self._g_prevs = sorted(
|
||||
[(p0, p1) for [p0, p1] in np.unique(self._prevs, axis=0).tolist()],
|
||||
key=lambda bp: bp[1],
|
||||
)
|
||||
|
||||
# last_end = 0
|
||||
# for ind, bp in enumerate(self._prevs):
|
||||
# if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]:
|
||||
# continue
|
||||
|
||||
# self._g_prevs.append(bp)
|
||||
# for col in self._dict.keys():
|
||||
# self._g_dict[col].append(
|
||||
# stats.mean(self._dict[col][last_end : ind + 1])
|
||||
# )
|
||||
|
||||
# last_end = ind + 1
|
||||
|
||||
filtered_g_dict = self._g_dict
|
||||
if metric is not None:
|
||||
|
@ -60,30 +79,83 @@ class EvaluationReport:
|
|||
|
||||
return self._g_prevs, filtered_g_dict
|
||||
|
||||
def avg_by_prevs(self, metric: str = None):
|
||||
g_prevs, g_dict = self.group_by_prevs(metric=metric)
|
||||
|
||||
a_dict = {}
|
||||
for col, vals in g_dict.items():
|
||||
a_dict[col] = [np.mean(vs) for vs in vals]
|
||||
|
||||
return g_prevs, a_dict
|
||||
|
||||
def avg_all(self, metric: str = None):
|
||||
f_dict = self._dict
|
||||
if metric is not None:
|
||||
f_dict = {c1: ls for ((c0, c1), ls) in self._dict.items() if c0 == metric}
|
||||
|
||||
a_dict = {}
|
||||
for col, vals in f_dict.items():
|
||||
a_dict[col] = [np.mean(vals)]
|
||||
|
||||
return a_dict
|
||||
|
||||
def get_dataframe(self, metric="acc"):
|
||||
g_prevs, g_dict = self.groupby_prevs(metric=metric)
|
||||
g_prevs, g_dict = self.avg_by_prevs(metric=metric)
|
||||
a_dict = self.avg_all(metric=metric)
|
||||
for col in g_dict.keys():
|
||||
g_dict[col].extend(a_dict[col])
|
||||
return pd.DataFrame(
|
||||
g_dict,
|
||||
index=g_prevs,
|
||||
index=g_prevs + ["tot"],
|
||||
columns=g_dict.keys(),
|
||||
)
|
||||
|
||||
def get_plot(self, mode="delta", metric="acc"):
|
||||
g_prevs, g_dict = self.groupby_prevs(metric=metric)
|
||||
t_prev = int(round(self.train_prevs["train"][0] * 100))
|
||||
title = f"{self.name}_{t_prev}_{metric}"
|
||||
plot.plot_delta(g_prevs, g_dict, metric, title)
|
||||
def get_plot(self, mode="delta", metric="acc") -> Path:
|
||||
if mode == "delta":
|
||||
g_prevs, g_dict = self.group_by_prevs(metric=metric)
|
||||
return plot.plot_delta(
|
||||
g_prevs,
|
||||
g_dict,
|
||||
metric=metric,
|
||||
name=self.name,
|
||||
train_prev=self.train_prev,
|
||||
)
|
||||
elif mode == "diagonal":
|
||||
_, g_dict = self.avg_by_prevs(metric=metric + "_score")
|
||||
f_dict = {k: v for k, v in g_dict.items() if k != "ref"}
|
||||
referece = g_dict["ref"]
|
||||
return plot.plot_diagonal(
|
||||
referece,
|
||||
f_dict,
|
||||
metric=metric,
|
||||
name=self.name,
|
||||
train_prev=self.train_prev,
|
||||
)
|
||||
elif mode == "shift":
|
||||
g_prevs, g_dict = self.avg_by_prevs(metric=metric)
|
||||
return plot.plot_shift(
|
||||
g_prevs,
|
||||
g_dict,
|
||||
metric=metric,
|
||||
name=self.name,
|
||||
train_prev=self.train_prev,
|
||||
)
|
||||
|
||||
def to_md(self, *metrics):
|
||||
res = ""
|
||||
for k, v in self.train_prevs.items():
|
||||
res += fmt_line_md(f"{k}: {str(v)}")
|
||||
res += fmt_line_md(f"train: {str(self.train_prev)}")
|
||||
res += fmt_line_md(f"validation: {str(self.valid_prev)}")
|
||||
for k, v in self.times.items():
|
||||
res += fmt_line_md(f"{k}: {v:.3f}s")
|
||||
res += "\n"
|
||||
for m in metrics:
|
||||
res += self.get_dataframe(metric=m).to_html() + "\n\n"
|
||||
self.get_plot(metric=m)
|
||||
op_delta = self.get_plot(mode="delta", metric=m)
|
||||
res += f"![plot_delta]({str(op_delta.relative_to(env.OUT_DIR))})\n"
|
||||
op_diag = self.get_plot(mode="diagonal", metric=m)
|
||||
res += f"![plot_diagonal]({str(op_diag.relative_to(env.OUT_DIR))})\n"
|
||||
op_shift = self.get_plot(mode="shift", metric=m)
|
||||
res += f"![plot_shift]({str(op_shift.relative_to(env.OUT_DIR))})\n"
|
||||
|
||||
return res
|
||||
|
||||
|
@ -91,8 +163,9 @@ class EvaluationReport:
|
|||
if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)):
|
||||
raise ValueError("other has not same base prevalences of self")
|
||||
|
||||
if len(set(self._dict.keys()).intersection(set(other._dict.keys()))) > 0:
|
||||
raise ValueError("self and other have matching keys")
|
||||
inters_keys = set(self._dict.keys()).intersection(set(other._dict.keys()))
|
||||
if len(inters_keys) > 0:
|
||||
raise ValueError(f"self and other have matching keys {str(inters_keys)}.")
|
||||
|
||||
report = EvaluationReport()
|
||||
report._prevs = self._prevs
|
||||
|
@ -100,12 +173,14 @@ class EvaluationReport:
|
|||
return report
|
||||
|
||||
@staticmethod
|
||||
def combine_reports(*args, name="default"):
|
||||
def combine_reports(*args, name="default", train_prev=None, valid_prev=None):
|
||||
er = args[0]
|
||||
for r in args[1:]:
|
||||
er = er.merge(r)
|
||||
|
||||
er.name = name
|
||||
er.train_prev = train_prev
|
||||
er.valid_prev = valid_prev
|
||||
return er
|
||||
|
||||
|
||||
|
|
|
@ -1,16 +1,39 @@
|
|||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import quacc.evaluation.comp as comp
|
||||
from quacc.dataset import Dataset
|
||||
from quacc.environ import env
|
||||
|
||||
|
||||
def create_out_dir(dir_name):
|
||||
dir_path = Path(env.OUT_DIR_NAME) / dir_name
|
||||
env.OUT_DIR = dir_path
|
||||
shutil.rmtree(dir_path, ignore_errors=True)
|
||||
os.mkdir(dir_path)
|
||||
plot_dir_path = dir_path / "plot"
|
||||
env.PLOT_OUT_DIR = plot_dir_path
|
||||
os.mkdir(plot_dir_path)
|
||||
|
||||
|
||||
def estimate_comparison():
|
||||
dataset = Dataset(
|
||||
env.DATASET_NAME, target=env.DATASET_TARGET, n_prevalences=env.DATASET_N_PREVS
|
||||
)
|
||||
output_path = env.OUT_DIR / f"{dataset.name}.md"
|
||||
with open(output_path, "w") as f:
|
||||
dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
|
||||
f.write(dr.to_md("acc"))
|
||||
for conf in env:
|
||||
create_out_dir(conf)
|
||||
dataset = Dataset(
|
||||
env.DATASET_NAME,
|
||||
target=env.DATASET_TARGET,
|
||||
n_prevalences=env.DATASET_N_PREVS,
|
||||
)
|
||||
output_path = env.OUT_DIR / f"{dataset.name}.md"
|
||||
try:
|
||||
dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
|
||||
for m in env.METRICS:
|
||||
output_path = env.OUT_DIR / f"{conf}_{m}.md"
|
||||
with open(output_path, "w") as f:
|
||||
f.write(dr.to_md(m))
|
||||
except Exception as e:
|
||||
print(f"Configuration {conf} failed. {e}")
|
||||
|
||||
# print(df.to_latex(float_format="{:.4f}".format))
|
||||
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))
|
||||
|
|
191
quacc/plot.py
|
@ -1,16 +1,191 @@
|
|||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from quacc.environ import env
|
||||
|
||||
|
||||
def plot_delta(base_prevs, dict_vals, metric, title):
|
||||
fig, ax = plt.subplots()
|
||||
def _get_markers(n: int):
|
||||
return [
|
||||
"o",
|
||||
"v",
|
||||
"x",
|
||||
"+",
|
||||
"s",
|
||||
"D",
|
||||
"p",
|
||||
"h",
|
||||
"*",
|
||||
"^",
|
||||
][:n]
|
||||
|
||||
base_prevs = [f for f, p in base_prevs]
|
||||
|
||||
def plot_delta(
|
||||
base_prevs,
|
||||
dict_vals,
|
||||
*,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
) -> Path:
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"delta_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"delta_{name}_{metric}"
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(dict_vals)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
ax.set_prop_cycle(
|
||||
color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)],
|
||||
)
|
||||
|
||||
base_prevs = [bp[pos_class] for bp in base_prevs]
|
||||
for method, deltas in dict_vals.items():
|
||||
avg = np.array([np.mean(d, axis=-1) for d in deltas])
|
||||
# std = np.array([np.std(d, axis=-1) for d in deltas])
|
||||
ax.plot(
|
||||
base_prevs,
|
||||
avg,
|
||||
label=method,
|
||||
linestyle="-",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
# ax.fill_between(base_prevs, avg - std, avg + std, alpha=0.25)
|
||||
|
||||
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
output_path = env.PLOT_OUT_DIR / f"{title}.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def plot_diagonal(
|
||||
reference,
|
||||
dict_vals,
|
||||
*,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
):
|
||||
if train_prev is not None:
|
||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
|
||||
else:
|
||||
title = f"diagonal_{name}_{metric}"
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(dict_vals)
|
||||
cm = plt.get_cmap("tab10")
|
||||
ax.set_prop_cycle(
|
||||
marker=_get_markers(NUM_COLORS) * 2,
|
||||
color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)] * 2,
|
||||
)
|
||||
|
||||
reference = np.array(reference)
|
||||
x_ticks = np.unique(reference)
|
||||
x_ticks.sort()
|
||||
|
||||
for _, deltas in dict_vals.items():
|
||||
deltas = np.array(deltas)
|
||||
ax.plot(
|
||||
reference,
|
||||
deltas,
|
||||
linestyle="None",
|
||||
markersize=3,
|
||||
zorder=2,
|
||||
)
|
||||
|
||||
for method, deltas in dict_vals.items():
|
||||
deltas = np.array(deltas)
|
||||
x_interp = x_ticks[[0, -1]]
|
||||
y_interp = np.interp(x_interp, reference, deltas)
|
||||
ax.plot(
|
||||
x_interp,
|
||||
y_interp,
|
||||
label=method,
|
||||
linestyle="-",
|
||||
markersize="0",
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
output_path = env.PLOT_OUT_DIR / f"{title}.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
return output_path
|
||||
|
||||
|
||||
def plot_shift(
|
||||
base_prevs,
|
||||
dict_vals,
|
||||
*,
|
||||
pos_class=1,
|
||||
metric="acc",
|
||||
name="default",
|
||||
train_prev=None,
|
||||
legend=True,
|
||||
) -> Path:
|
||||
if train_prev is None:
|
||||
raise AttributeError("train_prev cannot be None.")
|
||||
|
||||
train_prev = train_prev[pos_class]
|
||||
t_prev_pos = int(round(train_prev * 100))
|
||||
title = f"shift_{name}_{t_prev_pos}_{metric}"
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_aspect("auto")
|
||||
ax.grid()
|
||||
|
||||
NUM_COLORS = len(dict_vals)
|
||||
cm = plt.get_cmap("tab10")
|
||||
if NUM_COLORS > 10:
|
||||
cm = plt.get_cmap("tab20")
|
||||
ax.set_prop_cycle(
|
||||
color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)],
|
||||
)
|
||||
|
||||
base_prevs = np.around(
|
||||
[abs(bp[pos_class] - train_prev) for bp in base_prevs], decimals=2
|
||||
)
|
||||
for method, deltas in dict_vals.items():
|
||||
delta_bins = {}
|
||||
for bp, delta in zip(base_prevs, deltas):
|
||||
if bp not in delta_bins:
|
||||
delta_bins[bp] = []
|
||||
delta_bins[bp].append(delta)
|
||||
|
||||
bp_unique, delta_avg = zip(
|
||||
*sorted(
|
||||
{k: np.mean(v) for k, v in delta_bins.items()}.items(),
|
||||
key=lambda db: db[0],
|
||||
)
|
||||
)
|
||||
|
||||
ax.plot(
|
||||
bp_unique,
|
||||
delta_avg,
|
||||
label=method,
|
||||
linestyle="-",
|
||||
marker="o",
|
||||
|
@ -19,8 +194,10 @@ def plot_delta(base_prevs, dict_vals, metric, title):
|
|||
)
|
||||
|
||||
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
|
||||
# ax.set_ylim(0, 1)
|
||||
# ax.set_xlim(0, 1)
|
||||
ax.legend()
|
||||
|
||||
if legend:
|
||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
output_path = env.PLOT_OUT_DIR / f"{title}.png"
|
||||
plt.savefig(output_path)
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
|
||||
return output_path
|
||||
|
|