Merge pull request #4 from lorenzovolpi/grid_search
Grid search implemented for AccuracyEstimator
This commit is contained in:
commit
c52dc6498f
|
@ -1,17 +1,20 @@
|
||||||
*.code-workspace
|
*.code-workspace
|
||||||
quavenv/*
|
quavenv/*
|
||||||
*.pdf
|
*.pdf
|
||||||
|
|
||||||
|
__pycache__/*
|
||||||
|
baselines/__pycache__/*
|
||||||
|
baselines/densratio/__pycache__/*
|
||||||
quacc/__pycache__/*
|
quacc/__pycache__/*
|
||||||
quacc/evaluation/__pycache__/*
|
quacc/evaluation/__pycache__/*
|
||||||
|
quacc/method/__pycache__/*
|
||||||
tests/__pycache__/*
|
tests/__pycache__/*
|
||||||
garg22_ATC/__pycache__/*
|
|
||||||
guillory21_doc/__pycache__/*
|
|
||||||
jiang18_trustscore/__pycache__/*
|
|
||||||
lipton_bbse/__pycache__/*
|
|
||||||
elsahar19_rca/__pycache__/*
|
|
||||||
*.coverage
|
*.coverage
|
||||||
.coverage
|
.coverage
|
||||||
|
|
||||||
scp_sync.py
|
scp_sync.py
|
||||||
|
|
||||||
out/*
|
out/*
|
||||||
output/*
|
output/*
|
||||||
*.log
|
!output/main/
|
|
@ -5,7 +5,6 @@
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
|
|
||||||
|
|
||||||
{
|
{
|
||||||
"name": "main",
|
"name": "main",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
|
@ -15,12 +14,12 @@
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "models",
|
"name": "main_test",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\baselines\\models.py",
|
"program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\quacc\\main_test.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": false
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
|
@ -1,14 +1,5 @@
|
||||||
{
|
{
|
||||||
"todo": [
|
"todo": [
|
||||||
{
|
|
||||||
"assignedTo": {
|
|
||||||
"name": "Lorenzo Volpi"
|
|
||||||
},
|
|
||||||
"creation_time": "2023-10-28T14:34:46.226Z",
|
|
||||||
"id": "4",
|
|
||||||
"references": [],
|
|
||||||
"title": "Aggingere estimator basati su PACC (quantificatore)"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"assignedTo": {
|
"assignedTo": {
|
||||||
"name": "Lorenzo Volpi"
|
"name": "Lorenzo Volpi"
|
||||||
|
@ -18,15 +9,6 @@
|
||||||
"references": [],
|
"references": [],
|
||||||
"title": "Creare plot avg con training prevalence sull'asse x e media rispetto a test prevalence"
|
"title": "Creare plot avg con training prevalence sull'asse x e media rispetto a test prevalence"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"assignedTo": {
|
|
||||||
"name": "Lorenzo Volpi"
|
|
||||||
},
|
|
||||||
"creation_time": "2023-10-28T14:34:23.217Z",
|
|
||||||
"id": "3",
|
|
||||||
"references": [],
|
|
||||||
"title": "Relaizzare grid search per task specifico partedno da GridSearchQ"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"assignedTo": {
|
"assignedTo": {
|
||||||
"name": "Lorenzo Volpi"
|
"name": "Lorenzo Volpi"
|
||||||
|
@ -38,6 +20,27 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"in-progress": [
|
"in-progress": [
|
||||||
|
{
|
||||||
|
"assignedTo": {
|
||||||
|
"name": "Lorenzo Volpi"
|
||||||
|
},
|
||||||
|
"creation_time": "2023-10-28T14:34:23.217Z",
|
||||||
|
"id": "3",
|
||||||
|
"references": [],
|
||||||
|
"title": "Relaizzare grid search per task specifico partedno da GridSearchQ"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"assignedTo": {
|
||||||
|
"name": "Lorenzo Volpi"
|
||||||
|
},
|
||||||
|
"creation_time": "2023-10-28T14:34:46.226Z",
|
||||||
|
"id": "4",
|
||||||
|
"references": [],
|
||||||
|
"title": "Aggingere estimator basati su PACC (quantificatore)"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"testing": [],
|
||||||
|
"done": [
|
||||||
{
|
{
|
||||||
"assignedTo": {
|
"assignedTo": {
|
||||||
"name": "Lorenzo Volpi"
|
"name": "Lorenzo Volpi"
|
||||||
|
@ -47,7 +50,5 @@
|
||||||
"references": [],
|
"references": [],
|
||||||
"title": "Rework rappresentazione dati di report"
|
"title": "Rework rappresentazione dati di report"
|
||||||
}
|
}
|
||||||
],
|
]
|
||||||
"testing": [],
|
|
||||||
"done": []
|
|
||||||
}
|
}
|
36
TODO.html
36
TODO.html
|
@ -103,17 +103,35 @@ verbose=True).fit(V_tr)</li>
|
||||||
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> import baselines</p>
|
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> import baselines</p>
|
||||||
</li>
|
</li>
|
||||||
<li class="task-list-item enabled">
|
<li class="task-list-item enabled">
|
||||||
|
<p><input class="task-list-item-checkbox"type="checkbox"> importare mandoline</p>
|
||||||
|
<ul>
|
||||||
|
<li>mandoline può essere importato, ma richiedere uno slicing delle features a priori che devere essere realizzato ad hoc</li>
|
||||||
|
</ul>
|
||||||
|
</li>
|
||||||
|
<li class="task-list-item enabled">
|
||||||
|
<p><input class="task-list-item-checkbox"type="checkbox"> sistemare vecchie iw baselines</p>
|
||||||
|
<ul>
|
||||||
|
<li>non possono essere fixate perché dipendono da numpy</li>
|
||||||
|
</ul>
|
||||||
|
</li>
|
||||||
|
<li class="task-list-item enabled">
|
||||||
|
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> plot avg con train prevalence sull'asse x e media su test prevalecne</p>
|
||||||
|
</li>
|
||||||
|
<li class="task-list-item enabled">
|
||||||
|
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> realizzare grid search per task specifico partendo da GridSearchQ</p>
|
||||||
|
</li>
|
||||||
|
<li class="task-list-item enabled">
|
||||||
|
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> provare PACC come quantificatore</p>
|
||||||
|
</li>
|
||||||
|
<li class="task-list-item enabled">
|
||||||
|
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> aggiungere etichette in shift plot</p>
|
||||||
|
</li>
|
||||||
|
<li class="task-list-item enabled">
|
||||||
|
<p><input class="task-list-item-checkbox" checked=""type="checkbox"> sistemare exact_train quapy</p>
|
||||||
|
</li>
|
||||||
|
<li class="task-list-item enabled">
|
||||||
<p><input class="task-list-item-checkbox"type="checkbox"> testare anche su imbd</p>
|
<p><input class="task-list-item-checkbox"type="checkbox"> testare anche su imbd</p>
|
||||||
</li>
|
</li>
|
||||||
<li class="task-list-item enabled">
|
|
||||||
<p><input class="task-list-item-checkbox"type="checkbox"> plot avg con train prevalence sull'asse x e media su test prevalecne</p>
|
|
||||||
</li>
|
|
||||||
<li class="task-list-item enabled">
|
|
||||||
<p><input class="task-list-item-checkbox"type="checkbox"> realizzare grid search per task specifico partendo da GridSearchQ</p>
|
|
||||||
</li>
|
|
||||||
<li class="task-list-item enabled">
|
|
||||||
<p><input class="task-list-item-checkbox"type="checkbox"> provare PACC come quantificatore</p>
|
|
||||||
</li>
|
|
||||||
</ul>
|
</ul>
|
||||||
|
|
||||||
|
|
||||||
|
|
12
TODO.md
12
TODO.md
|
@ -30,7 +30,13 @@
|
||||||
- nel caso di bin fare media dei due best score
|
- nel caso di bin fare media dei due best score
|
||||||
- [x] import baselines
|
- [x] import baselines
|
||||||
|
|
||||||
|
- [ ] importare mandoline
|
||||||
|
- mandoline può essere importato, ma richiedere uno slicing delle features a priori che devere essere realizzato ad hoc
|
||||||
|
- [ ] sistemare vecchie iw baselines
|
||||||
|
- non possono essere fixate perché dipendono da numpy
|
||||||
|
- [x] plot avg con train prevalence sull'asse x e media su test prevalecne
|
||||||
|
- [x] realizzare grid search per task specifico partendo da GridSearchQ
|
||||||
|
- [x] provare PACC come quantificatore
|
||||||
|
- [x] aggiungere etichette in shift plot
|
||||||
|
- [x] sistemare exact_train quapy
|
||||||
- [ ] testare anche su imbd
|
- [ ] testare anche su imbd
|
||||||
- [ ] plot avg con train prevalence sull'asse x e media su test prevalecne
|
|
||||||
- [ ] realizzare grid search per task specifico partendo da GridSearchQ
|
|
||||||
- [ ] provare PACC come quantificatore
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,19 +1,22 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.metrics import f1_score
|
from sklearn.metrics import f1_score
|
||||||
|
|
||||||
|
|
||||||
def get_entropy(probs):
|
def get_entropy(probs):
|
||||||
return np.sum( np.multiply(probs, np.log(probs + 1e-20)) , axis=1)
|
return np.sum(np.multiply(probs, np.log(probs + 1e-20)), axis=1)
|
||||||
|
|
||||||
|
|
||||||
def get_max_conf(probs):
|
def get_max_conf(probs):
|
||||||
return np.max(probs, axis=-1)
|
return np.max(probs, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def find_ATC_threshold(scores, labels):
|
def find_ATC_threshold(scores, labels):
|
||||||
sorted_idx = np.argsort(scores)
|
sorted_idx = np.argsort(scores)
|
||||||
|
|
||||||
sorted_scores = scores[sorted_idx]
|
sorted_scores = scores[sorted_idx]
|
||||||
sorted_labels = labels[sorted_idx]
|
sorted_labels = labels[sorted_idx]
|
||||||
|
|
||||||
fp = np.sum(labels==0)
|
fp = np.sum(labels == 0)
|
||||||
fn = 0.0
|
fn = 0.0
|
||||||
|
|
||||||
min_fp_fn = np.abs(fp - fn)
|
min_fp_fn = np.abs(fp - fn)
|
||||||
|
@ -32,10 +35,10 @@ def find_ATC_threshold(scores, labels):
|
||||||
|
|
||||||
|
|
||||||
def get_ATC_acc(thres, scores):
|
def get_ATC_acc(thres, scores):
|
||||||
return np.mean(scores>=thres)
|
return np.mean(scores >= thres)
|
||||||
|
|
||||||
|
|
||||||
def get_ATC_f1(thres, scores, probs):
|
def get_ATC_f1(thres, scores, probs):
|
||||||
preds = np.argmax(probs, axis=-1)
|
preds = np.argmax(probs, axis=-1)
|
||||||
estim_y = abs(1 - (scores>=thres)^preds)
|
estim_y = np.abs(1 - (scores >= thres) ^ preds)
|
||||||
return f1_score(estim_y, preds)
|
return f1_score(estim_y, preds)
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -4,6 +4,20 @@ from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.model_selection import GridSearchCV
|
from sklearn.model_selection import GridSearchCV
|
||||||
from sklearn.neighbors import KernelDensity
|
from sklearn.neighbors import KernelDensity
|
||||||
|
|
||||||
|
from baselines import densratio
|
||||||
|
from baselines.pykliep import DensityRatioEstimator
|
||||||
|
|
||||||
|
|
||||||
|
def kliep(Xtr, ytr, Xte):
|
||||||
|
kliep = DensityRatioEstimator()
|
||||||
|
kliep.fit(Xtr, Xte)
|
||||||
|
return kliep.predict(Xtr)
|
||||||
|
|
||||||
|
|
||||||
|
def usilf(Xtr, ytr, Xte, alpha=0.0):
|
||||||
|
dense_ratio_obj = densratio(Xtr, Xte, alpha=alpha, verbose=False)
|
||||||
|
return dense_ratio_obj.compute_density_ratio(Xtr)
|
||||||
|
|
||||||
|
|
||||||
def logreg(Xtr, ytr, Xte):
|
def logreg(Xtr, ytr, Xte):
|
||||||
# check "Direct Density Ratio Estimation for
|
# check "Direct Density Ratio Estimation for
|
||||||
|
|
|
@ -123,7 +123,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for sample in protocol():
|
for sample in protocol():
|
||||||
wx = iw.logreg(d.validation.X, d.validation.y, sample.X)
|
wx = iw.kliep(d.validation.X, d.validation.y, sample.X)
|
||||||
test_preds = lr.predict(sample.X)
|
test_preds = lr.predict(sample.X)
|
||||||
estim_acc = np.sum((1.0 * (val_preds == d.validation.y)) * wx) / np.sum(wx)
|
estim_acc = np.sum((1.0 * (val_preds == d.validation.y)) * wx) / np.sum(wx)
|
||||||
true_acc = metrics.accuracy_score(sample.y, test_preds)
|
true_acc = metrics.accuracy_score(sample.y, test_preds)
|
||||||
|
|
|
@ -74,7 +74,9 @@ class DensityRatioEstimator:
|
||||||
# X_test_shuffled = X_test.copy()
|
# X_test_shuffled = X_test.copy()
|
||||||
X_test_shuffled = X_test.copy()
|
X_test_shuffled = X_test.copy()
|
||||||
|
|
||||||
np.random.shuffle(X_test_shuffled)
|
X_test_index = np.arange(X_test_shuffled.shape[0])
|
||||||
|
np.random.shuffle(X_test_index)
|
||||||
|
X_test_shuffled = X_test_shuffled[X_test_index, :]
|
||||||
|
|
||||||
j_scores = {}
|
j_scores = {}
|
||||||
|
|
||||||
|
|
187
conf.yaml
187
conf.yaml
|
@ -4,56 +4,46 @@ debug_conf: &debug_conf
|
||||||
- acc
|
- acc
|
||||||
DATASET_N_PREVS: 5
|
DATASET_N_PREVS: 5
|
||||||
DATASET_PREVS:
|
DATASET_PREVS:
|
||||||
|
# - 0.2
|
||||||
- 0.5
|
- 0.5
|
||||||
- 0.1
|
# - 0.8
|
||||||
|
|
||||||
confs:
|
confs:
|
||||||
- DATASET_NAME: imdb
|
- DATASET_NAME: rcv1
|
||||||
|
DATASET_TARGET: CCAT
|
||||||
|
|
||||||
plot_confs:
|
plot_confs:
|
||||||
debug:
|
debug:
|
||||||
PLOT_ESTIMATORS:
|
PLOT_ESTIMATORS:
|
||||||
- ref
|
- mulmc_sld
|
||||||
- atc_mc
|
- atc_mc
|
||||||
- atc_ne
|
|
||||||
PLOT_STDEV: true
|
PLOT_STDEV: true
|
||||||
debug_plus:
|
|
||||||
PLOT_ESTIMATORS:
|
|
||||||
- mul_sld_bcts
|
|
||||||
- mul_sld
|
|
||||||
- ref
|
|
||||||
- atc_mc
|
|
||||||
- atc_ne
|
|
||||||
|
|
||||||
test_conf: &test_conf
|
mc_conf: &mc_conf
|
||||||
global:
|
global:
|
||||||
METRICS:
|
METRICS:
|
||||||
- acc
|
- acc
|
||||||
- f1
|
DATASET_N_PREVS: 9
|
||||||
DATASET_N_PREVS: 2
|
DATASET_DIR_UPDATE: true
|
||||||
DATASET_PREVS:
|
|
||||||
- 0.5
|
|
||||||
- 0.1
|
|
||||||
|
|
||||||
confs:
|
confs:
|
||||||
# - DATASET_NAME: rcv1
|
- DATASET_NAME: rcv1
|
||||||
# DATASET_TARGET: CCAT
|
DATASET_TARGET: CCAT
|
||||||
- DATASET_NAME: imdb
|
# - DATASET_NAME: imdb
|
||||||
|
|
||||||
plot_confs:
|
plot_confs:
|
||||||
best_vs_atc:
|
debug3:
|
||||||
PLOT_ESTIMATORS:
|
PLOT_ESTIMATORS:
|
||||||
- bin_sld
|
- binmc_sld
|
||||||
- bin_sld_bcts
|
- mulmc_sld
|
||||||
|
- binne_sld
|
||||||
|
- mulne_sld
|
||||||
- bin_sld_gs
|
- bin_sld_gs
|
||||||
- mul_sld
|
|
||||||
- mul_sld_bcts
|
|
||||||
- mul_sld_gs
|
- mul_sld_gs
|
||||||
- ref
|
|
||||||
- atc_mc
|
- atc_mc
|
||||||
- atc_ne
|
PLOT_STDEV: true
|
||||||
|
|
||||||
main_conf: &main_conf
|
test_conf: &test_conf
|
||||||
global:
|
global:
|
||||||
METRICS:
|
METRICS:
|
||||||
- acc
|
- acc
|
||||||
|
@ -63,29 +53,160 @@ main_conf: &main_conf
|
||||||
confs:
|
confs:
|
||||||
- DATASET_NAME: rcv1
|
- DATASET_NAME: rcv1
|
||||||
DATASET_TARGET: CCAT
|
DATASET_TARGET: CCAT
|
||||||
confs_bck:
|
# - DATASET_NAME: imdb
|
||||||
|
|
||||||
|
plot_confs:
|
||||||
|
gs_vs_gsq:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- bin_sld
|
||||||
|
- bin_sld_gs
|
||||||
|
- bin_sld_gsq
|
||||||
|
- mul_sld
|
||||||
|
- mul_sld_gs
|
||||||
|
- mul_sld_gsq
|
||||||
|
gs_vs_atc:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- bin_sld
|
||||||
|
- bin_sld_gs
|
||||||
|
- mul_sld
|
||||||
|
- mul_sld_gs
|
||||||
|
- atc_mc
|
||||||
|
- atc_ne
|
||||||
|
sld_vs_pacc:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- bin_sld
|
||||||
|
- bin_sld_gs
|
||||||
|
- mul_sld
|
||||||
|
- mul_sld_gs
|
||||||
|
- bin_pacc
|
||||||
|
- bin_pacc_gs
|
||||||
|
- mul_pacc
|
||||||
|
- mul_pacc_gs
|
||||||
|
- atc_mc
|
||||||
|
- atc_ne
|
||||||
|
pacc_vs_atc:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- bin_pacc
|
||||||
|
- bin_pacc_gs
|
||||||
|
- mul_pacc
|
||||||
|
- mul_pacc_gs
|
||||||
|
- atc_mc
|
||||||
|
- atc_ne
|
||||||
|
|
||||||
|
main_conf: &main_conf
|
||||||
|
|
||||||
|
global:
|
||||||
|
METRICS:
|
||||||
|
- acc
|
||||||
|
- f1
|
||||||
|
DATASET_N_PREVS: 9
|
||||||
|
DATASET_DIR_UPDATE: true
|
||||||
|
|
||||||
|
confs:
|
||||||
|
- DATASET_NAME: rcv1
|
||||||
|
DATASET_TARGET: CCAT
|
||||||
- DATASET_NAME: imdb
|
- DATASET_NAME: imdb
|
||||||
|
confs_next:
|
||||||
- DATASET_NAME: rcv1
|
- DATASET_NAME: rcv1
|
||||||
DATASET_TARGET: GCAT
|
DATASET_TARGET: GCAT
|
||||||
- DATASET_NAME: rcv1
|
- DATASET_NAME: rcv1
|
||||||
DATASET_TARGET: MCAT
|
DATASET_TARGET: MCAT
|
||||||
|
|
||||||
plot_confs:
|
plot_confs:
|
||||||
|
gs_vs_qgs:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- mul_sld_gs
|
||||||
|
- bin_sld_gs
|
||||||
|
- mul_sld_gsq
|
||||||
|
- bin_sld_gsq
|
||||||
|
- atc_mc
|
||||||
|
- atc_ne
|
||||||
|
PLOT_STDEV: true
|
||||||
|
plot_confs_completed:
|
||||||
|
max_conf_vs_atc_pacc:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- bin_pacc
|
||||||
|
- binmc_pacc
|
||||||
|
- mul_pacc
|
||||||
|
- mulmc_pacc
|
||||||
|
- atc_mc
|
||||||
|
PLOT_STDEV: true
|
||||||
|
max_conf_vs_entropy_pacc:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- binmc_pacc
|
||||||
|
- binne_pacc
|
||||||
|
- mulmc_pacc
|
||||||
|
- mulne_pacc
|
||||||
|
- atc_mc
|
||||||
|
PLOT_STDEV: true
|
||||||
gs_vs_atc:
|
gs_vs_atc:
|
||||||
PLOT_ESTIMATORS:
|
PLOT_ESTIMATORS:
|
||||||
- mul_sld_gs
|
- mul_sld_gs
|
||||||
- bin_sld_gs
|
- bin_sld_gs
|
||||||
- ref
|
- mul_pacc_gs
|
||||||
|
- bin_pacc_gs
|
||||||
- atc_mc
|
- atc_mc
|
||||||
- atc_ne
|
- atc_ne
|
||||||
PLOT_STDEV: true
|
PLOT_STDEV: true
|
||||||
|
gs_vs_all:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- mul_sld_gs
|
||||||
|
- bin_sld_gs
|
||||||
|
- mul_pacc_gs
|
||||||
|
- bin_pacc_gs
|
||||||
|
- atc_mc
|
||||||
|
- doc_feat
|
||||||
|
- kfcv
|
||||||
|
PLOT_STDEV: true
|
||||||
|
gs_vs_qgs:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- mul_sld_gs
|
||||||
|
- bin_sld_gs
|
||||||
|
- mul_sld_gsq
|
||||||
|
- bin_sld_gsq
|
||||||
|
- atc_mc
|
||||||
|
- atc_ne
|
||||||
|
PLOT_STDEV: true
|
||||||
|
cc_vs_other:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- mul_cc
|
||||||
|
- bin_cc
|
||||||
|
- mul_sld
|
||||||
|
- bin_sld
|
||||||
|
- mul_pacc
|
||||||
|
- bin_pacc
|
||||||
|
PLOT_STDEV: true
|
||||||
|
max_conf_vs_atc:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- bin_sld
|
||||||
|
- binmc_sld
|
||||||
|
- mul_sld
|
||||||
|
- mulmc_sld
|
||||||
|
- atc_mc
|
||||||
|
PLOT_STDEV: true
|
||||||
|
max_conf_vs_entropy:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- binmc_sld
|
||||||
|
- binne_sld
|
||||||
|
- mulmc_sld
|
||||||
|
- mulne_sld
|
||||||
|
- atc_mc
|
||||||
|
PLOT_STDEV: true
|
||||||
|
sld_vs_pacc:
|
||||||
|
PLOT_ESTIMATORS:
|
||||||
|
- bin_sld
|
||||||
|
- mul_sld
|
||||||
|
- bin_pacc
|
||||||
|
- mul_pacc
|
||||||
|
- atc_mc
|
||||||
|
PLOT_STDEV: true
|
||||||
|
plot_confs_other:
|
||||||
best_vs_atc:
|
best_vs_atc:
|
||||||
PLOT_ESTIMATORS:
|
PLOT_ESTIMATORS:
|
||||||
- mul_sld_bcts
|
- mul_sld_bcts
|
||||||
- mul_sld_gs
|
- mul_sld_gs
|
||||||
- bin_sld_bcts
|
- bin_sld_bcts
|
||||||
- bin_sld_gs
|
- bin_sld_gs
|
||||||
- ref
|
|
||||||
- atc_mc
|
- atc_mc
|
||||||
- atc_ne
|
- atc_ne
|
||||||
all_vs_atc:
|
all_vs_atc:
|
||||||
|
@ -96,7 +217,6 @@ main_conf: &main_conf
|
||||||
- mul_sld
|
- mul_sld
|
||||||
- mul_sld_bcts
|
- mul_sld_bcts
|
||||||
- mul_sld_gs
|
- mul_sld_gs
|
||||||
- ref
|
|
||||||
- atc_mc
|
- atc_mc
|
||||||
- atc_ne
|
- atc_ne
|
||||||
best_vs_all:
|
best_vs_all:
|
||||||
|
@ -105,10 +225,9 @@ main_conf: &main_conf
|
||||||
- bin_sld_gs
|
- bin_sld_gs
|
||||||
- mul_sld_bcts
|
- mul_sld_bcts
|
||||||
- mul_sld_gs
|
- mul_sld_gs
|
||||||
- ref
|
|
||||||
- kfcv
|
- kfcv
|
||||||
- atc_mc
|
- atc_mc
|
||||||
- atc_ne
|
- atc_ne
|
||||||
- doc_feat
|
- doc_feat
|
||||||
|
|
||||||
exec: *debug_conf
|
exec: *main_conf
|
|
@ -15,6 +15,104 @@ numpy = ">=1.9"
|
||||||
scikit-learn = ">=0.20.0"
|
scikit-learn = ">=0.20.0"
|
||||||
scipy = ">=1.1.0"
|
scipy = ">=1.1.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bcrypt"
|
||||||
|
version = "4.0.1"
|
||||||
|
description = "Modern password hashing for your software and your servers"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:b1023030aec778185a6c16cf70f359cbb6e0c289fd564a7cfa29e727a1c38f8f"},
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:08d2947c490093a11416df18043c27abe3921558d2c03e2076ccb28a116cb6d0"},
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0eaa47d4661c326bfc9d08d16debbc4edf78778e6aaba29c1bc7ce67214d4410"},
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae88eca3024bb34bb3430f964beab71226e761f51b912de5133470b649d82344"},
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:a522427293d77e1c29e303fc282e2d71864579527a04ddcfda6d4f8396c6c36a"},
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:fbdaec13c5105f0c4e5c52614d04f0bca5f5af007910daa8b6b12095edaa67b3"},
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:ca3204d00d3cb2dfed07f2d74a25f12fc12f73e606fcaa6975d1f7ae69cacbb2"},
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:089098effa1bc35dc055366740a067a2fc76987e8ec75349eb9484061c54f535"},
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:e9a51bbfe7e9802b5f3508687758b564069ba937748ad7b9e890086290d2f79e"},
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-win32.whl", hash = "sha256:2caffdae059e06ac23fce178d31b4a702f2a3264c20bfb5ff541b338194d8fab"},
|
||||||
|
{file = "bcrypt-4.0.1-cp36-abi3-win_amd64.whl", hash = "sha256:8a68f4341daf7522fe8d73874de8906f3a339048ba406be6ddc1b3ccb16fc0d9"},
|
||||||
|
{file = "bcrypt-4.0.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf4fa8b2ca74381bb5442c089350f09a3f17797829d958fad058d6e44d9eb83c"},
|
||||||
|
{file = "bcrypt-4.0.1-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:67a97e1c405b24f19d08890e7ae0c4f7ce1e56a712a016746c8b2d7732d65d4b"},
|
||||||
|
{file = "bcrypt-4.0.1-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b3b85202d95dd568efcb35b53936c5e3b3600c7cdcc6115ba461df3a8e89f38d"},
|
||||||
|
{file = "bcrypt-4.0.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbb03eec97496166b704ed663a53680ab57c5084b2fc98ef23291987b525cb7d"},
|
||||||
|
{file = "bcrypt-4.0.1-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:5ad4d32a28b80c5fa6671ccfb43676e8c1cc232887759d1cd7b6f56ea4355215"},
|
||||||
|
{file = "bcrypt-4.0.1-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b57adba8a1444faf784394de3436233728a1ecaeb6e07e8c22c8848f179b893c"},
|
||||||
|
{file = "bcrypt-4.0.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:705b2cea8a9ed3d55b4491887ceadb0106acf7c6387699fca771af56b1cdeeda"},
|
||||||
|
{file = "bcrypt-4.0.1-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:2b3ac11cf45161628f1f3733263e63194f22664bf4d0c0f3ab34099c02134665"},
|
||||||
|
{file = "bcrypt-4.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3100851841186c25f127731b9fa11909ab7b1df6fc4b9f8353f4f1fd952fbf71"},
|
||||||
|
{file = "bcrypt-4.0.1.tar.gz", hash = "sha256:27d375903ac8261cfe4047f6709d16f7d18d39b1ec92aaf72af989552a650ebd"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
tests = ["pytest (>=3.2.1,!=3.3.0)"]
|
||||||
|
typecheck = ["mypy"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cffi"
|
||||||
|
version = "1.16.0"
|
||||||
|
description = "Foreign Function Interface for Python calling C code."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"},
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"},
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673"},
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896"},
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684"},
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7"},
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614"},
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743"},
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d"},
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-win32.whl", hash = "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a"},
|
||||||
|
{file = "cffi-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b84834d0cf97e7d27dd5b7f3aca7b6e9263c56308ab9dc8aae9784abb774d404"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b8ebc27c014c59692bb2664c7d13ce7a6e9a629be20e54e7271fa696ff2b417"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ee07e47c12890ef248766a6e55bd38ebfb2bb8edd4142d56db91b21ea68b7627"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8a9d3ebe49f084ad71f9269834ceccbf398253c9fac910c4fd7053ff1386936"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e70f54f1796669ef691ca07d046cd81a29cb4deb1e5f942003f401c0c4a2695d"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5bf44d66cdf9e893637896c7faa22298baebcd18d1ddb6d2626a6e39793a1d56"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b78010e7b97fef4bee1e896df8a4bbb6712b7f05b7ef630f9d1da00f6444d2e"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c6a164aa47843fb1b01e941d385aab7215563bb8816d80ff3a363a9f8448a8dc"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e09f3ff613345df5e8c3667da1d918f9149bd623cd9070c983c013792a9a62eb"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-win32.whl", hash = "sha256:2c56b361916f390cd758a57f2e16233eb4f64bcbeee88a4881ea90fca14dc6ab"},
|
||||||
|
{file = "cffi-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:db8e577c19c0fda0beb7e0d4e09e0ba74b1e4c092e0e40bfa12fe05b6f6d75ba"},
|
||||||
|
{file = "cffi-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956"},
|
||||||
|
{file = "cffi-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e"},
|
||||||
|
{file = "cffi-1.16.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e"},
|
||||||
|
{file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2"},
|
||||||
|
{file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357"},
|
||||||
|
{file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6"},
|
||||||
|
{file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969"},
|
||||||
|
{file = "cffi-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520"},
|
||||||
|
{file = "cffi-1.16.0-cp312-cp312-win32.whl", hash = "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b"},
|
||||||
|
{file = "cffi-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235"},
|
||||||
|
{file = "cffi-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc"},
|
||||||
|
{file = "cffi-1.16.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a09582f178759ee8128d9270cd1344154fd473bb77d94ce0aeb2a93ebf0feaf0"},
|
||||||
|
{file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e760191dd42581e023a68b758769e2da259b5d52e3103c6060ddc02c9edb8d7b"},
|
||||||
|
{file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80876338e19c951fdfed6198e70bc88f1c9758b94578d5a7c4c91a87af3cf31c"},
|
||||||
|
{file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6a14b17d7e17fa0d207ac08642c8820f84f25ce17a442fd15e27ea18d67c59b"},
|
||||||
|
{file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6602bc8dc6f3a9e02b6c22c4fc1e47aa50f8f8e6d3f78a5e16ac33ef5fefa324"},
|
||||||
|
{file = "cffi-1.16.0-cp38-cp38-win32.whl", hash = "sha256:131fd094d1065b19540c3d72594260f118b231090295d8c34e19a7bbcf2e860a"},
|
||||||
|
{file = "cffi-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:31d13b0f99e0836b7ff893d37af07366ebc90b678b6664c955b54561fc36ef36"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:582215a0e9adbe0e379761260553ba11c58943e4bbe9c36430c4ca6ac74b15ed"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b29ebffcf550f9da55bec9e02ad430c992a87e5f512cd63388abb76f1036d8d2"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc9b18bf40cc75f66f40a7379f6a9513244fe33c0e8aa72e2d56b0196a7ef872"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cb4a35b3642fc5c005a6755a5d17c6c8b6bcb6981baf81cea8bfbc8903e8ba8"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b86851a328eedc692acf81fb05444bdf1891747c25af7529e39ddafaf68a4f3f"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0f31130ebc2d37cdd8e44605fb5fa7ad59049298b3f745c74fa74c62fbfcfc4"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:748dcd1e3d3d7cd5443ef03ce8685043294ad6bd7c02a38d1bd367cfd968e000"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8895613bcc094d4a1b2dbe179d88d7fb4a15cee43c052e8885783fac397d91fe"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed86a35631f7bfbb28e108dd96773b9d5a6ce4811cf6ea468bb6a359b256b1e4"},
|
||||||
|
{file = "cffi-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:3686dffb02459559c74dd3d81748269ffb0eb027c39a6fc99502de37d501faa8"},
|
||||||
|
{file = "cffi-1.16.0.tar.gz", hash = "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pycparser = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "colorama"
|
name = "colorama"
|
||||||
version = "0.4.6"
|
version = "0.4.6"
|
||||||
|
@ -223,6 +321,51 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
toml = ["tomli"]
|
toml = ["tomli"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cryptography"
|
||||||
|
version = "41.0.5"
|
||||||
|
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "cryptography-41.0.5-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:da6a0ff8f1016ccc7477e6339e1d50ce5f59b88905585f77193ebd5068f1e797"},
|
||||||
|
{file = "cryptography-41.0.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:b948e09fe5fb18517d99994184854ebd50b57248736fd4c720ad540560174ec5"},
|
||||||
|
{file = "cryptography-41.0.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d38e6031e113b7421db1de0c1b1f7739564a88f1684c6b89234fbf6c11b75147"},
|
||||||
|
{file = "cryptography-41.0.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e270c04f4d9b5671ebcc792b3ba5d4488bf7c42c3c241a3748e2599776f29696"},
|
||||||
|
{file = "cryptography-41.0.5-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ec3b055ff8f1dce8e6ef28f626e0972981475173d7973d63f271b29c8a2897da"},
|
||||||
|
{file = "cryptography-41.0.5-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:7d208c21e47940369accfc9e85f0de7693d9a5d843c2509b3846b2db170dfd20"},
|
||||||
|
{file = "cryptography-41.0.5-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:8254962e6ba1f4d2090c44daf50a547cd5f0bf446dc658a8e5f8156cae0d8548"},
|
||||||
|
{file = "cryptography-41.0.5-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:a48e74dad1fb349f3dc1d449ed88e0017d792997a7ad2ec9587ed17405667e6d"},
|
||||||
|
{file = "cryptography-41.0.5-cp37-abi3-win32.whl", hash = "sha256:d3977f0e276f6f5bf245c403156673db103283266601405376f075c849a0b936"},
|
||||||
|
{file = "cryptography-41.0.5-cp37-abi3-win_amd64.whl", hash = "sha256:73801ac9736741f220e20435f84ecec75ed70eda90f781a148f1bad546963d81"},
|
||||||
|
{file = "cryptography-41.0.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3be3ca726e1572517d2bef99a818378bbcf7d7799d5372a46c79c29eb8d166c1"},
|
||||||
|
{file = "cryptography-41.0.5-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:e886098619d3815e0ad5790c973afeee2c0e6e04b4da90b88e6bd06e2a0b1b72"},
|
||||||
|
{file = "cryptography-41.0.5-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:573eb7128cbca75f9157dcde974781209463ce56b5804983e11a1c462f0f4e88"},
|
||||||
|
{file = "cryptography-41.0.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:0c327cac00f082013c7c9fb6c46b7cc9fa3c288ca702c74773968173bda421bf"},
|
||||||
|
{file = "cryptography-41.0.5-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:227ec057cd32a41c6651701abc0328135e472ed450f47c2766f23267b792a88e"},
|
||||||
|
{file = "cryptography-41.0.5-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:22892cc830d8b2c89ea60148227631bb96a7da0c1b722f2aac8824b1b7c0b6b8"},
|
||||||
|
{file = "cryptography-41.0.5-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:5a70187954ba7292c7876734183e810b728b4f3965fbe571421cb2434d279179"},
|
||||||
|
{file = "cryptography-41.0.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:88417bff20162f635f24f849ab182b092697922088b477a7abd6664ddd82291d"},
|
||||||
|
{file = "cryptography-41.0.5-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c707f7afd813478e2019ae32a7c49cd932dd60ab2d2a93e796f68236b7e1fbf1"},
|
||||||
|
{file = "cryptography-41.0.5-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:580afc7b7216deeb87a098ef0674d6ee34ab55993140838b14c9b83312b37b86"},
|
||||||
|
{file = "cryptography-41.0.5-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:fba1e91467c65fe64a82c689dc6cf58151158993b13eb7a7f3f4b7f395636723"},
|
||||||
|
{file = "cryptography-41.0.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0d2a6a598847c46e3e321a7aef8af1436f11c27f1254933746304ff014664d84"},
|
||||||
|
{file = "cryptography-41.0.5.tar.gz", hash = "sha256:392cb88b597247177172e02da6b7a63deeff1937fa6fec3bbf902ebd75d97ec7"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
cffi = ">=1.12"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"]
|
||||||
|
docstest = ["pyenchant (>=1.6.11)", "sphinxcontrib-spelling (>=4.0.1)", "twine (>=1.12.0)"]
|
||||||
|
nox = ["nox"]
|
||||||
|
pep8test = ["black", "check-sdist", "mypy", "ruff"]
|
||||||
|
sdist = ["build"]
|
||||||
|
ssh = ["bcrypt (>=3.1.5)"]
|
||||||
|
test = ["pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"]
|
||||||
|
test-randomorder = ["pytest-randomly"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cycler"
|
name = "cycler"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
|
@ -728,6 +871,27 @@ sql-other = ["SQLAlchemy (>=1.4.36)"]
|
||||||
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"]
|
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"]
|
||||||
xml = ["lxml (>=4.8.0)"]
|
xml = ["lxml (>=4.8.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "paramiko"
|
||||||
|
version = "3.3.1"
|
||||||
|
description = "SSH2 protocol library"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "paramiko-3.3.1-py3-none-any.whl", hash = "sha256:b7bc5340a43de4287bbe22fe6de728aa2c22468b2a849615498dd944c2f275eb"},
|
||||||
|
{file = "paramiko-3.3.1.tar.gz", hash = "sha256:6a3777a961ac86dbef375c5f5b8d50014a1a96d0fd7f054a43bc880134b0ff77"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
bcrypt = ">=3.2"
|
||||||
|
cryptography = ">=3.3"
|
||||||
|
pynacl = ">=1.5"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["gssapi (>=1.4.1)", "invoke (>=2.0)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"]
|
||||||
|
gssapi = ["gssapi (>=1.4.1)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"]
|
||||||
|
invoke = ["invoke (>=2.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pillow"
|
name = "pillow"
|
||||||
version = "10.0.1"
|
version = "10.0.1"
|
||||||
|
@ -851,6 +1015,17 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
numpy = ">=1.16.6"
|
numpy = ">=1.16.6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pycparser"
|
||||||
|
version = "2.21"
|
||||||
|
description = "C parser in Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||||
|
files = [
|
||||||
|
{file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
|
||||||
|
{file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pylance"
|
name = "pylance"
|
||||||
version = "0.5.10"
|
version = "0.5.10"
|
||||||
|
@ -872,6 +1047,32 @@ pyarrow = ">=10"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
tests = ["duckdb", "ml_dtypes", "pandas (>=1.4)", "polars[pandas,pyarrow]", "pytest", "tensorflow"]
|
tests = ["duckdb", "ml_dtypes", "pandas (>=1.4)", "polars[pandas,pyarrow]", "pytest", "tensorflow"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pynacl"
|
||||||
|
version = "1.5.0"
|
||||||
|
description = "Python binding to the Networking and Cryptography (NaCl) library"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1"},
|
||||||
|
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:52cb72a79269189d4e0dc537556f4740f7f0a9ec41c1322598799b0bdad4ef92"},
|
||||||
|
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a36d4a9dda1f19ce6e03c9a784a2921a4b726b02e1c736600ca9c22029474394"},
|
||||||
|
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:0c84947a22519e013607c9be43706dd42513f9e6ae5d39d3613ca1e142fba44d"},
|
||||||
|
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06b8f6fa7f5de8d5d2f7573fe8c863c051225a27b61e6860fd047b1775807858"},
|
||||||
|
{file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a422368fc821589c228f4c49438a368831cb5bbc0eab5ebe1d7fac9dded6567b"},
|
||||||
|
{file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:61f642bf2378713e2c2e1de73444a3778e5f0a38be6fee0fe532fe30060282ff"},
|
||||||
|
{file = "PyNaCl-1.5.0-cp36-abi3-win32.whl", hash = "sha256:e46dae94e34b085175f8abb3b0aaa7da40767865ac82c928eeb9e57e1ea8a543"},
|
||||||
|
{file = "PyNaCl-1.5.0-cp36-abi3-win_amd64.whl", hash = "sha256:20f42270d27e1b6a29f54032090b972d97f0a1b0948cc52392041ef7831fee93"},
|
||||||
|
{file = "PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
cffi = ">=1.4.1"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"]
|
||||||
|
tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyparsing"
|
name = "pyparsing"
|
||||||
version = "3.1.1"
|
version = "3.1.1"
|
||||||
|
@ -1172,6 +1373,20 @@ files = [
|
||||||
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
|
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tabulate"
|
||||||
|
version = "0.9.0"
|
||||||
|
description = "Pretty-print tabular data"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
|
||||||
|
{file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
widechars = ["wcwidth"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "threadpoolctl"
|
name = "threadpoolctl"
|
||||||
version = "3.2.0"
|
version = "3.2.0"
|
||||||
|
@ -1271,4 +1486,4 @@ test = ["pytest", "pytest-cov"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "54d9922f6d48a46f554a6b350ce09d668a88755efb1fbf295f8f8a0a411bdef2"
|
content-hash = "8ce63f546a9858b6678a3eb3925d6f629cbf0c95e4ee8bdeeb3415ce184ffbc9"
|
||||||
|
|
|
@ -25,6 +25,8 @@ pylance = "^0.5.9"
|
||||||
pytest-mock = "^3.11.1"
|
pytest-mock = "^3.11.1"
|
||||||
pytest-cov = "^4.1.0"
|
pytest-cov = "^4.1.0"
|
||||||
win11toast = "^0.32"
|
win11toast = "^0.32"
|
||||||
|
tabulate = "^0.9.0"
|
||||||
|
paramiko = "^3.3.1"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--cov=quacc --capture=tee-sys"
|
addopts = "--cov=quacc --capture=tee-sys"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
|
import math
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
|
||||||
import scipy.sparse as sp
|
import scipy.sparse as sp
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
|
|
||||||
|
@ -128,7 +128,9 @@ class ExtendedCollection(LabelledCollection):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extend_collection(
|
def extend_collection(
|
||||||
cls, base: LabelledCollection, pred_proba: np.ndarray
|
cls,
|
||||||
|
base: LabelledCollection,
|
||||||
|
pred_proba: np.ndarray,
|
||||||
):
|
):
|
||||||
n_classes = base.n_classes
|
n_classes = base.n_classes
|
||||||
|
|
||||||
|
@ -136,13 +138,13 @@ class ExtendedCollection(LabelledCollection):
|
||||||
n_x = cls.extend_instances(base.X, pred_proba)
|
n_x = cls.extend_instances(base.X, pred_proba)
|
||||||
|
|
||||||
# n_y = (exptected y, predicted y)
|
# n_y = (exptected y, predicted y)
|
||||||
pred = np.asarray([prob.argmax(axis=0) for prob in pred_proba])
|
pred_proba = pred_proba[:, -n_classes:]
|
||||||
|
preds = np.argmax(pred_proba, axis=-1)
|
||||||
n_y = np.asarray(
|
n_y = np.asarray(
|
||||||
[
|
[
|
||||||
ExClassManager.get_ex(n_classes, true_class, pred_class)
|
ExClassManager.get_ex(n_classes, true_class, pred_class)
|
||||||
for (true_class, pred_class) in zip(base.y, pred)
|
for (true_class, pred_class) in zip(base.y, preds)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])
|
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])
|
||||||
|
|
||||||
|
|
|
@ -71,15 +71,13 @@ class Dataset:
|
||||||
|
|
||||||
return all_train, test
|
return all_train, test
|
||||||
|
|
||||||
def get_raw(self, validation=True) -> DatasetSample:
|
def get_raw(self) -> DatasetSample:
|
||||||
all_train, test = {
|
all_train, test = {
|
||||||
"spambase": self.__spambase,
|
"spambase": self.__spambase,
|
||||||
"imdb": self.__imdb,
|
"imdb": self.__imdb,
|
||||||
"rcv1": self.__rcv1,
|
"rcv1": self.__rcv1,
|
||||||
}[self._name]()
|
}[self._name]()
|
||||||
|
|
||||||
train, val = all_train, None
|
|
||||||
if validation:
|
|
||||||
train, val = all_train.split_stratified(
|
train, val = all_train.split_stratified(
|
||||||
train_prop=TRAIN_VAL_PROP, random_state=0
|
train_prop=TRAIN_VAL_PROP, random_state=0
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,13 +1,10 @@
|
||||||
import quapy as qp
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def from_name(err_name):
|
def from_name(err_name):
|
||||||
if err_name == "f1e":
|
assert err_name in ERROR_NAMES, f"unknown error {err_name}"
|
||||||
return f1e
|
callable_error = globals()[err_name]
|
||||||
elif err_name == "f1":
|
return callable_error
|
||||||
return f1
|
|
||||||
else:
|
|
||||||
return qp.error.from_name(err_name)
|
|
||||||
|
|
||||||
|
|
||||||
# def f1(prev):
|
# def f1(prev):
|
||||||
|
@ -36,5 +33,23 @@ def f1e(prev):
|
||||||
return 1 - f1(prev)
|
return 1 - f1(prev)
|
||||||
|
|
||||||
|
|
||||||
def acc(prev):
|
def acc(prev: np.ndarray) -> float:
|
||||||
return (prev[0] + prev[3]) / sum(prev)
|
return (prev[0] + prev[3]) / np.sum(prev)
|
||||||
|
|
||||||
|
|
||||||
|
def accd(true_prevs: np.ndarray, estim_prevs: np.ndarray) -> np.ndarray:
|
||||||
|
vacc = np.vectorize(acc, signature="(m)->()")
|
||||||
|
a_tp = vacc(true_prevs)
|
||||||
|
a_ep = vacc(estim_prevs)
|
||||||
|
return np.abs(a_tp - a_ep)
|
||||||
|
|
||||||
|
|
||||||
|
def maccd(true_prevs: np.ndarray, estim_prevs: np.ndarray) -> float:
|
||||||
|
return accd(true_prevs, estim_prevs).mean()
|
||||||
|
|
||||||
|
|
||||||
|
ACCURACY_ERROR = {maccd}
|
||||||
|
ACCURACY_ERROR_SINGLE = {accd}
|
||||||
|
ACCURACY_ERROR_NAMES = {func.__name__ for func in ACCURACY_ERROR}
|
||||||
|
ACCURACY_ERROR_SINGLE_NAMES = {func.__name__ for func in ACCURACY_ERROR_SINGLE}
|
||||||
|
ERROR_NAMES = ACCURACY_ERROR_NAMES | ACCURACY_ERROR_SINGLE_NAMES
|
||||||
|
|
|
@ -1,192 +0,0 @@
|
||||||
import math
|
|
||||||
from abc import abstractmethod
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import quapy as qp
|
|
||||||
from quapy.data import LabelledCollection
|
|
||||||
from quapy.method.aggregative import CC, SLD
|
|
||||||
from quapy.model_selection import GridSearchQ
|
|
||||||
from quapy.protocol import UPP
|
|
||||||
from sklearn.base import BaseEstimator
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from sklearn.model_selection import cross_val_predict
|
|
||||||
|
|
||||||
from quacc.data import ExtendedCollection
|
|
||||||
|
|
||||||
|
|
||||||
class AccuracyEstimator:
|
|
||||||
def __init__(self):
|
|
||||||
self.fit_score = None
|
|
||||||
|
|
||||||
def _gs_params(self, t_val: LabelledCollection):
|
|
||||||
return {
|
|
||||||
"param_grid": {
|
|
||||||
"classifier__C": np.logspace(-3, 3, 7),
|
|
||||||
"classifier__class_weight": [None, "balanced"],
|
|
||||||
"recalib": [None, "bcts"],
|
|
||||||
},
|
|
||||||
"protocol": UPP(t_val, repeats=1000),
|
|
||||||
"error": qp.error.mae,
|
|
||||||
"refit": False,
|
|
||||||
"timeout": -1,
|
|
||||||
"n_jobs": None,
|
|
||||||
"verbose": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
|
||||||
if not pred_proba:
|
|
||||||
pred_proba = self.c_model.predict_proba(base.X)
|
|
||||||
return ExtendedCollection.extend_collection(base, pred_proba), pred_proba
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def estimate(self, instances, ext=False):
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
AE = AccuracyEstimator
|
|
||||||
|
|
||||||
|
|
||||||
class MulticlassAccuracyEstimator(AccuracyEstimator):
|
|
||||||
def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None):
|
|
||||||
super().__init__()
|
|
||||||
self.c_model = c_model
|
|
||||||
self._q_model_name = q_model.upper()
|
|
||||||
self.e_train = None
|
|
||||||
self.gs = gs
|
|
||||||
self.recalib = recalib
|
|
||||||
|
|
||||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
|
||||||
# check if model is fit
|
|
||||||
# self.model.fit(*train.Xy)
|
|
||||||
if isinstance(train, LabelledCollection):
|
|
||||||
pred_prob_train = cross_val_predict(
|
|
||||||
self.c_model, *train.Xy, method="predict_proba"
|
|
||||||
)
|
|
||||||
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
|
||||||
else:
|
|
||||||
self.e_train = train
|
|
||||||
|
|
||||||
if self._q_model_name == "SLD":
|
|
||||||
if self.gs:
|
|
||||||
t_train, t_val = self.e_train.split_stratified(0.6, random_state=0)
|
|
||||||
gs_params = self._gs_params(t_val)
|
|
||||||
self.q_model = GridSearchQ(
|
|
||||||
SLD(LogisticRegression()),
|
|
||||||
**gs_params,
|
|
||||||
)
|
|
||||||
self.q_model.fit(t_train)
|
|
||||||
self.fit_score = self.q_model.best_score_
|
|
||||||
else:
|
|
||||||
self.q_model = SLD(LogisticRegression(), recalib=self.recalib)
|
|
||||||
self.q_model.fit(self.e_train)
|
|
||||||
elif self._q_model_name == "CC":
|
|
||||||
self.q_model = CC(LogisticRegression())
|
|
||||||
self.q_model.fit(self.e_train)
|
|
||||||
|
|
||||||
def estimate(self, instances, ext=False):
|
|
||||||
if not ext:
|
|
||||||
pred_prob = self.c_model.predict_proba(instances)
|
|
||||||
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
|
||||||
else:
|
|
||||||
e_inst = instances
|
|
||||||
|
|
||||||
estim_prev = self.q_model.quantify(e_inst)
|
|
||||||
|
|
||||||
return self._check_prevalence_classes(
|
|
||||||
self.e_train.classes_, self.q_model, estim_prev
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_prevalence_classes(self, true_classes, q_model, estim_prev):
|
|
||||||
if isinstance(q_model, GridSearchQ):
|
|
||||||
estim_classes = q_model.best_model().classes_
|
|
||||||
else:
|
|
||||||
estim_classes = q_model.classes_
|
|
||||||
for _cls in true_classes:
|
|
||||||
if _cls not in estim_classes:
|
|
||||||
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
|
||||||
return estim_prev
|
|
||||||
|
|
||||||
|
|
||||||
class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
|
||||||
def __init__(self, c_model: BaseEstimator, q_model="SLD", gs=False, recalib=None):
|
|
||||||
super().__init__()
|
|
||||||
self.c_model = c_model
|
|
||||||
self._q_model_name = q_model.upper()
|
|
||||||
self.q_models = []
|
|
||||||
self.gs = gs
|
|
||||||
self.recalib = recalib
|
|
||||||
self.e_train = None
|
|
||||||
|
|
||||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
|
||||||
# check if model is fit
|
|
||||||
# self.model.fit(*train.Xy)
|
|
||||||
if isinstance(train, LabelledCollection):
|
|
||||||
pred_prob_train = cross_val_predict(
|
|
||||||
self.c_model, *train.Xy, method="predict_proba"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
|
||||||
elif isinstance(train, ExtendedCollection):
|
|
||||||
self.e_train = train
|
|
||||||
|
|
||||||
self.n_classes = self.e_train.n_classes
|
|
||||||
e_trains = self.e_train.split_by_pred()
|
|
||||||
|
|
||||||
if self._q_model_name == "SLD":
|
|
||||||
fit_scores = []
|
|
||||||
for e_train in e_trains:
|
|
||||||
if self.gs:
|
|
||||||
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
|
|
||||||
gs_params = self._gs_params(t_val)
|
|
||||||
q_model = GridSearchQ(
|
|
||||||
SLD(LogisticRegression()),
|
|
||||||
**gs_params,
|
|
||||||
)
|
|
||||||
q_model.fit(t_train)
|
|
||||||
fit_scores.append(q_model.best_score_)
|
|
||||||
self.q_models.append(q_model)
|
|
||||||
else:
|
|
||||||
q_model = SLD(LogisticRegression(), recalib=self.recalib)
|
|
||||||
q_model.fit(e_train)
|
|
||||||
self.q_models.append(q_model)
|
|
||||||
|
|
||||||
if self.gs:
|
|
||||||
self.fit_score = np.mean(fit_scores)
|
|
||||||
|
|
||||||
elif self._q_model_name == "CC":
|
|
||||||
for e_train in e_trains:
|
|
||||||
q_model = CC(LogisticRegression())
|
|
||||||
q_model.fit(e_train)
|
|
||||||
self.q_models.append(q_model)
|
|
||||||
|
|
||||||
def estimate(self, instances, ext=False):
|
|
||||||
# TODO: test
|
|
||||||
if not ext:
|
|
||||||
pred_prob = self.c_model.predict_proba(instances)
|
|
||||||
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
|
||||||
else:
|
|
||||||
e_inst = instances
|
|
||||||
|
|
||||||
_ncl = int(math.sqrt(self.n_classes))
|
|
||||||
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
|
|
||||||
estim_prevs = [
|
|
||||||
self._quantify_helper(inst, norm, q_model)
|
|
||||||
for (inst, norm, q_model) in zip(s_inst, norms, self.q_models)
|
|
||||||
]
|
|
||||||
|
|
||||||
estim_prev = []
|
|
||||||
for prev_row in zip(*estim_prevs):
|
|
||||||
for prev in prev_row:
|
|
||||||
estim_prev.append(prev)
|
|
||||||
|
|
||||||
return np.asarray(estim_prev)
|
|
||||||
|
|
||||||
def _quantify_helper(self, inst, norm, q_model):
|
|
||||||
if inst.shape[0] > 0:
|
|
||||||
return np.asarray(list(map(lambda p: p * norm, q_model.quantify(inst))))
|
|
||||||
else:
|
|
||||||
return np.asarray([0.0, 0.0])
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
from typing import Callable, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||||||
|
|
||||||
|
import quacc as qc
|
||||||
|
|
||||||
|
from ..method.base import BaseAccuracyEstimator
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
estimator: BaseAccuracyEstimator,
|
||||||
|
protocol: AbstractProtocol,
|
||||||
|
error_metric: Union[Callable | str],
|
||||||
|
) -> float:
|
||||||
|
if isinstance(error_metric, str):
|
||||||
|
error_metric = qc.error.from_name(error_metric)
|
||||||
|
|
||||||
|
collator_bck_ = protocol.collator
|
||||||
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
|
estim_prevs, true_prevs = [], []
|
||||||
|
for sample in protocol():
|
||||||
|
e_sample = estimator.extend(sample)
|
||||||
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||||
|
estim_prevs.append(estim_prev)
|
||||||
|
true_prevs.append(e_sample.prevalence())
|
||||||
|
|
||||||
|
protocol.collator = collator_bck_
|
||||||
|
|
||||||
|
true_prevs = np.array(true_prevs)
|
||||||
|
estim_prevs = np.array(estim_prevs)
|
||||||
|
|
||||||
|
return error_metric(true_prevs, estim_prevs)
|
|
@ -65,11 +65,10 @@ def ref(
|
||||||
validation: LabelledCollection,
|
validation: LabelledCollection,
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
):
|
):
|
||||||
c_model_predict = getattr(c_model, "predict_proba")
|
c_model_predict = getattr(c_model, "predict")
|
||||||
report = EvaluationReport(name="ref")
|
report = EvaluationReport(name="ref")
|
||||||
for test in protocol():
|
for test in protocol():
|
||||||
test_probs = c_model_predict(test.X)
|
test_preds = c_model_predict(test.X)
|
||||||
test_preds = np.argmax(test_probs, axis=-1)
|
|
||||||
report.append_row(
|
report.append_row(
|
||||||
test.prevalence(),
|
test.prevalence(),
|
||||||
acc_score=metrics.accuracy_score(test.y, test_preds),
|
acc_score=metrics.accuracy_score(test.y, test_preds),
|
||||||
|
|
|
@ -3,6 +3,7 @@ import time
|
||||||
from traceback import print_exception as traceback
|
from traceback import print_exception as traceback
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
|
|
||||||
|
@ -17,31 +18,63 @@ pd.set_option("display.float_format", "{:.4f}".format)
|
||||||
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
|
qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
class CompEstimatorName_:
|
||||||
|
def __init__(self, ce):
|
||||||
|
self.ce = ce
|
||||||
|
|
||||||
|
def __getitem__(self, e: str | List[str]):
|
||||||
|
if isinstance(e, str):
|
||||||
|
return self.ce._CompEstimator__get(e)[0]
|
||||||
|
elif isinstance(e, list):
|
||||||
|
return list(self.ce._CompEstimator__get(e).keys())
|
||||||
|
|
||||||
|
|
||||||
|
class CompEstimatorFunc_:
|
||||||
|
def __init__(self, ce):
|
||||||
|
self.ce = ce
|
||||||
|
|
||||||
|
def __getitem__(self, e: str | List[str]):
|
||||||
|
if isinstance(e, str):
|
||||||
|
return self.ce._CompEstimator__get(e)[1]
|
||||||
|
elif isinstance(e, list):
|
||||||
|
return list(self.ce._CompEstimator__get(e).values())
|
||||||
|
|
||||||
|
|
||||||
class CompEstimator:
|
class CompEstimator:
|
||||||
__dict = method._methods | baseline._baselines
|
__dict = method._methods | baseline._baselines
|
||||||
|
|
||||||
def __class_getitem__(cls, e: str | List[str]):
|
def __get(cls, e: str | List[str]):
|
||||||
if isinstance(e, str):
|
if isinstance(e, str):
|
||||||
try:
|
try:
|
||||||
return cls.__dict[e]
|
return (e, cls.__dict[e])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise KeyError(f"Invalid estimator: estimator {e} does not exist")
|
raise KeyError(f"Invalid estimator: estimator {e} does not exist")
|
||||||
elif isinstance(e, list):
|
elif isinstance(e, list):
|
||||||
_subtr = [k for k in e if k not in cls.__dict]
|
_subtr = np.setdiff1d(e, list(cls.__dict.keys()))
|
||||||
if len(_subtr) > 0:
|
if len(_subtr) > 0:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
f"Invalid estimator: estimator {_subtr[0]} does not exist"
|
f"Invalid estimator: estimator {_subtr[0]} does not exist"
|
||||||
)
|
)
|
||||||
|
|
||||||
return [fun for k, fun in cls.__dict.items() if k in e]
|
e_fun = {k: fun for k, fun in cls.__dict.items() if k in e}
|
||||||
|
if "ref" not in e:
|
||||||
|
e_fun["ref"] = cls.__dict["ref"]
|
||||||
|
|
||||||
|
return e_fun
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return CompEstimatorName_(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def func(self):
|
||||||
|
return CompEstimatorFunc_(self)
|
||||||
|
|
||||||
|
|
||||||
CE = CompEstimator
|
CE = CompEstimator()
|
||||||
|
|
||||||
|
|
||||||
def evaluate_comparison(
|
def evaluate_comparison(dataset: Dataset, estimators=None) -> EvaluationReport:
|
||||||
dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"]
|
|
||||||
) -> EvaluationReport:
|
|
||||||
log = Logger.logger()
|
log = Logger.logger()
|
||||||
# with multiprocessing.Pool(1) as pool:
|
# with multiprocessing.Pool(1) as pool:
|
||||||
with multiprocessing.Pool(len(estimators)) as pool:
|
with multiprocessing.Pool(len(estimators)) as pool:
|
||||||
|
@ -52,7 +85,9 @@ def evaluate_comparison(
|
||||||
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started"
|
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started"
|
||||||
)
|
)
|
||||||
tstart = time.time()
|
tstart = time.time()
|
||||||
tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]]
|
tasks = [
|
||||||
|
(estim, d.train, d.validation, d.test) for estim in CE.func[estimators]
|
||||||
|
]
|
||||||
results = [
|
results = [
|
||||||
pool.apply_async(estimate_worker, t, {"_env": env, "q": Logger.queue()})
|
pool.apply_async(estimate_worker, t, {"_env": env, "q": Logger.queue()})
|
||||||
for t in tasks
|
for t in tasks
|
||||||
|
|
|
@ -1,23 +1,30 @@
|
||||||
|
import inspect
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sklearn.metrics as metrics
|
from quapy.method.aggregative import PACC, SLD, CC
|
||||||
from quapy.data import LabelledCollection
|
from quapy.protocol import UPP, AbstractProtocol
|
||||||
from quapy.protocol import AbstractStochasticSeededProtocol
|
from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.base import BaseEstimator
|
|
||||||
|
|
||||||
import quacc.error as error
|
import quacc as qc
|
||||||
from quacc.evaluation.report import EvaluationReport
|
from quacc.evaluation.report import EvaluationReport
|
||||||
|
from quacc.method.model_selection import BQAEgsq, GridSearchAE, MCAEgsq
|
||||||
|
|
||||||
from ..estimator import (
|
from ..method.base import BQAE, MCAE, BaseAccuracyEstimator
|
||||||
AccuracyEstimator,
|
|
||||||
BinaryQuantifierAccuracyEstimator,
|
|
||||||
MulticlassAccuracyEstimator,
|
|
||||||
)
|
|
||||||
|
|
||||||
_methods = {}
|
_methods = {}
|
||||||
|
_sld_param_grid = {
|
||||||
|
"q__classifier__C": np.logspace(-3, 3, 7),
|
||||||
|
"q__classifier__class_weight": [None, "balanced"],
|
||||||
|
"q__recalib": [None, "bcts"],
|
||||||
|
"q__exact_train_prev": [True],
|
||||||
|
"confidence": [None, "max_conf", "entropy"],
|
||||||
|
}
|
||||||
|
_pacc_param_grid = {
|
||||||
|
"q__classifier__C": np.logspace(-3, 3, 7),
|
||||||
|
"q__classifier__class_weight": [None, "balanced"],
|
||||||
|
"confidence": [None, "max_conf", "entropy"],
|
||||||
|
}
|
||||||
def method(func):
|
def method(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(c_model, validation, protocol):
|
def wrapper(c_model, validation, protocol):
|
||||||
|
@ -28,108 +35,271 @@ def method(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def estimate(
|
|
||||||
estimator: AccuracyEstimator,
|
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
|
||||||
):
|
|
||||||
base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], []
|
|
||||||
for sample in protocol():
|
|
||||||
e_sample, pred_proba = estimator.extend(sample)
|
|
||||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
|
||||||
base_prevs.append(sample.prevalence())
|
|
||||||
true_prevs.append(e_sample.prevalence())
|
|
||||||
estim_prevs.append(estim_prev)
|
|
||||||
pred_probas.append(pred_proba)
|
|
||||||
labels.append(sample.y)
|
|
||||||
|
|
||||||
return base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
|
||||||
|
|
||||||
|
|
||||||
def evaluation_report(
|
def evaluation_report(
|
||||||
estimator: AccuracyEstimator,
|
estimator: BaseAccuracyEstimator,
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
protocol: AbstractProtocol,
|
||||||
method: str,
|
|
||||||
) -> EvaluationReport:
|
) -> EvaluationReport:
|
||||||
base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate(
|
method_name = inspect.stack()[1].function
|
||||||
estimator, protocol
|
report = EvaluationReport(name=method_name)
|
||||||
)
|
for sample in protocol():
|
||||||
report = EvaluationReport(name=method)
|
e_sample = estimator.extend(sample)
|
||||||
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||||
for base_prev, true_prev, estim_prev, pred_proba, label in zip(
|
acc_score = qc.error.acc(estim_prev)
|
||||||
base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
f1_score = qc.error.f1(estim_prev)
|
||||||
):
|
|
||||||
pred = np.argmax(pred_proba, axis=-1)
|
|
||||||
acc_score = error.acc(estim_prev)
|
|
||||||
f1_score = error.f1(estim_prev)
|
|
||||||
report.append_row(
|
report.append_row(
|
||||||
base_prev,
|
sample.prevalence(),
|
||||||
acc_score=acc_score,
|
acc_score=acc_score,
|
||||||
acc=abs(metrics.accuracy_score(label, pred) - acc_score),
|
acc=abs(qc.error.acc(e_sample.prevalence()) - acc_score),
|
||||||
f1_score=f1_score,
|
f1_score=f1_score,
|
||||||
f1=abs(error.f1(true_prev) - f1_score),
|
f1=abs(qc.error.f1(e_sample.prevalence()) - f1_score),
|
||||||
)
|
)
|
||||||
|
|
||||||
report.fit_score = estimator.fit_score
|
|
||||||
|
|
||||||
return report
|
return report
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
c_model: BaseEstimator,
|
|
||||||
validation: LabelledCollection,
|
|
||||||
protocol: AbstractStochasticSeededProtocol,
|
|
||||||
method: str,
|
|
||||||
q_model: str,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
estimator: AccuracyEstimator = {
|
|
||||||
"bin": BinaryQuantifierAccuracyEstimator,
|
|
||||||
"mul": MulticlassAccuracyEstimator,
|
|
||||||
}[method](c_model, q_model=q_model.upper(), **kwargs)
|
|
||||||
estimator.fit(validation)
|
|
||||||
_method = f"{method}_{q_model}"
|
|
||||||
if "recalib" in kwargs:
|
|
||||||
_method += f"_{kwargs['recalib']}"
|
|
||||||
if ("gs", True) in kwargs.items():
|
|
||||||
_method += "_gs"
|
|
||||||
return evaluation_report(estimator, protocol, _method)
|
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def bin_sld(c_model, validation, protocol) -> EvaluationReport:
|
def bin_sld(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "bin", "sld")
|
est = BQAE(c_model, SLD(LogisticRegression())).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def mul_sld(c_model, validation, protocol) -> EvaluationReport:
|
def mul_sld(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "mul", "sld")
|
est = MCAE(c_model, SLD(LogisticRegression())).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
def binmc_sld(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "bin", "sld", recalib="bcts")
|
est = BQAE(
|
||||||
|
c_model,
|
||||||
|
SLD(LogisticRegression()),
|
||||||
|
confidence="max_conf",
|
||||||
|
).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
def mulmc_sld(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "mul", "sld", recalib="bcts")
|
est = MCAE(
|
||||||
|
c_model,
|
||||||
|
SLD(LogisticRegression()),
|
||||||
|
confidence="max_conf",
|
||||||
|
).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def binne_sld(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
est = BQAE(
|
||||||
|
c_model,
|
||||||
|
SLD(LogisticRegression()),
|
||||||
|
confidence="entropy",
|
||||||
|
).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def mulne_sld(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
est = MCAE(
|
||||||
|
c_model,
|
||||||
|
SLD(LogisticRegression()),
|
||||||
|
confidence="entropy",
|
||||||
|
).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "bin", "sld", gs=True)
|
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
||||||
|
model = BQAE(c_model, SLD(LogisticRegression()))
|
||||||
|
est = GridSearchAE(
|
||||||
|
model=model,
|
||||||
|
param_grid=_sld_param_grid,
|
||||||
|
refit=False,
|
||||||
|
protocol=UPP(v_val, repeats=100),
|
||||||
|
verbose=True,
|
||||||
|
).fit(v_train)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "mul", "sld", gs=True)
|
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
||||||
|
model = MCAE(c_model, SLD(LogisticRegression()))
|
||||||
|
est = GridSearchAE(
|
||||||
|
model=model,
|
||||||
|
param_grid=_sld_param_grid,
|
||||||
|
refit=False,
|
||||||
|
protocol=UPP(v_val, repeats=100),
|
||||||
|
verbose=True,
|
||||||
|
).fit(v_train)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def bin_sld_gsq(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
est = BQAEgsq(
|
||||||
|
c_model,
|
||||||
|
SLD(LogisticRegression()),
|
||||||
|
param_grid={
|
||||||
|
"classifier__C": np.logspace(-3, 3, 7),
|
||||||
|
"classifier__class_weight": [None, "balanced"],
|
||||||
|
"recalib": [None, "bcts", "vs"],
|
||||||
|
},
|
||||||
|
refit=False,
|
||||||
|
verbose=False,
|
||||||
|
).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def mul_sld_gsq(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
est = MCAEgsq(
|
||||||
|
c_model,
|
||||||
|
SLD(LogisticRegression()),
|
||||||
|
param_grid={
|
||||||
|
"classifier__C": np.logspace(-3, 3, 7),
|
||||||
|
"classifier__class_weight": [None, "balanced"],
|
||||||
|
"recalib": [None, "bcts", "vs"],
|
||||||
|
},
|
||||||
|
refit=False,
|
||||||
|
verbose=False,
|
||||||
|
).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def bin_pacc(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
est = BQAE(c_model, PACC(LogisticRegression())).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def mul_pacc(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
est = MCAE(c_model, PACC(LogisticRegression())).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def binmc_pacc(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
est = BQAE(c_model, PACC(LogisticRegression()), confidence="max_conf").fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def mulmc_pacc(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
est = MCAE(c_model, PACC(LogisticRegression()), confidence="max_conf").fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def binne_pacc(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
est = BQAE(c_model, PACC(LogisticRegression()), confidence="entropy").fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def mulne_pacc(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
est = MCAE(c_model, PACC(LogisticRegression()), confidence="entropy").fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def bin_pacc_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
||||||
|
model = BQAE(c_model, PACC(LogisticRegression()))
|
||||||
|
est = GridSearchAE(
|
||||||
|
model=model,
|
||||||
|
param_grid=_pacc_param_grid,
|
||||||
|
refit=False,
|
||||||
|
protocol=UPP(v_val, repeats=100),
|
||||||
|
verbose=False,
|
||||||
|
).fit(v_train)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@method
|
||||||
|
def mul_pacc_gs(c_model, validation, protocol) -> EvaluationReport:
|
||||||
|
v_train, v_val = validation.split_stratified(0.6, random_state=0)
|
||||||
|
model = MCAE(c_model, PACC(LogisticRegression()))
|
||||||
|
est = GridSearchAE(
|
||||||
|
model=model,
|
||||||
|
param_grid=_pacc_param_grid,
|
||||||
|
refit=False,
|
||||||
|
protocol=UPP(v_val, repeats=100),
|
||||||
|
verbose=False,
|
||||||
|
).fit(v_train)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def bin_cc(c_model, validation, protocol) -> EvaluationReport:
|
def bin_cc(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "bin", "cc")
|
est = BQAE(c_model, CC(LogisticRegression())).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@method
|
@method
|
||||||
def mul_cc(c_model, validation, protocol) -> EvaluationReport:
|
def mul_cc(c_model, validation, protocol) -> EvaluationReport:
|
||||||
return evaluate(c_model, validation, protocol, "mul", "cc")
|
est = MCAE(c_model, CC(LogisticRegression())).fit(validation)
|
||||||
|
return evaluation_report(
|
||||||
|
estimator=est,
|
||||||
|
protocol=protocol,
|
||||||
|
)
|
||||||
|
|
|
@ -115,7 +115,7 @@ class CompReport:
|
||||||
|
|
||||||
shift_data = self._data.copy()
|
shift_data = self._data.copy()
|
||||||
shift_data.index = pd.MultiIndex.from_arrays([shift_idx_0, shift_idx_1])
|
shift_data.index = pd.MultiIndex.from_arrays([shift_idx_0, shift_idx_1])
|
||||||
shift_data.sort_index(axis=0, level=0)
|
shift_data = shift_data.sort_index(axis=0, level=0)
|
||||||
|
|
||||||
_metric = _get_metric(metric)
|
_metric = _get_metric(metric)
|
||||||
_estimators = _get_estimators(estimators, shift_data.columns.unique(1))
|
_estimators = _get_estimators(estimators, shift_data.columns.unique(1))
|
||||||
|
@ -182,22 +182,21 @@ class CompReport:
|
||||||
train_prev=self.train_prev,
|
train_prev=self.train_prev,
|
||||||
)
|
)
|
||||||
elif mode == "shift":
|
elif mode == "shift":
|
||||||
shift_data = (
|
_shift_data = self.shift_data(metric=metric, estimators=estimators)
|
||||||
self.shift_data(metric=metric, estimators=estimators)
|
shift_avg = _shift_data.groupby(level=0).mean()
|
||||||
.groupby(level=0)
|
shift_counts = _shift_data.groupby(level=0).count()
|
||||||
.mean()
|
|
||||||
)
|
|
||||||
shift_prevs = np.around(
|
shift_prevs = np.around(
|
||||||
[(1.0 - p, p) for p in np.sort(shift_data.index.unique(0))],
|
[(1.0 - p, p) for p in np.sort(shift_avg.index.unique(0))],
|
||||||
decimals=2,
|
decimals=2,
|
||||||
)
|
)
|
||||||
return plot.plot_shift(
|
return plot.plot_shift(
|
||||||
shift_prevs=shift_prevs,
|
shift_prevs=shift_prevs,
|
||||||
columns=shift_data.columns.to_numpy(),
|
columns=shift_avg.columns.to_numpy(),
|
||||||
data=shift_data.T.to_numpy(),
|
data=shift_avg.T.to_numpy(),
|
||||||
metric=metric,
|
metric=metric,
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=self.train_prev,
|
train_prev=self.train_prev,
|
||||||
|
counts=shift_counts.T.to_numpy(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_md(self, conf="default", metric="acc", estimators=None, stdev=False) -> str:
|
def to_md(self, conf="default", metric="acc", estimators=None, stdev=False) -> str:
|
||||||
|
@ -246,7 +245,7 @@ class DatasetReport:
|
||||||
)
|
)
|
||||||
_crs_train, _crs_data = zip(*_crs_sorted)
|
_crs_train, _crs_data = zip(*_crs_sorted)
|
||||||
|
|
||||||
_data = pd.concat(_crs_data, axis=0, keys=_crs_train)
|
_data = pd.concat(_crs_data, axis=0, keys=np.around(_crs_train, decimals=2))
|
||||||
_data = _data.sort_index(axis=0, level=0)
|
_data = _data.sort_index(axis=0, level=0)
|
||||||
return _data
|
return _data
|
||||||
|
|
||||||
|
@ -296,47 +295,95 @@ class DatasetReport:
|
||||||
_data = self.data(metric=metric, estimators=estimators)
|
_data = self.data(metric=metric, estimators=estimators)
|
||||||
_shift_data = self.shift_data(metric=metric, estimators=estimators)
|
_shift_data = self.shift_data(metric=metric, estimators=estimators)
|
||||||
|
|
||||||
avg_x_test = _data.groupby(level=1).mean()
|
|
||||||
prevs_x_test = np.sort(avg_x_test.index.unique(0))
|
|
||||||
stdev_x_test = _data.groupby(level=1).std() if stdev else None
|
|
||||||
avg_x_test_tbl = _data.groupby(level=1).mean()
|
|
||||||
avg_x_test_tbl.loc["avg", :] = _data.mean()
|
|
||||||
|
|
||||||
avg_x_shift = _shift_data.groupby(level=0).mean()
|
|
||||||
prevs_x_shift = np.sort(avg_x_shift.index.unique(0))
|
|
||||||
|
|
||||||
res += "## avg\n"
|
res += "## avg\n"
|
||||||
res += avg_x_test_tbl.to_html() + "\n\n"
|
|
||||||
|
######################## avg on train ########################
|
||||||
|
res += "### avg on train\n"
|
||||||
|
|
||||||
|
avg_on_train = _data.groupby(level=1).mean()
|
||||||
|
prevs_on_train = np.sort(avg_on_train.index.unique(0))
|
||||||
|
stdev_on_train = _data.groupby(level=1).std() if stdev else None
|
||||||
|
avg_on_train_tbl = _data.groupby(level=1).mean()
|
||||||
|
avg_on_train_tbl.loc["avg", :] = _data.mean()
|
||||||
|
|
||||||
|
res += avg_on_train_tbl.to_html() + "\n\n"
|
||||||
|
|
||||||
delta_op = plot.plot_delta(
|
delta_op = plot.plot_delta(
|
||||||
base_prevs=np.around([(1.0 - p, p) for p in prevs_x_test], decimals=2),
|
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_train], decimals=2),
|
||||||
columns=avg_x_test.columns.to_numpy(),
|
columns=avg_on_train.columns.to_numpy(),
|
||||||
data=avg_x_test.T.to_numpy(),
|
data=avg_on_train.T.to_numpy(),
|
||||||
metric=metric,
|
metric=metric,
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=None,
|
train_prev=None,
|
||||||
|
avg="train",
|
||||||
)
|
)
|
||||||
res += f".as_posix()})\n"
|
res += f".as_posix()})\n"
|
||||||
|
|
||||||
if stdev:
|
if stdev:
|
||||||
delta_stdev_op = plot.plot_delta(
|
delta_stdev_op = plot.plot_delta(
|
||||||
base_prevs=np.around([(1.0 - p, p) for p in prevs_x_test], decimals=2),
|
base_prevs=np.around(
|
||||||
columns=avg_x_test.columns.to_numpy(),
|
[(1.0 - p, p) for p in prevs_on_train], decimals=2
|
||||||
data=avg_x_test.T.to_numpy(),
|
),
|
||||||
|
columns=avg_on_train.columns.to_numpy(),
|
||||||
|
data=avg_on_train.T.to_numpy(),
|
||||||
metric=metric,
|
metric=metric,
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=None,
|
train_prev=None,
|
||||||
stdevs=stdev_x_test.T.to_numpy(),
|
stdevs=stdev_on_train.T.to_numpy(),
|
||||||
|
avg="train",
|
||||||
)
|
)
|
||||||
res += f".as_posix()})\n"
|
res += f".as_posix()})\n"
|
||||||
|
|
||||||
shift_op = plot.plot_shift(
|
######################## avg on test ########################
|
||||||
shift_prevs=np.around([(1.0 - p, p) for p in prevs_x_shift], decimals=2),
|
res += "### avg on test\n"
|
||||||
columns=avg_x_shift.columns.to_numpy(),
|
|
||||||
data=avg_x_shift.T.to_numpy(),
|
avg_on_test = _data.groupby(level=0).mean()
|
||||||
|
prevs_on_test = np.sort(avg_on_test.index.unique(0))
|
||||||
|
stdev_on_test = _data.groupby(level=0).std() if stdev else None
|
||||||
|
avg_on_test_tbl = _data.groupby(level=0).mean()
|
||||||
|
avg_on_test_tbl.loc["avg", :] = _data.mean()
|
||||||
|
|
||||||
|
res += avg_on_test_tbl.to_html() + "\n\n"
|
||||||
|
|
||||||
|
delta_op = plot.plot_delta(
|
||||||
|
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
|
||||||
|
columns=avg_on_test.columns.to_numpy(),
|
||||||
|
data=avg_on_test.T.to_numpy(),
|
||||||
metric=metric,
|
metric=metric,
|
||||||
name=conf,
|
name=conf,
|
||||||
train_prev=None,
|
train_prev=None,
|
||||||
|
avg="test",
|
||||||
|
)
|
||||||
|
res += f".as_posix()})\n"
|
||||||
|
|
||||||
|
if stdev:
|
||||||
|
delta_stdev_op = plot.plot_delta(
|
||||||
|
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
|
||||||
|
columns=avg_on_test.columns.to_numpy(),
|
||||||
|
data=avg_on_test.T.to_numpy(),
|
||||||
|
metric=metric,
|
||||||
|
name=conf,
|
||||||
|
train_prev=None,
|
||||||
|
stdevs=stdev_on_test.T.to_numpy(),
|
||||||
|
avg="test",
|
||||||
|
)
|
||||||
|
res += f".as_posix()})\n"
|
||||||
|
|
||||||
|
######################## avg shift ########################
|
||||||
|
res += "### avg dataset shift\n"
|
||||||
|
|
||||||
|
avg_shift = _shift_data.groupby(level=0).mean()
|
||||||
|
count_shift = _shift_data.groupby(level=0).count()
|
||||||
|
prevs_shift = np.sort(avg_shift.index.unique(0))
|
||||||
|
|
||||||
|
shift_op = plot.plot_shift(
|
||||||
|
shift_prevs=np.around([(1.0 - p, p) for p in prevs_shift], decimals=2),
|
||||||
|
columns=avg_shift.columns.to_numpy(),
|
||||||
|
data=avg_shift.T.to_numpy(),
|
||||||
|
metric=metric,
|
||||||
|
name=conf,
|
||||||
|
train_prev=None,
|
||||||
|
counts=count_shift.T.to_numpy(),
|
||||||
)
|
)
|
||||||
res += f".as_posix()})\n"
|
res += f".as_posix()})\n"
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ def estimate_worker(_estimate, train, validation, test, _env=None, q=None):
|
||||||
result = _estimate(model, validation, protocol)
|
result = _estimate(model, validation, protocol)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"Method {_estimate.__name__} failed. Exception: {e}")
|
log.warning(f"Method {_estimate.__name__} failed. Exception: {e}")
|
||||||
# traceback(e)
|
traceback(e)
|
||||||
return {
|
return {
|
||||||
"name": _estimate.__name__,
|
"name": _estimate.__name__,
|
||||||
"result": None,
|
"result": None,
|
||||||
|
|
|
@ -2,6 +2,7 @@ import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import threading
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class Logger:
|
class Logger:
|
||||||
|
@ -11,6 +12,7 @@ class Logger:
|
||||||
__queue = None
|
__queue = None
|
||||||
__thread = None
|
__thread = None
|
||||||
__setup = False
|
__setup = False
|
||||||
|
__handlers = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __logger_listener(cls, q):
|
def __logger_listener(cls, q):
|
||||||
|
@ -32,7 +34,6 @@ class Logger:
|
||||||
rh = logging.FileHandler(cls.__logger_file, mode="a")
|
rh = logging.FileHandler(cls.__logger_file, mode="a")
|
||||||
rh.setLevel(logging.DEBUG)
|
rh.setLevel(logging.DEBUG)
|
||||||
root.addHandler(rh)
|
root.addHandler(rh)
|
||||||
root.info("-" * 100)
|
|
||||||
|
|
||||||
# setup logger
|
# setup logger
|
||||||
if cls.__manager is None:
|
if cls.__manager is None:
|
||||||
|
@ -62,6 +63,21 @@ class Logger:
|
||||||
|
|
||||||
cls.__setup = True
|
cls.__setup = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_handler(cls, path: Path):
|
||||||
|
root = logging.getLogger("listener")
|
||||||
|
rh = logging.FileHandler(path, mode="a")
|
||||||
|
rh.setLevel(logging.DEBUG)
|
||||||
|
cls.__handlers.append(rh)
|
||||||
|
root.addHandler(rh)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_handlers(cls):
|
||||||
|
root = logging.getLogger("listener")
|
||||||
|
for h in cls.__handlers:
|
||||||
|
root.removeHandler(h)
|
||||||
|
cls.__handlers.clear()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def queue(cls):
|
def queue(cls):
|
||||||
if not cls.__setup:
|
if not cls.__setup:
|
||||||
|
@ -79,6 +95,8 @@ class Logger:
|
||||||
@classmethod
|
@classmethod
|
||||||
def close(cls):
|
def close(cls):
|
||||||
if cls.__setup and cls.__thread is not None:
|
if cls.__setup and cls.__thread is not None:
|
||||||
|
root = logging.getLogger("listener")
|
||||||
|
root.info("-" * 100)
|
||||||
cls.__queue.put(None)
|
cls.__queue.put(None)
|
||||||
cls.__thread.join()
|
cls.__thread.join()
|
||||||
# cls.__manager.close()
|
# cls.__manager.close()
|
||||||
|
@ -102,7 +120,7 @@ class SubLogger:
|
||||||
rh.setLevel(logging.DEBUG)
|
rh.setLevel(logging.DEBUG)
|
||||||
rh.setFormatter(
|
rh.setFormatter(
|
||||||
logging.Formatter(
|
logging.Formatter(
|
||||||
fmt="%(asctime)s| %(levelname)-12s\t%(message)s",
|
fmt="%(asctime)s| %(levelname)-12s%(message)s",
|
||||||
datefmt="%d/%m/%y %H:%M:%S",
|
datefmt="%d/%m/%y %H:%M:%S",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,6 +7,8 @@ from quacc.environment import env
|
||||||
from quacc.logger import Logger
|
from quacc.logger import Logger
|
||||||
from quacc.utils import create_dataser_dir
|
from quacc.utils import create_dataser_dir
|
||||||
|
|
||||||
|
CE = comp.CompEstimator()
|
||||||
|
|
||||||
|
|
||||||
def toast():
|
def toast():
|
||||||
if platform == "win32":
|
if platform == "win32":
|
||||||
|
@ -25,8 +27,12 @@ def estimate_comparison():
|
||||||
prevs=env.DATASET_PREVS,
|
prevs=env.DATASET_PREVS,
|
||||||
)
|
)
|
||||||
create_dataser_dir(dataset.name, update=env.DATASET_DIR_UPDATE)
|
create_dataser_dir(dataset.name, update=env.DATASET_DIR_UPDATE)
|
||||||
|
Logger.add_handler(env.OUT_DIR / f"{dataset.name}.log")
|
||||||
try:
|
try:
|
||||||
dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
|
dr = comp.evaluate_comparison(
|
||||||
|
dataset,
|
||||||
|
estimators=CE.name[env.COMP_ESTIMATORS],
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Evaluation over {dataset.name} failed. Exception: {e}")
|
log.error(f"Evaluation over {dataset.name} failed. Exception: {e}")
|
||||||
traceback(e)
|
traceback(e)
|
||||||
|
@ -37,7 +43,7 @@ def estimate_comparison():
|
||||||
_repr = dr.to_md(
|
_repr = dr.to_md(
|
||||||
conf=plot_conf,
|
conf=plot_conf,
|
||||||
metric=m,
|
metric=m,
|
||||||
estimators=env.PLOT_ESTIMATORS,
|
estimators=CE.name[env.PLOT_ESTIMATORS],
|
||||||
stdev=env.PLOT_STDEV,
|
stdev=env.PLOT_STDEV,
|
||||||
)
|
)
|
||||||
with open(output_path, "w") as f:
|
with open(output_path, "w") as f:
|
||||||
|
@ -47,6 +53,7 @@ def estimate_comparison():
|
||||||
f"Failed while saving configuration {plot_conf} of {dataset.name}. Exception: {e}"
|
f"Failed while saving configuration {plot_conf} of {dataset.name}. Exception: {e}"
|
||||||
)
|
)
|
||||||
traceback(e)
|
traceback(e)
|
||||||
|
Logger.clear_handlers()
|
||||||
|
|
||||||
# print(df.to_latex(float_format="{:.4f}".format))
|
# print(df.to_latex(float_format="{:.4f}".format))
|
||||||
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))
|
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))
|
||||||
|
|
|
@ -0,0 +1,120 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import win11toast
|
||||||
|
from quapy.method.aggregative import SLD
|
||||||
|
from quapy.protocol import APP, UPP
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
|
import quacc as qc
|
||||||
|
from quacc.dataset import Dataset
|
||||||
|
from quacc.error import acc
|
||||||
|
from quacc.evaluation.baseline import ref
|
||||||
|
from quacc.evaluation.method import mulmc_sld
|
||||||
|
from quacc.evaluation.report import CompReport, EvaluationReport
|
||||||
|
from quacc.method.base import MCAE, BinaryQuantifierAccuracyEstimator
|
||||||
|
from quacc.method.model_selection import GridSearchAE
|
||||||
|
|
||||||
|
|
||||||
|
def test_gs():
|
||||||
|
d = Dataset(name="rcv1", target="CCAT", n_prevalences=1).get_raw()
|
||||||
|
|
||||||
|
classifier = LogisticRegression()
|
||||||
|
classifier.fit(*d.train.Xy)
|
||||||
|
|
||||||
|
quantifier = SLD(LogisticRegression())
|
||||||
|
# estimator = MultiClassAccuracyEstimator(classifier, quantifier)
|
||||||
|
estimator = BinaryQuantifierAccuracyEstimator(classifier, quantifier)
|
||||||
|
|
||||||
|
v_train, v_val = d.validation.split_stratified(0.6, random_state=0)
|
||||||
|
gs_protocol = UPP(v_val, sample_size=1000, repeats=100)
|
||||||
|
gs_estimator = GridSearchAE(
|
||||||
|
model=deepcopy(estimator),
|
||||||
|
param_grid={
|
||||||
|
"q__classifier__C": np.logspace(-3, 3, 7),
|
||||||
|
"q__classifier__class_weight": [None, "balanced"],
|
||||||
|
"q__recalib": [None, "bcts", "ts"],
|
||||||
|
},
|
||||||
|
refit=False,
|
||||||
|
protocol=gs_protocol,
|
||||||
|
verbose=True,
|
||||||
|
).fit(v_train)
|
||||||
|
|
||||||
|
estimator.fit(d.validation)
|
||||||
|
|
||||||
|
tstart = time()
|
||||||
|
erb, ergs = EvaluationReport("base"), EvaluationReport("gs")
|
||||||
|
protocol = APP(
|
||||||
|
d.test,
|
||||||
|
sample_size=1000,
|
||||||
|
n_prevalences=21,
|
||||||
|
repeats=100,
|
||||||
|
return_type="labelled_collection",
|
||||||
|
)
|
||||||
|
for sample in protocol():
|
||||||
|
e_sample = gs_estimator.extend(sample)
|
||||||
|
estim_prev_b = estimator.estimate(e_sample.X, ext=True)
|
||||||
|
estim_prev_gs = gs_estimator.estimate(e_sample.X, ext=True)
|
||||||
|
erb.append_row(
|
||||||
|
sample.prevalence(),
|
||||||
|
acc=abs(acc(e_sample.prevalence()) - acc(estim_prev_b)),
|
||||||
|
)
|
||||||
|
ergs.append_row(
|
||||||
|
sample.prevalence(),
|
||||||
|
acc=abs(acc(e_sample.prevalence()) - acc(estim_prev_gs)),
|
||||||
|
)
|
||||||
|
|
||||||
|
cr = CompReport(
|
||||||
|
[erb, ergs],
|
||||||
|
"test",
|
||||||
|
train_prev=d.train_prev,
|
||||||
|
valid_prev=d.validation_prev,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(cr.table())
|
||||||
|
print(f"[took {time() - tstart:.3f}s]")
|
||||||
|
win11toast.notify("Test", "completed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_mc():
|
||||||
|
d = Dataset(name="rcv1", target="CCAT", prevs=[0.9]).get()[0]
|
||||||
|
classifier = LogisticRegression().fit(*d.train.Xy)
|
||||||
|
protocol = APP(
|
||||||
|
d.test,
|
||||||
|
sample_size=1000,
|
||||||
|
repeats=100,
|
||||||
|
n_prevalences=21,
|
||||||
|
return_type="labelled_collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_er = ref(classifier, d.validation, protocol)
|
||||||
|
mulmc_er = mulmc_sld(classifier, d.validation, protocol)
|
||||||
|
|
||||||
|
cr = CompReport(
|
||||||
|
[mulmc_er, ref_er],
|
||||||
|
name="test_mc",
|
||||||
|
train_prev=d.train_prev,
|
||||||
|
valid_prev=d.validation_prev,
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("test_mc.md", "w") as f:
|
||||||
|
f.write(cr.data().to_markdown())
|
||||||
|
|
||||||
|
|
||||||
|
def test_et():
|
||||||
|
d = Dataset(name="imdb", prevs=[0.5]).get()[0]
|
||||||
|
classifier = LogisticRegression().fit(*d.train.Xy)
|
||||||
|
estimator = MCAE(
|
||||||
|
classifier,
|
||||||
|
SLD(LogisticRegression(), exact_train_prev=False),
|
||||||
|
confidence="max_conf",
|
||||||
|
).fit(d.validation)
|
||||||
|
e_test = estimator.extend(d.test)
|
||||||
|
ep = estimator.estimate(e_test.X, ext=True)
|
||||||
|
print(f"{qc.error.acc(ep) = }")
|
||||||
|
print(f"{qc.error.acc(e_test.prevalence()) = }")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_et()
|
|
@ -0,0 +1,177 @@
|
||||||
|
import math
|
||||||
|
from abc import abstractmethod
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from quapy.data import LabelledCollection
|
||||||
|
from quapy.method.aggregative import BaseQuantifier
|
||||||
|
from scipy.sparse import csr_matrix
|
||||||
|
from sklearn.base import BaseEstimator
|
||||||
|
|
||||||
|
from quacc.data import ExtendedCollection
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAccuracyEstimator(BaseQuantifier):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: BaseEstimator,
|
||||||
|
quantifier: BaseQuantifier,
|
||||||
|
confidence=None,
|
||||||
|
):
|
||||||
|
self.__check_classifier(classifier)
|
||||||
|
self.quantifier = quantifier
|
||||||
|
self.confidence = confidence
|
||||||
|
|
||||||
|
def __check_classifier(self, classifier):
|
||||||
|
if not hasattr(classifier, "predict_proba"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities."
|
||||||
|
)
|
||||||
|
self.classifier = classifier
|
||||||
|
|
||||||
|
def __get_confidence(self):
|
||||||
|
def max_conf(probas):
|
||||||
|
_mc = np.max(probas, axis=-1)
|
||||||
|
_min = 1.0 / probas.shape[1]
|
||||||
|
_norm_mc = (_mc - _min) / (1.0 - _min)
|
||||||
|
return _norm_mc
|
||||||
|
|
||||||
|
def entropy(probas):
|
||||||
|
_ent = np.sum(np.multiply(probas, np.log(probas + 1e-20)), axis=1)
|
||||||
|
return _ent
|
||||||
|
|
||||||
|
if self.confidence is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
__confs = {
|
||||||
|
"max_conf": max_conf,
|
||||||
|
"entropy": entropy,
|
||||||
|
}
|
||||||
|
return __confs.get(self.confidence, None)
|
||||||
|
|
||||||
|
def __get_ext(self, pred_proba):
|
||||||
|
_ext = pred_proba
|
||||||
|
_f_conf = self.__get_confidence()
|
||||||
|
if _f_conf is not None:
|
||||||
|
_confs = _f_conf(pred_proba).reshape((len(pred_proba), 1))
|
||||||
|
_ext = np.concatenate((_confs, pred_proba), axis=1)
|
||||||
|
|
||||||
|
return _ext
|
||||||
|
|
||||||
|
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||||
|
if pred_proba is None:
|
||||||
|
pred_proba = self.classifier.predict_proba(coll.X)
|
||||||
|
|
||||||
|
_ext = self.__get_ext(pred_proba)
|
||||||
|
return ExtendedCollection.extend_collection(coll, pred_proba=_ext)
|
||||||
|
|
||||||
|
def _extend_instances(self, instances: np.ndarray | csr_matrix, pred_proba=None):
|
||||||
|
if pred_proba is None:
|
||||||
|
pred_proba = self.classifier.predict_proba(instances)
|
||||||
|
|
||||||
|
_ext = self.__get_ext(pred_proba)
|
||||||
|
return ExtendedCollection.extend_instances(instances, _ext)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MultiClassAccuracyEstimator(BaseAccuracyEstimator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: BaseEstimator,
|
||||||
|
quantifier: BaseQuantifier,
|
||||||
|
confidence: str = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
classifier=classifier,
|
||||||
|
quantifier=quantifier,
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
self.e_train = None
|
||||||
|
|
||||||
|
def fit(self, train: LabelledCollection):
|
||||||
|
self.e_train = self.extend(train)
|
||||||
|
|
||||||
|
self.quantifier.fit(self.e_train)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||||
|
e_inst = instances if ext else self._extend_instances(instances)
|
||||||
|
|
||||||
|
estim_prev = self.quantifier.quantify(e_inst)
|
||||||
|
return self._check_prevalence_classes(estim_prev, self.quantifier.classes_)
|
||||||
|
|
||||||
|
def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray:
|
||||||
|
true_classes = self.e_train.classes_
|
||||||
|
for _cls in true_classes:
|
||||||
|
if _cls not in estim_classes:
|
||||||
|
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
||||||
|
return estim_prev
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: BaseEstimator,
|
||||||
|
quantifier: BaseAccuracyEstimator,
|
||||||
|
confidence: str = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
classifier=classifier,
|
||||||
|
quantifier=quantifier,
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
self.quantifiers = []
|
||||||
|
self.e_trains = []
|
||||||
|
|
||||||
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||||
|
self.e_train = self.extend(train)
|
||||||
|
|
||||||
|
self.n_classes = self.e_train.n_classes
|
||||||
|
self.e_trains = self.e_train.split_by_pred()
|
||||||
|
|
||||||
|
self.quantifiers = []
|
||||||
|
for train in self.e_trains:
|
||||||
|
quant = deepcopy(self.quantifier)
|
||||||
|
quant.fit(train)
|
||||||
|
self.quantifiers.append(quant)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def estimate(self, instances, ext=False):
|
||||||
|
# TODO: test
|
||||||
|
e_inst = instances if ext else self._extend_instances(instances)
|
||||||
|
|
||||||
|
_ncl = int(math.sqrt(self.n_classes))
|
||||||
|
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
|
||||||
|
estim_prevs = self._quantify_helper(s_inst, norms)
|
||||||
|
|
||||||
|
estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten()
|
||||||
|
return estim_prev
|
||||||
|
|
||||||
|
def _quantify_helper(
|
||||||
|
self,
|
||||||
|
s_inst: List[np.ndarray | csr_matrix],
|
||||||
|
norms: List[float],
|
||||||
|
):
|
||||||
|
estim_prevs = []
|
||||||
|
for quant, inst, norm in zip(self.quantifiers, s_inst, norms):
|
||||||
|
if inst.shape[0] > 0:
|
||||||
|
estim_prevs.append(quant.quantify(inst) * norm)
|
||||||
|
else:
|
||||||
|
estim_prevs.append(np.asarray([0.0, 0.0]))
|
||||||
|
|
||||||
|
return estim_prevs
|
||||||
|
|
||||||
|
|
||||||
|
BAE = BaseAccuracyEstimator
|
||||||
|
MCAE = MultiClassAccuracyEstimator
|
||||||
|
BQAE = BinaryQuantifierAccuracyEstimator
|
|
@ -0,0 +1,307 @@
|
||||||
|
import itertools
|
||||||
|
from copy import deepcopy
|
||||||
|
from time import time
|
||||||
|
from typing import Callable, Union
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import quapy as qp
|
||||||
|
from quapy.data import LabelledCollection
|
||||||
|
from quapy.model_selection import GridSearchQ
|
||||||
|
from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol
|
||||||
|
from sklearn.base import BaseEstimator
|
||||||
|
|
||||||
|
import quacc as qc
|
||||||
|
import quacc.error
|
||||||
|
from quacc.data import ExtendedCollection
|
||||||
|
from quacc.evaluation import evaluate
|
||||||
|
from quacc.logger import SubLogger
|
||||||
|
from quacc.method.base import (
|
||||||
|
BaseAccuracyEstimator,
|
||||||
|
BinaryQuantifierAccuracyEstimator,
|
||||||
|
MultiClassAccuracyEstimator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GridSearchAE(BaseAccuracyEstimator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: BaseAccuracyEstimator,
|
||||||
|
param_grid: dict,
|
||||||
|
protocol: AbstractProtocol,
|
||||||
|
error: Union[Callable, str] = qc.error.maccd,
|
||||||
|
refit=True,
|
||||||
|
# timeout=-1,
|
||||||
|
# n_jobs=None,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.param_grid = self.__normalize_params(param_grid)
|
||||||
|
self.protocol = protocol
|
||||||
|
self.refit = refit
|
||||||
|
# self.timeout = timeout
|
||||||
|
# self.n_jobs = qp._get_njobs(n_jobs)
|
||||||
|
self.verbose = verbose
|
||||||
|
self.__check_error(error)
|
||||||
|
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
|
||||||
|
|
||||||
|
def _sout(self, msg):
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[{self.__class__.__name__}]: {msg}")
|
||||||
|
|
||||||
|
def __normalize_params(self, params):
|
||||||
|
__remap = {}
|
||||||
|
for key in params.keys():
|
||||||
|
k, delim, sub_key = key.partition("__")
|
||||||
|
if delim and k == "q":
|
||||||
|
__remap[key] = f"quantifier__{sub_key}"
|
||||||
|
|
||||||
|
return {(__remap[k] if k in __remap else k): v for k, v in params.items()}
|
||||||
|
|
||||||
|
def __check_error(self, error):
|
||||||
|
if error in qc.error.ACCURACY_ERROR:
|
||||||
|
self.error = error
|
||||||
|
elif isinstance(error, str):
|
||||||
|
self.error = qc.error.from_name(error)
|
||||||
|
elif hasattr(error, "__call__"):
|
||||||
|
self.error = error
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"unexpected error type; must either be a callable function or a str representing\n"
|
||||||
|
f"the name of an error function in {qc.error.ACCURACY_ERROR_NAMES}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def fit(self, training: LabelledCollection):
|
||||||
|
"""Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
||||||
|
the error metric.
|
||||||
|
|
||||||
|
:param training: the training set on which to optimize the hyperparameters
|
||||||
|
:return: self
|
||||||
|
"""
|
||||||
|
params_keys = list(self.param_grid.keys())
|
||||||
|
params_values = list(self.param_grid.values())
|
||||||
|
|
||||||
|
protocol = self.protocol
|
||||||
|
|
||||||
|
self.param_scores_ = {}
|
||||||
|
self.best_score_ = None
|
||||||
|
|
||||||
|
tinit = time()
|
||||||
|
|
||||||
|
hyper = [
|
||||||
|
dict(zip(params_keys, val)) for val in itertools.product(*params_values)
|
||||||
|
]
|
||||||
|
|
||||||
|
# self._sout(f"starting model selection with {self.n_jobs =}")
|
||||||
|
self._sout("starting model selection")
|
||||||
|
|
||||||
|
scores = [self.__params_eval(params, training) for params in hyper]
|
||||||
|
|
||||||
|
for params, score, model in scores:
|
||||||
|
if score is not None:
|
||||||
|
if self.best_score_ is None or score < self.best_score_:
|
||||||
|
self.best_score_ = score
|
||||||
|
self.best_params_ = params
|
||||||
|
self.best_model_ = model
|
||||||
|
self.param_scores_[str(params)] = score
|
||||||
|
else:
|
||||||
|
self.param_scores_[str(params)] = "timeout"
|
||||||
|
|
||||||
|
tend = time() - tinit
|
||||||
|
|
||||||
|
if self.best_score_ is None:
|
||||||
|
raise TimeoutError("no combination of hyperparameters seem to work")
|
||||||
|
|
||||||
|
self._sout(
|
||||||
|
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
|
||||||
|
f"[took {tend:.4f}s]"
|
||||||
|
)
|
||||||
|
log = SubLogger.logger()
|
||||||
|
log.debug(
|
||||||
|
f"[{self.model.__class__.__name__}] "
|
||||||
|
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
|
||||||
|
f"[took {tend:.4f}s]"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.refit:
|
||||||
|
if isinstance(protocol, OnLabelledCollectionProtocol):
|
||||||
|
self._sout("refitting on the whole development set")
|
||||||
|
self.best_model_.fit(training + protocol.get_labelled_collection())
|
||||||
|
else:
|
||||||
|
raise RuntimeWarning(
|
||||||
|
f'"refit" was requested, but the protocol does not '
|
||||||
|
f"implement the {OnLabelledCollectionProtocol.__name__} interface"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __params_eval(self, params, training):
|
||||||
|
protocol = self.protocol
|
||||||
|
error = self.error
|
||||||
|
|
||||||
|
# if self.timeout > 0:
|
||||||
|
|
||||||
|
# def handler(signum, frame):
|
||||||
|
# raise TimeoutError()
|
||||||
|
|
||||||
|
# signal.signal(signal.SIGALRM, handler)
|
||||||
|
|
||||||
|
tinit = time()
|
||||||
|
|
||||||
|
# if self.timeout > 0:
|
||||||
|
# signal.alarm(self.timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = deepcopy(self.model)
|
||||||
|
# overrides default parameters with the parameters being explored at this iteration
|
||||||
|
model.set_params(**params)
|
||||||
|
# print({k: v for k, v in model.get_params().items() if k in params})
|
||||||
|
model.fit(training)
|
||||||
|
score = evaluate(model, protocol=protocol, error_metric=error)
|
||||||
|
|
||||||
|
ttime = time() - tinit
|
||||||
|
self._sout(
|
||||||
|
f"hyperparams={params}\t got score {score:.5f} [took {ttime:.4f}s]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# if self.timeout > 0:
|
||||||
|
# signal.alarm(0)
|
||||||
|
# except TimeoutError:
|
||||||
|
# self._sout(f"timeout ({self.timeout}s) reached for config {params}")
|
||||||
|
# score = None
|
||||||
|
except ValueError as e:
|
||||||
|
self._sout(f"the combination of hyperparameters {params} is invalid")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
self._sout(f"something went wrong for config {params}; skipping:")
|
||||||
|
self._sout(f"\tException: {e}")
|
||||||
|
score = None
|
||||||
|
|
||||||
|
return params, score, model
|
||||||
|
|
||||||
|
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||||
|
assert hasattr(self, "best_model_"), "quantify called before fit"
|
||||||
|
return self.best_model().extend(coll, pred_proba=pred_proba)
|
||||||
|
|
||||||
|
def estimate(self, instances, ext=False):
|
||||||
|
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
|
||||||
|
|
||||||
|
:param instances: sample contanining the instances
|
||||||
|
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
|
||||||
|
by the model selection process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert hasattr(self, "best_model_"), "estimate called before fit"
|
||||||
|
return self.best_model().estimate(instances, ext=ext)
|
||||||
|
|
||||||
|
def set_params(self, **parameters):
|
||||||
|
"""Sets the hyper-parameters to explore.
|
||||||
|
|
||||||
|
:param parameters: a dictionary with keys the parameter names and values the list of values to explore
|
||||||
|
"""
|
||||||
|
self.param_grid = parameters
|
||||||
|
|
||||||
|
def get_params(self, deep=True):
|
||||||
|
"""Returns the dictionary of hyper-parameters to explore (`param_grid`)
|
||||||
|
|
||||||
|
:param deep: Unused
|
||||||
|
:return: the dictionary `param_grid`
|
||||||
|
"""
|
||||||
|
return self.param_grid
|
||||||
|
|
||||||
|
def best_model(self):
|
||||||
|
"""
|
||||||
|
Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination
|
||||||
|
of hyper-parameters that minimized the error function.
|
||||||
|
|
||||||
|
:return: a trained quantifier
|
||||||
|
"""
|
||||||
|
if hasattr(self, "best_model_"):
|
||||||
|
return self.best_model_
|
||||||
|
raise ValueError("best_model called before fit")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MCAEgsq(MultiClassAccuracyEstimator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: BaseEstimator,
|
||||||
|
quantifier: BaseAccuracyEstimator,
|
||||||
|
param_grid: dict,
|
||||||
|
error: Union[Callable, str] = qp.error.mae,
|
||||||
|
refit=True,
|
||||||
|
timeout=-1,
|
||||||
|
n_jobs=None,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
self.param_grid = param_grid
|
||||||
|
self.refit = refit
|
||||||
|
self.timeout = timeout
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
self.verbose = verbose
|
||||||
|
self.error = error
|
||||||
|
super().__init__(classifier, quantifier)
|
||||||
|
|
||||||
|
def fit(self, train: LabelledCollection):
|
||||||
|
self.e_train = self.extend(train)
|
||||||
|
t_train, t_val = self.e_train.split_stratified(0.6, random_state=0)
|
||||||
|
self.quantifier = GridSearchQ(
|
||||||
|
deepcopy(self.quantifier),
|
||||||
|
param_grid=self.param_grid,
|
||||||
|
protocol=UPP(t_val, repeats=100),
|
||||||
|
error=self.error,
|
||||||
|
refit=self.refit,
|
||||||
|
timeout=self.timeout,
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
verbose=self.verbose,
|
||||||
|
).fit(self.e_train)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||||
|
e_inst = instances if ext else self._extend_instances(instances)
|
||||||
|
estim_prev = self.quantifier.quantify(e_inst)
|
||||||
|
return self._check_prevalence_classes(estim_prev, self.quantifier.best_model().classes_)
|
||||||
|
|
||||||
|
|
||||||
|
class BQAEgsq(BinaryQuantifierAccuracyEstimator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: BaseEstimator,
|
||||||
|
quantifier: BaseAccuracyEstimator,
|
||||||
|
param_grid: dict,
|
||||||
|
error: Union[Callable, str] = qp.error.mae,
|
||||||
|
refit=True,
|
||||||
|
timeout=-1,
|
||||||
|
n_jobs=None,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
self.param_grid = param_grid
|
||||||
|
self.refit = refit
|
||||||
|
self.timeout = timeout
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
self.verbose = verbose
|
||||||
|
self.error = error
|
||||||
|
super().__init__(classifier=classifier, quantifier=quantifier)
|
||||||
|
|
||||||
|
def fit(self, train: LabelledCollection):
|
||||||
|
self.e_train = self.extend(train)
|
||||||
|
|
||||||
|
self.n_classes = self.e_train.n_classes
|
||||||
|
self.e_trains = self.e_train.split_by_pred()
|
||||||
|
|
||||||
|
self.quantifiers = []
|
||||||
|
for e_train in self.e_trains:
|
||||||
|
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
|
||||||
|
quantifier = GridSearchQ(
|
||||||
|
model=deepcopy(self.quantifier),
|
||||||
|
param_grid=self.param_grid,
|
||||||
|
protocol=UPP(t_val, repeats=100),
|
||||||
|
error=self.error,
|
||||||
|
refit=self.refit,
|
||||||
|
timeout=self.timeout,
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
verbose=self.verbose,
|
||||||
|
).fit(t_train)
|
||||||
|
self.quantifiers.append(quantifier)
|
||||||
|
|
||||||
|
return self
|
|
@ -1,138 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
import scipy as sp
|
|
||||||
import quapy as qp
|
|
||||||
from quapy.data import LabelledCollection
|
|
||||||
from quapy.method.aggregative import SLD
|
|
||||||
from quapy.protocol import APP, AbstractStochasticSeededProtocol
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from sklearn.model_selection import cross_val_predict
|
|
||||||
|
|
||||||
from .data import get_dataset
|
|
||||||
|
|
||||||
# Extended classes
|
|
||||||
#
|
|
||||||
# 0 ~ True 0
|
|
||||||
# 1 ~ False 1
|
|
||||||
# 2 ~ False 0
|
|
||||||
# 3 ~ True 1
|
|
||||||
# _____________________
|
|
||||||
# | | |
|
|
||||||
# | True 0 | False 1 |
|
|
||||||
# |__________|__________|
|
|
||||||
# | | |
|
|
||||||
# | False 0 | True 1 |
|
|
||||||
# |__________|__________|
|
|
||||||
#
|
|
||||||
def get_ex_class(classes, true_class, pred_class):
|
|
||||||
return true_class * classes + pred_class
|
|
||||||
|
|
||||||
|
|
||||||
def extend_collection(coll, pred_prob):
|
|
||||||
n_classes = coll.n_classes
|
|
||||||
|
|
||||||
# n_X = [ X | predicted probs. ]
|
|
||||||
if isinstance(coll.X, sp.csr_matrix):
|
|
||||||
pred_prob_csr = sp.csr_matrix(pred_prob)
|
|
||||||
n_x = sp.hstack([coll.X, pred_prob_csr])
|
|
||||||
elif isinstance(coll.X, np.ndarray):
|
|
||||||
n_x = np.concatenate((coll.X, pred_prob), axis=1)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported matrix format")
|
|
||||||
|
|
||||||
# n_y = (exptected y, predicted y)
|
|
||||||
n_y = []
|
|
||||||
for i, true_class in enumerate(coll.y):
|
|
||||||
pred_class = pred_prob[i].argmax(axis=0)
|
|
||||||
n_y.append(get_ex_class(n_classes, true_class, pred_class))
|
|
||||||
|
|
||||||
return LabelledCollection(n_x, np.asarray(n_y), [*range(0, n_classes * n_classes)])
|
|
||||||
|
|
||||||
|
|
||||||
def qf1e_binary(prev):
|
|
||||||
recall = prev[0] / (prev[0] + prev[1])
|
|
||||||
precision = prev[0] / (prev[0] + prev[2])
|
|
||||||
|
|
||||||
return 1 - 2 * (precision * recall) / (precision + recall)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_errors(true_prev, estim_prev, n_instances):
|
|
||||||
errors = {}
|
|
||||||
_eps = 1 / (2 * n_instances)
|
|
||||||
errors = {
|
|
||||||
"mae": qp.error.mae(true_prev, estim_prev),
|
|
||||||
"rae": qp.error.rae(true_prev, estim_prev, eps=_eps),
|
|
||||||
"mrae": qp.error.mrae(true_prev, estim_prev, eps=_eps),
|
|
||||||
"kld": qp.error.kld(true_prev, estim_prev, eps=_eps),
|
|
||||||
"nkld": qp.error.nkld(true_prev, estim_prev, eps=_eps),
|
|
||||||
"true_f1e": qf1e_binary(true_prev),
|
|
||||||
"estim_f1e": qf1e_binary(estim_prev),
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
def extend_and_quantify(
|
|
||||||
model,
|
|
||||||
q_model,
|
|
||||||
train,
|
|
||||||
test: LabelledCollection | AbstractStochasticSeededProtocol,
|
|
||||||
):
|
|
||||||
model.fit(*train.Xy)
|
|
||||||
|
|
||||||
pred_prob_train = cross_val_predict(model, *train.Xy, method="predict_proba")
|
|
||||||
_train = extend_collection(train, pred_prob_train)
|
|
||||||
|
|
||||||
q_model.fit(_train)
|
|
||||||
|
|
||||||
def quantify_extended(test):
|
|
||||||
pred_prob_test = model.predict_proba(test.X)
|
|
||||||
_test = extend_collection(test, pred_prob_test)
|
|
||||||
_estim_prev = q_model.quantify(_test.instances)
|
|
||||||
# check that _estim_prev has all the classes and eventually fill the missing
|
|
||||||
# ones with 0
|
|
||||||
for _cls in _test.classes_:
|
|
||||||
if _cls not in q_model.classes_:
|
|
||||||
_estim_prev = np.insert(_estim_prev, _cls, [0.0], axis=0)
|
|
||||||
print(_estim_prev)
|
|
||||||
return _test.prevalence(), _estim_prev
|
|
||||||
|
|
||||||
if isinstance(test, LabelledCollection):
|
|
||||||
_true_prev, _estim_prev = quantify_extended(test)
|
|
||||||
_errors = compute_errors(_true_prev, _estim_prev, test.X.shape[0])
|
|
||||||
return ([test.prevalence()], [_true_prev], [_estim_prev], [_errors])
|
|
||||||
|
|
||||||
elif isinstance(test, AbstractStochasticSeededProtocol):
|
|
||||||
orig_prevs, true_prevs, estim_prevs, errors = [], [], [], []
|
|
||||||
for index in test.samples_parameters():
|
|
||||||
sample = test.sample(index)
|
|
||||||
_true_prev, _estim_prev = quantify_extended(sample)
|
|
||||||
|
|
||||||
orig_prevs.append(sample.prevalence())
|
|
||||||
true_prevs.append(_true_prev)
|
|
||||||
estim_prevs.append(_estim_prev)
|
|
||||||
errors.append(compute_errors(_true_prev, _estim_prev, sample.X.shape[0]))
|
|
||||||
|
|
||||||
return orig_prevs, true_prevs, estim_prevs, errors
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_1(dataset_name):
|
|
||||||
train, test = get_dataset(dataset_name)
|
|
||||||
|
|
||||||
orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify(
|
|
||||||
LogisticRegression(),
|
|
||||||
SLD(LogisticRegression()),
|
|
||||||
train,
|
|
||||||
APP(test, n_prevalences=11, repeats=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
for orig_prev, true_prev, estim_prev, _errors in zip(
|
|
||||||
orig_prevs, true_prevs, estim_prevs, errors
|
|
||||||
):
|
|
||||||
print(f"original prevalence:\t{orig_prev}")
|
|
||||||
print(f"true prevalence:\t{true_prev}")
|
|
||||||
print(f"estimated prevalence:\t{estim_prev}")
|
|
||||||
for name, err in _errors.items():
|
|
||||||
print(f"{name}={err:.3f}")
|
|
||||||
print()
|
|
|
@ -27,15 +27,15 @@ def plot_delta(
|
||||||
metric="acc",
|
metric="acc",
|
||||||
name="default",
|
name="default",
|
||||||
train_prev=None,
|
train_prev=None,
|
||||||
fit_scores=None,
|
|
||||||
legend=True,
|
legend=True,
|
||||||
|
avg=None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
_base_title = "delta_stdev" if stdevs is not None else "delta"
|
_base_title = "delta_stdev" if stdevs is not None else "delta"
|
||||||
if train_prev is not None:
|
if train_prev is not None:
|
||||||
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
t_prev_pos = int(round(train_prev[pos_class] * 100))
|
||||||
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
|
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
|
||||||
else:
|
else:
|
||||||
title = f"{_base_title}_{name}_avg_{metric}"
|
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
ax.set_aspect("auto")
|
ax.set_aspect("auto")
|
||||||
|
@ -74,16 +74,13 @@ def plot_delta(
|
||||||
color=_cy["color"],
|
color=_cy["color"],
|
||||||
alpha=0.25,
|
alpha=0.25,
|
||||||
)
|
)
|
||||||
if fit_scores is not None and method in fit_scores:
|
|
||||||
ax.plot(
|
|
||||||
base_prevs,
|
|
||||||
np.repeat(fit_scores[method], base_prevs.shape[0]),
|
|
||||||
color=_cy["color"],
|
|
||||||
linestyle="--",
|
|
||||||
markersize=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
|
x_label = "test" if avg is None or avg == "train" else "train"
|
||||||
|
ax.set(
|
||||||
|
xlabel=f"{x_label} prevalence",
|
||||||
|
ylabel=metric,
|
||||||
|
title=title,
|
||||||
|
)
|
||||||
|
|
||||||
if legend:
|
if legend:
|
||||||
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||||
|
@ -182,11 +179,11 @@ def plot_shift(
|
||||||
columns,
|
columns,
|
||||||
data,
|
data,
|
||||||
*,
|
*,
|
||||||
|
counts=None,
|
||||||
pos_class=1,
|
pos_class=1,
|
||||||
metric="acc",
|
metric="acc",
|
||||||
name="default",
|
name="default",
|
||||||
train_prev=None,
|
train_prev=None,
|
||||||
fit_scores=None,
|
|
||||||
legend=True,
|
legend=True,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
if train_prev is not None:
|
if train_prev is not None:
|
||||||
|
@ -217,14 +214,19 @@ def plot_shift(
|
||||||
markersize=3,
|
markersize=3,
|
||||||
zorder=2,
|
zorder=2,
|
||||||
)
|
)
|
||||||
|
if counts is not None:
|
||||||
if fit_scores is not None and method in fit_scores:
|
_col_idx = np.where(columns == method)[0]
|
||||||
ax.plot(
|
count = counts[_col_idx].flatten()
|
||||||
shift_prevs,
|
for prev, shift, cnt in zip(shift_prevs, shifts, count):
|
||||||
np.repeat(fit_scores[method], shift_prevs.shape[0]),
|
label = f"{cnt}"
|
||||||
|
plt.annotate(
|
||||||
|
label,
|
||||||
|
(prev, shift),
|
||||||
|
textcoords="offset points",
|
||||||
|
xytext=(0, 10),
|
||||||
|
ha="center",
|
||||||
color=_cy["color"],
|
color=_cy["color"],
|
||||||
linestyle="--",
|
fontsize=12.0,
|
||||||
markersize=0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ax.set(xlabel="dataset shift", ylabel=metric, title=title)
|
ax.set(xlabel="dataset shift", ylabel=metric, title=title)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,20 +0,0 @@
|
||||||
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from quacc.evaluation.baseline import kfcv, trust_score
|
|
||||||
from quacc.dataset import get_spambase
|
|
||||||
|
|
||||||
|
|
||||||
class TestBaseline:
|
|
||||||
|
|
||||||
def test_kfcv(self):
|
|
||||||
train, validation, _ = get_spambase()
|
|
||||||
c_model = LogisticRegression()
|
|
||||||
c_model.fit(train.X, train.y)
|
|
||||||
assert "f1_score" in kfcv(c_model, validation)
|
|
||||||
|
|
||||||
def test_trust_score(self):
|
|
||||||
train, validation, test = get_spambase()
|
|
||||||
c_model = LogisticRegression()
|
|
||||||
c_model.fit(train.X, train.y)
|
|
||||||
trustscore = trust_score(c_model, train, test)
|
|
||||||
assert len(trustscore) == len(test.y)
|
|
Binary file not shown.
|
@ -0,0 +1,12 @@
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
|
from quacc.dataset import Dataset
|
||||||
|
from quacc.evaluation.baseline import kfcv
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseline:
|
||||||
|
def test_kfcv(self):
|
||||||
|
spambase = Dataset("spambase", n_prevalences=1).get_raw()
|
||||||
|
c_model = LogisticRegression()
|
||||||
|
c_model.fit(spambase.train.X, spambase.train.y)
|
||||||
|
assert "f1_score" in kfcv(c_model, spambase.validation)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,12 +1,12 @@
|
||||||
import pytest
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import scipy.sparse as sp
|
import scipy.sparse as sp
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
from quacc.estimator import BinaryQuantifierAccuracyEstimator
|
from quacc.method.base import BinaryQuantifierAccuracyEstimator
|
||||||
|
|
||||||
|
|
||||||
class TestBinaryQuantifierAccuracyEstimator:
|
class TestBQAE:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"instances,preds0,preds1,result",
|
"instances,preds0,preds1,result",
|
||||||
[
|
[
|
|
@ -0,0 +1,2 @@
|
||||||
|
class TestMCAE:
|
||||||
|
pass
|
Loading…
Reference in New Issue