plots, avg table, conf added; method updated

This commit is contained in:
Lorenzo Volpi 2023-10-23 03:14:35 +02:00
parent 1055a6bdd4
commit 040578652e
25 changed files with 716 additions and 2081 deletions

2
.gitignore vendored
View File

@ -12,3 +12,5 @@ elsahar19_rca/__pycache__/*
*.coverage
.coverage
scp_sync.py
out/*
output/*

View File

@ -41,12 +41,12 @@
</head>
<body class="vscode-body vscode-light">
<ul class="contains-task-list">
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> aggiungere media tabelle</li>
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> plot; 3 tipi (appunti + email + garg)</li>
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> aggiungere media tabelle</li>
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> plot; 3 tipi (appunti + email + garg)</li>
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> sistemare kfcv baseline</li>
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> aggiungere metodo con CC oltre SLD</li>
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> aggiungere metodo con CC oltre SLD</li>
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)</li>
<li class="task-list-item enabled"><input class="task-list-item-checkbox"type="checkbox"> variare parametro recalibration in SLD</li>
<li class="task-list-item enabled"><input class="task-list-item-checkbox" checked=""type="checkbox"> variare parametro recalibration in SLD</li>
</ul>

View File

@ -1,6 +1,6 @@
- [ ] aggiungere media tabelle
- [ ] plot; 3 tipi (appunti + email + garg)
- [x] aggiungere media tabelle
- [x] plot; 3 tipi (appunti + email + garg)
- [ ] sistemare kfcv baseline
- [ ] aggiungere metodo con CC oltre SLD
- [x] aggiungere metodo con CC oltre SLD
- [x] prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)
- [ ] variare parametro recalibration in SLD
- [x] variare parametro recalibration in SLD

71
conf.yaml Normal file
View File

@ -0,0 +1,71 @@
exec: []
commons:
- DATASET_NAME: rcv1
DATASET_TARGET: CCAT
METRICS:
- acc
- f1
DATASET_N_PREVS: 9
- DATASET_NAME: imdb
METRICS:
- acc
- f1
DATASET_N_PREVS: 9
confs:
all_mul_vs_atc:
COMP_ESTIMATORS:
- our_mul_SLD
- our_mul_SLD_nbvs
- our_mul_SLD_bcts
- our_mul_SLD_ts
- our_mul_SLD_vs
- our_mul_CC
- ref
- atc_mc
- atc_ne
all_bin_vs_atc:
COMP_ESTIMATORS:
- our_bin_SLD
- our_bin_SLD_nbvs
- our_bin_SLD_bcts
- our_bin_SLD_ts
- our_bin_SLD_vs
- our_bin_CC
- ref
- atc_mc
- atc_ne
best_our_vs_atc:
COMP_ESTIMATORS:
- our_bin_SLD
- our_bin_SLD_bcts
- our_bin_SLD_vs
- our_bin_CC
- our_mul_SLD
- our_mul_SLD_bcts
- our_mul_SLD_vs
- our_mul_CC
- ref
- atc_mc
- atc_ne
best_our_vs_all:
COMP_ESTIMATORS:
- our_bin_SLD
- our_bin_SLD_bcts
- our_bin_SLD_vs
- our_bin_CC
- our_mul_SLD
- our_mul_SLD_bcts
- our_mul_SLD_vs
- our_mul_CC
- ref
- kfcv
- atc_mc
- atc_ne
- doc_feat

Binary file not shown.

Before

Width:  |  Height:  |  Size: 188 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 198 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 225 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 244 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 266 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 231 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 200 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 192 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 175 KiB

File diff suppressed because it is too large Load Diff

61
poetry.lock generated
View File

@ -956,6 +956,65 @@ files = [
{file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"},
]
[[package]]
name = "pyyaml"
version = "6.0.1"
description = "YAML parser and emitter for Python"
optional = false
python-versions = ">=3.6"
files = [
{file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
{file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
{file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"},
{file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"},
{file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"},
{file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"},
{file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"},
{file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"},
{file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"},
{file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"},
{file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"},
{file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
{file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
]
[[package]]
name = "quapy"
version = "0.1.7"
@ -1164,4 +1223,4 @@ test = ["pytest", "pytest-cov"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "72e3afd9a24b88fc8a8f5f55e1c408f65090fce9015a442f6f41638191276b6f"
content-hash = "0ce0e6b058900e7db2939e7eb047a1f868c88de67def370c1c1fa0ba532df0b0"

View File

@ -10,6 +10,7 @@ python = "^3.11"
quapy = "^0.1.7"
pandas = "^2.0.3"
jinja2 = "^3.1.2"
pyyaml = "^6.0.1"
[tool.poetry.scripts]
main = "quacc.main:main"

View File

@ -1,21 +1,33 @@
from pathlib import Path
import yaml
defalut_env = {
"DATASET_NAME": "rcv1",
"DATASET_TARGET": "CCAT",
"METRICS": ["acc", "f1"],
"COMP_ESTIMATORS": [
"OUR_BIN_SLD",
"OUR_MUL_SLD",
"KFCV",
"ATC_MC",
"ATC_NE",
"DOC_FEAT",
# "RCA",
# "RCA_STAR",
"our_bin_SLD",
"our_bin_SLD_nbvs",
"our_bin_SLD_bcts",
"our_bin_SLD_ts",
"our_bin_SLD_vs",
"our_bin_CC",
"our_mul_SLD",
"our_mul_SLD_nbvs",
"our_mul_SLD_bcts",
"our_mul_SLD_ts",
"our_mul_SLD_vs",
"our_mul_CC",
"ref",
"kfcv",
"atc_mc",
"atc_ne",
"doc_feat",
"rca",
"rca_star",
],
"DATASET_N_PREVS": 9,
"OUT_DIR": Path("out"),
"PLOT_OUT_DIR": Path("out/plot"),
"OUT_DIR_NAME": "output",
"PLOT_DIR_NAME": "plot",
"PROTOCOL_N_PREVS": 21,
"PROTOCOL_REPEATS": 100,
"SAMPLE_SIZE": 1000,
@ -24,8 +36,37 @@ defalut_env = {
class Environ:
def __init__(self, **kwargs):
for k, v in kwargs.items():
self.exec = []
self.confs = {}
self.__setdict(kwargs)
def __setdict(self, d):
for k, v in d.items():
self.__setattr__(k, v)
def load_conf(self):
with open("conf.yaml", "r") as f:
confs = yaml.safe_load(f)
for common in confs["commons"]:
name = common["DATASET_NAME"]
if "DATASET_TARGET" in common:
name += "_" + common["DATASET_TARGET"]
for k, d in confs["confs"].items():
_k = f"{name}_{k}"
self.confs[_k] = common | d
self.exec.append(_k)
if "exec" in confs:
if len(confs["exec"]) > 0:
self.exec = confs["exec"]
def __iter__(self):
self.load_conf()
for _conf in self.exec:
if _conf in self.confs:
self.__setdict(self.confs[_conf])
yield _conf
env = Environ(**defalut_env)

View File

@ -1,13 +1,15 @@
import quapy as qp
def from_name(err_name):
if err_name == 'f1e':
if err_name == "f1e":
return f1e
elif err_name == 'f1':
elif err_name == "f1":
return f1
else:
return qp.error.from_name(err_name)
# def f1(prev):
# # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
# if prev[0] == 0 and prev[1] == 0 and prev[2] == 0:
@ -21,6 +23,7 @@ def from_name(err_name):
# precision = prev[0] / (prev[0] + prev[2])
# return 2 * (precision * recall) / (precision + recall)
def f1(prev):
den = (2 * prev[3]) + prev[1] + prev[2]
if den == 0:
@ -28,8 +31,10 @@ def f1(prev):
else:
return (2 * prev[3]) / den
def f1e(prev):
return 1 - f1(prev)
def acc(prev):
return (prev[1] + prev[2]) / sum(prev)
return (prev[0] + prev[3]) / sum(prev)

View File

@ -1,9 +1,9 @@
from abc import abstractmethod
import math
from abc import abstractmethod
import numpy as np
from quapy.data import LabelledCollection
from quapy.method.aggregative import SLD
from quapy.method.aggregative import CC, SLD
from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
@ -15,7 +15,7 @@ class AccuracyEstimator:
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
if not pred_proba:
pred_proba = self.c_model.predict_proba(base.X)
return ExtendedCollection.extend_collection(base, pred_proba)
return ExtendedCollection.extend_collection(base, pred_proba), pred_proba
@abstractmethod
def fit(self, train: LabelledCollection | ExtendedCollection):
@ -27,9 +27,15 @@ class AccuracyEstimator:
class MulticlassAccuracyEstimator(AccuracyEstimator):
def __init__(self, c_model: BaseEstimator):
def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs):
self.c_model = c_model
self.q_model = SLD(LogisticRegression())
if q_model == "SLD":
available_args = ["recalib"]
sld_args = {k: v for k, v in kwargs.items() if k in available_args}
self.q_model = SLD(LogisticRegression(), **sld_args)
elif q_model == "CC":
self.q_model = CC(LogisticRegression())
self.e_train = None
def fit(self, train: LabelledCollection | ExtendedCollection):
@ -67,10 +73,17 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
def __init__(self, c_model: BaseEstimator):
def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs):
self.c_model = c_model
self.q_model_0 = SLD(LogisticRegression())
self.q_model_1 = SLD(LogisticRegression())
if q_model == "SLD":
available_args = ["recalib"]
sld_args = {k: v for k, v in kwargs.items() if k in available_args}
self.q_model_0 = SLD(LogisticRegression(), **sld_args)
self.q_model_1 = SLD(LogisticRegression(), **sld_args)
elif q_model == "CC":
self.q_model_0 = CC(LogisticRegression())
self.q_model_1 = CC(LogisticRegression())
self.e_train = None
def fit(self, train: LabelledCollection | ExtendedCollection):

View File

@ -34,14 +34,14 @@ def kfcv(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
report = EvaluationReport(prefix="kfcv")
report = EvaluationReport(name="kfcv")
for test in protocol():
test_preds = c_model_predict(test.X)
meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds))
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds))
report.append_row(
test.prevalence(),
acc_score=(1.0 - acc_score),
acc_score=acc_score,
f1_score=f1_score,
acc=meta_acc,
f1=meta_f1,
@ -57,13 +57,13 @@ def reference(
):
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
c_model_predict = getattr(c_model, "predict_proba")
report = EvaluationReport(prefix="ref")
report = EvaluationReport(name="ref")
for test in protocol():
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
report.append_row(
test.prevalence(),
acc_score=(1 - metrics.accuracy_score(test.y, test_preds)),
acc_score=metrics.accuracy_score(test.y, test_preds),
f1_score=metrics.f1_score(test.y, test_preds),
)
@ -89,7 +89,7 @@ def atc_mc(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
report = EvaluationReport(prefix="atc_mc")
report = EvaluationReport(name="atc_mc")
for test in protocol():
## Load OOD test data probs
test_probs = c_model_predict(test.X)
@ -102,7 +102,7 @@ def atc_mc(
report.append_row(
test.prevalence(),
acc=meta_acc,
acc_score=1.0 - atc_accuracy,
acc_score=atc_accuracy,
f1_score=f1_score,
f1=meta_f1,
)
@ -129,7 +129,7 @@ def atc_ne(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
report = EvaluationReport(prefix="atc_ne")
report = EvaluationReport(name="atc_ne")
for test in protocol():
## Load OOD test data probs
test_probs = c_model_predict(test.X)
@ -142,7 +142,7 @@ def atc_ne(
report.append_row(
test.prevalence(),
acc=meta_acc,
acc_score=(1.0 - atc_accuracy),
acc_score=atc_accuracy,
f1_score=f1_score,
f1=meta_f1,
)
@ -182,14 +182,14 @@ def doc_feat(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
report = EvaluationReport(prefix="doc_feat")
report = EvaluationReport(name="doc_feat")
for test in protocol():
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
test_scores = np.max(test_probs, axis=-1)
score = (v1acc + doc.get_doc(val_scores, test_scores)) / 100.0
meta_acc = abs(score - metrics.accuracy_score(test.y, test_preds))
report.append_row(test.prevalence(), acc=meta_acc, acc_score=(1.0 - score))
report.append_row(test.prevalence(), acc=meta_acc, acc_score=score)
return report
@ -206,17 +206,15 @@ def rca_score(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
report = EvaluationReport(prefix="rca")
report = EvaluationReport(name="rca")
for test in protocol():
try:
test_pred = c_model_predict(test.X)
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
c_model2_predict = getattr(c_model2, predict_method)
val_pred2 = c_model2_predict(validation.X)
rca_score = rca.get_score(val_pred1, val_pred2, validation.y)
meta_score = abs(
rca_score - (1 - metrics.accuracy_score(test.y, test_pred))
)
rca_score = 1.0 - rca.get_score(val_pred1, val_pred2, validation.y)
meta_score = abs(rca_score - metrics.accuracy_score(test.y, test_pred))
report.append_row(test.prevalence(), acc=meta_score, acc_score=rca_score)
except ValueError:
report.append_row(
@ -244,17 +242,15 @@ def rca_star_score(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
report = EvaluationReport(prefix="rca_star")
report = EvaluationReport(name="rca_star")
for test in protocol():
try:
test_pred = c_model_predict(test.X)
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
c_model2_predict = getattr(c_model2, predict_method)
val2_pred2 = c_model2_predict(validation2.X)
rca_star_score = rca.get_score(val2_pred1, val2_pred2, validation2.y)
meta_score = abs(
rca_star_score - (1 - metrics.accuracy_score(test.y, test_pred))
)
rca_star_score = 1.0 - rca.get_score(val2_pred1, val2_pred2, validation2.y)
meta_score = abs(rca_star_score - metrics.accuracy_score(test.y, test_pred))
report.append_row(
test.prevalence(), acc=meta_score, acc_score=rca_star_score
)

View File

@ -1,5 +1,6 @@
import multiprocessing
import time
import traceback
from typing import List
import pandas as pd
@ -19,14 +20,25 @@ pd.set_option("display.float_format", "{:.4f}".format)
class CompEstimator:
__dict = {
"OUR_BIN_SLD": method.evaluate_bin_sld,
"OUR_MUL_SLD": method.evaluate_mul_sld,
"KFCV": baseline.kfcv,
"ATC_MC": baseline.atc_mc,
"ATC_NE": baseline.atc_ne,
"DOC_FEAT": baseline.doc_feat,
"RCA": baseline.rca_score,
"RCA_STAR": baseline.rca_star_score,
"our_bin_SLD": method.evaluate_bin_sld,
"our_mul_SLD": method.evaluate_mul_sld,
"our_bin_SLD_nbvs": method.evaluate_bin_sld_nbvs,
"our_mul_SLD_nbvs": method.evaluate_mul_sld_nbvs,
"our_bin_SLD_bcts": method.evaluate_bin_sld_bcts,
"our_mul_SLD_bcts": method.evaluate_mul_sld_bcts,
"our_bin_SLD_ts": method.evaluate_bin_sld_ts,
"our_mul_SLD_ts": method.evaluate_mul_sld_ts,
"our_bin_SLD_vs": method.evaluate_bin_sld_vs,
"our_mul_SLD_vs": method.evaluate_mul_sld_vs,
"our_bin_CC": method.evaluate_bin_cc,
"our_mul_CC": method.evaluate_mul_cc,
"ref": baseline.reference,
"kfcv": baseline.kfcv,
"atc_mc": baseline.atc_mc,
"atc_ne": baseline.atc_ne,
"doc_feat": baseline.doc_feat,
"rca": baseline.rca_score,
"rca_star": baseline.rca_star_score,
}
def __class_getitem__(cls, e: str | List[str]):
@ -55,7 +67,17 @@ def fit_and_estimate(_estimate, train, validation, test):
test, n_prevalences=env.PROTOCOL_N_PREVS, repeats=env.PROTOCOL_REPEATS
)
start = time.time()
try:
result = _estimate(model, validation, protocol)
except Exception as e:
print(f"Method {_estimate.__name__} failed.")
traceback(e)
return {
"name": _estimate.__name__,
"result": None,
"time": 0,
}
end = time.time()
print(f"{_estimate.__name__}: {end-start:.2f}s")
@ -69,22 +91,33 @@ def fit_and_estimate(_estimate, train, validation, test):
def evaluate_comparison(
dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"]
) -> EvaluationReport:
with multiprocessing.Pool(8) as pool:
with multiprocessing.Pool(len(estimators)) as pool:
dr = DatasetReport(dataset.name)
for d in dataset():
print(f"train prev.: {d.train_prev}")
start = time.time()
tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]]
results = [pool.apply_async(fit_and_estimate, t) for t in tasks]
results = list(map(lambda r: r.get(), results))
results_got = []
for _r in results:
try:
r = _r.get()
if r["result"] is not None:
results_got.append(r)
except Exception as e:
print(e)
er = EvaluationReport.combine_reports(
*list(map(lambda r: r["result"], results)), name=dataset.name
*[r["result"] for r in results_got],
name=dataset.name,
train_prev=d.train_prev,
valid_prev=d.validation_prev,
)
times = {r["name"]: r["time"] for r in results}
times = {r["name"]: r["time"] for r in results_got}
end = time.time()
times["tot"] = end - start
er.times = times
er.train_prevs = d.prevs
dr.add(er)
print()

View File

@ -1,3 +1,5 @@
import numpy as np
import sklearn.metrics as metrics
from quapy.data import LabelledCollection
from quapy.protocol import (
AbstractStochasticSeededProtocol,
@ -22,15 +24,17 @@ def estimate(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
base_prevs, true_prevs, estim_prevs = [], [], []
base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], []
for sample in protocol():
e_sample = estimator.extend(sample)
e_sample, pred_proba = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
base_prevs.append(sample.prevalence())
true_prevs.append(e_sample.prevalence())
estim_prevs.append(estim_prev)
pred_probas.append(pred_proba)
labels.append(sample.y)
return base_prevs, true_prevs, estim_prevs
return base_prevs, true_prevs, estim_prevs, pred_probas, labels
def evaluation_report(
@ -38,16 +42,21 @@ def evaluation_report(
protocol: AbstractStochasticSeededProtocol,
method: str,
) -> EvaluationReport:
base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
report = EvaluationReport(prefix=method)
base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate(
estimator, protocol
)
report = EvaluationReport(name=method)
for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
for base_prev, true_prev, estim_prev, pred_proba, label in zip(
base_prevs, true_prevs, estim_prevs, pred_probas, labels
):
pred = np.argmax(pred_proba, axis=-1)
acc_score = error.acc(estim_prev)
f1_score = error.f1(estim_prev)
report.append_row(
base_prev,
acc_score=1.0 - acc_score,
acc=abs(error.acc(true_prev) - acc_score),
acc_score=acc_score,
acc=abs(metrics.accuracy_score(label, pred) - acc_score),
f1_score=f1_score,
f1=abs(error.f1(true_prev) - f1_score),
)
@ -60,13 +69,18 @@ def evaluate(
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
method: str,
q_model: str,
**kwargs,
):
estimator: AccuracyEstimator = {
"bin": BinaryQuantifierAccuracyEstimator,
"mul": MulticlassAccuracyEstimator,
}[method](c_model)
}[method](c_model, q_model=q_model, **kwargs)
estimator.fit(validation)
return evaluation_report(estimator, protocol, method)
_method = f"{method}_{q_model}"
for k, v in kwargs.items():
_method += f"_{v}"
return evaluation_report(estimator, protocol, _method)
def evaluate_bin_sld(
@ -74,7 +88,7 @@ def evaluate_bin_sld(
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin")
return evaluate(c_model, validation, protocol, "bin", "SLD")
def evaluate_mul_sld(
@ -82,4 +96,84 @@ def evaluate_mul_sld(
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul")
return evaluate(c_model, validation, protocol, "mul", "SLD")
def evaluate_bin_sld_nbvs(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="nbvs")
def evaluate_mul_sld_nbvs(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="nbvs")
def evaluate_bin_sld_bcts(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="bcts")
def evaluate_mul_sld_bcts(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="bcts")
def evaluate_bin_sld_ts(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="ts")
def evaluate_mul_sld_ts(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="ts")
def evaluate_bin_sld_vs(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="vs")
def evaluate_mul_sld_vs(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="vs")
def evaluate_bin_cc(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "CC")
def evaluate_mul_cc(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "CC")

View File

@ -1,22 +1,24 @@
import statistics as stats
from pathlib import Path
from typing import List, Tuple
import numpy as np
import pandas as pd
from quacc import plot
from quacc.environ import env
from quacc.utils import fmt_line_md
class EvaluationReport:
def __init__(self, prefix=None):
def __init__(self, name=None):
self._prevs = []
self._dict = {}
self._g_prevs = None
self._g_dict = None
self.name = prefix if prefix is not None else "default"
self.name = name if name is not None else "default"
self.times = {}
self.train_prevs = {}
self.train_prev = None
self.valid_prev = None
self.target = "default"
def append_row(self, base: np.ndarray | Tuple, **row):
@ -34,23 +36,40 @@ class EvaluationReport:
def columns(self):
return self._dict.keys()
def groupby_prevs(self, metric: str = None):
def group_by_prevs(self, metric: str = None):
if self._g_dict is None:
self._g_prevs = []
self._g_dict = {k: [] for k in self._dict.keys()}
last_end = 0
for ind, bp in enumerate(self._prevs):
if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]:
continue
for col, vals in self._dict.items():
col_grouped = {}
for bp, v in zip(self._prevs, vals):
if bp not in col_grouped:
col_grouped[bp] = []
col_grouped[bp].append(v)
self._g_prevs.append(bp)
for col in self._dict.keys():
self._g_dict[col].append(
stats.mean(self._dict[col][last_end : ind + 1])
self._g_dict[col] = [
vs
for bp, vs in sorted(col_grouped.items(), key=lambda cg: cg[0][1])
]
self._g_prevs = sorted(
[(p0, p1) for [p0, p1] in np.unique(self._prevs, axis=0).tolist()],
key=lambda bp: bp[1],
)
last_end = ind + 1
# last_end = 0
# for ind, bp in enumerate(self._prevs):
# if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]:
# continue
# self._g_prevs.append(bp)
# for col in self._dict.keys():
# self._g_dict[col].append(
# stats.mean(self._dict[col][last_end : ind + 1])
# )
# last_end = ind + 1
filtered_g_dict = self._g_dict
if metric is not None:
@ -60,30 +79,83 @@ class EvaluationReport:
return self._g_prevs, filtered_g_dict
def avg_by_prevs(self, metric: str = None):
g_prevs, g_dict = self.group_by_prevs(metric=metric)
a_dict = {}
for col, vals in g_dict.items():
a_dict[col] = [np.mean(vs) for vs in vals]
return g_prevs, a_dict
def avg_all(self, metric: str = None):
f_dict = self._dict
if metric is not None:
f_dict = {c1: ls for ((c0, c1), ls) in self._dict.items() if c0 == metric}
a_dict = {}
for col, vals in f_dict.items():
a_dict[col] = [np.mean(vals)]
return a_dict
def get_dataframe(self, metric="acc"):
g_prevs, g_dict = self.groupby_prevs(metric=metric)
g_prevs, g_dict = self.avg_by_prevs(metric=metric)
a_dict = self.avg_all(metric=metric)
for col in g_dict.keys():
g_dict[col].extend(a_dict[col])
return pd.DataFrame(
g_dict,
index=g_prevs,
index=g_prevs + ["tot"],
columns=g_dict.keys(),
)
def get_plot(self, mode="delta", metric="acc"):
g_prevs, g_dict = self.groupby_prevs(metric=metric)
t_prev = int(round(self.train_prevs["train"][0] * 100))
title = f"{self.name}_{t_prev}_{metric}"
plot.plot_delta(g_prevs, g_dict, metric, title)
def get_plot(self, mode="delta", metric="acc") -> Path:
if mode == "delta":
g_prevs, g_dict = self.group_by_prevs(metric=metric)
return plot.plot_delta(
g_prevs,
g_dict,
metric=metric,
name=self.name,
train_prev=self.train_prev,
)
elif mode == "diagonal":
_, g_dict = self.avg_by_prevs(metric=metric + "_score")
f_dict = {k: v for k, v in g_dict.items() if k != "ref"}
referece = g_dict["ref"]
return plot.plot_diagonal(
referece,
f_dict,
metric=metric,
name=self.name,
train_prev=self.train_prev,
)
elif mode == "shift":
g_prevs, g_dict = self.avg_by_prevs(metric=metric)
return plot.plot_shift(
g_prevs,
g_dict,
metric=metric,
name=self.name,
train_prev=self.train_prev,
)
def to_md(self, *metrics):
res = ""
for k, v in self.train_prevs.items():
res += fmt_line_md(f"{k}: {str(v)}")
res += fmt_line_md(f"train: {str(self.train_prev)}")
res += fmt_line_md(f"validation: {str(self.valid_prev)}")
for k, v in self.times.items():
res += fmt_line_md(f"{k}: {v:.3f}s")
res += "\n"
for m in metrics:
res += self.get_dataframe(metric=m).to_html() + "\n\n"
self.get_plot(metric=m)
op_delta = self.get_plot(mode="delta", metric=m)
res += f"![plot_delta]({str(op_delta.relative_to(env.OUT_DIR))})\n"
op_diag = self.get_plot(mode="diagonal", metric=m)
res += f"![plot_diagonal]({str(op_diag.relative_to(env.OUT_DIR))})\n"
op_shift = self.get_plot(mode="shift", metric=m)
res += f"![plot_shift]({str(op_shift.relative_to(env.OUT_DIR))})\n"
return res
@ -91,8 +163,9 @@ class EvaluationReport:
if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)):
raise ValueError("other has not same base prevalences of self")
if len(set(self._dict.keys()).intersection(set(other._dict.keys()))) > 0:
raise ValueError("self and other have matching keys")
inters_keys = set(self._dict.keys()).intersection(set(other._dict.keys()))
if len(inters_keys) > 0:
raise ValueError(f"self and other have matching keys {str(inters_keys)}.")
report = EvaluationReport()
report._prevs = self._prevs
@ -100,12 +173,14 @@ class EvaluationReport:
return report
@staticmethod
def combine_reports(*args, name="default"):
def combine_reports(*args, name="default", train_prev=None, valid_prev=None):
er = args[0]
for r in args[1:]:
er = er.merge(r)
er.name = name
er.train_prev = train_prev
er.valid_prev = valid_prev
return er

View File

@ -1,16 +1,39 @@
import os
import shutil
from pathlib import Path
import quacc.evaluation.comp as comp
from quacc.dataset import Dataset
from quacc.environ import env
def create_out_dir(dir_name):
dir_path = Path(env.OUT_DIR_NAME) / dir_name
env.OUT_DIR = dir_path
shutil.rmtree(dir_path, ignore_errors=True)
os.mkdir(dir_path)
plot_dir_path = dir_path / "plot"
env.PLOT_OUT_DIR = plot_dir_path
os.mkdir(plot_dir_path)
def estimate_comparison():
for conf in env:
create_out_dir(conf)
dataset = Dataset(
env.DATASET_NAME, target=env.DATASET_TARGET, n_prevalences=env.DATASET_N_PREVS
env.DATASET_NAME,
target=env.DATASET_TARGET,
n_prevalences=env.DATASET_N_PREVS,
)
output_path = env.OUT_DIR / f"{dataset.name}.md"
with open(output_path, "w") as f:
try:
dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
f.write(dr.to_md("acc"))
for m in env.METRICS:
output_path = env.OUT_DIR / f"{conf}_{m}.md"
with open(output_path, "w") as f:
f.write(dr.to_md(m))
except Exception as e:
print(f"Configuration {conf} failed. {e}")
# print(df.to_latex(float_format="{:.4f}".format))
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))

View File

@ -1,16 +1,191 @@
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from quacc.environ import env
def plot_delta(base_prevs, dict_vals, metric, title):
fig, ax = plt.subplots()
def _get_markers(n: int):
return [
"o",
"v",
"x",
"+",
"s",
"D",
"p",
"h",
"*",
"^",
][:n]
base_prevs = [f for f, p in base_prevs]
def plot_delta(
base_prevs,
dict_vals,
*,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
) -> Path:
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"delta_{name}_{t_prev_pos}_{metric}"
else:
title = f"delta_{name}_{metric}"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(dict_vals)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
ax.set_prop_cycle(
color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)],
)
base_prevs = [bp[pos_class] for bp in base_prevs]
for method, deltas in dict_vals.items():
avg = np.array([np.mean(d, axis=-1) for d in deltas])
# std = np.array([np.std(d, axis=-1) for d in deltas])
ax.plot(
base_prevs,
avg,
label=method,
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
# ax.fill_between(base_prevs, avg - std, avg + std, alpha=0.25)
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
output_path = env.PLOT_OUT_DIR / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
def plot_diagonal(
reference,
dict_vals,
*,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
):
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
else:
title = f"diagonal_{name}_{metric}"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(dict_vals)
cm = plt.get_cmap("tab10")
ax.set_prop_cycle(
marker=_get_markers(NUM_COLORS) * 2,
color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)] * 2,
)
reference = np.array(reference)
x_ticks = np.unique(reference)
x_ticks.sort()
for _, deltas in dict_vals.items():
deltas = np.array(deltas)
ax.plot(
reference,
deltas,
linestyle="None",
markersize=3,
zorder=2,
)
for method, deltas in dict_vals.items():
deltas = np.array(deltas)
x_interp = x_ticks[[0, -1]]
y_interp = np.interp(x_interp, reference, deltas)
ax.plot(
x_interp,
y_interp,
label=method,
linestyle="-",
markersize="0",
zorder=1,
)
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
output_path = env.PLOT_OUT_DIR / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
def plot_shift(
base_prevs,
dict_vals,
*,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
) -> Path:
if train_prev is None:
raise AttributeError("train_prev cannot be None.")
train_prev = train_prev[pos_class]
t_prev_pos = int(round(train_prev * 100))
title = f"shift_{name}_{t_prev_pos}_{metric}"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(dict_vals)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
ax.set_prop_cycle(
color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)],
)
base_prevs = np.around(
[abs(bp[pos_class] - train_prev) for bp in base_prevs], decimals=2
)
for method, deltas in dict_vals.items():
delta_bins = {}
for bp, delta in zip(base_prevs, deltas):
if bp not in delta_bins:
delta_bins[bp] = []
delta_bins[bp].append(delta)
bp_unique, delta_avg = zip(
*sorted(
{k: np.mean(v) for k, v in delta_bins.items()}.items(),
key=lambda db: db[0],
)
)
ax.plot(
bp_unique,
delta_avg,
label=method,
linestyle="-",
marker="o",
@ -19,8 +194,10 @@ def plot_delta(base_prevs, dict_vals, metric, title):
)
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
# ax.set_ylim(0, 1)
# ax.set_xlim(0, 1)
ax.legend()
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
output_path = env.PLOT_OUT_DIR / f"{title}.png"
plt.savefig(output_path)
fig.savefig(output_path, bbox_inches="tight")
return output_path