diff --git a/.gitignore b/.gitignore
index 1ae9719..b199a8a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,4 +11,6 @@ lipton_bbse/__pycache__/*
elsahar19_rca/__pycache__/*
*.coverage
.coverage
-scp_sync.py
\ No newline at end of file
+scp_sync.py
+out/*
+output/*
\ No newline at end of file
diff --git a/TODO.html b/TODO.html
index ddfdc17..35dd2f8 100644
--- a/TODO.html
+++ b/TODO.html
@@ -41,12 +41,12 @@
diff --git a/TODO.md b/TODO.md
index 2f6d846..d7012a1 100644
--- a/TODO.md
+++ b/TODO.md
@@ -1,6 +1,6 @@
-- [ ] aggiungere media tabelle
-- [ ] plot; 3 tipi (appunti + email + garg)
+- [x] aggiungere media tabelle
+- [x] plot; 3 tipi (appunti + email + garg)
- [ ] sistemare kfcv baseline
-- [ ] aggiungere metodo con CC oltre SLD
+- [x] aggiungere metodo con CC oltre SLD
- [x] prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)
-- [ ] variare parametro recalibration in SLD
\ No newline at end of file
+- [x] variare parametro recalibration in SLD
\ No newline at end of file
diff --git a/conf.yaml b/conf.yaml
new file mode 100644
index 0000000..50a5dd0
--- /dev/null
+++ b/conf.yaml
@@ -0,0 +1,71 @@
+
+exec: []
+
+commons:
+ - DATASET_NAME: rcv1
+ DATASET_TARGET: CCAT
+ METRICS:
+ - acc
+ - f1
+ DATASET_N_PREVS: 9
+ - DATASET_NAME: imdb
+ METRICS:
+ - acc
+ - f1
+ DATASET_N_PREVS: 9
+
+confs:
+
+ all_mul_vs_atc:
+ COMP_ESTIMATORS:
+ - our_mul_SLD
+ - our_mul_SLD_nbvs
+ - our_mul_SLD_bcts
+ - our_mul_SLD_ts
+ - our_mul_SLD_vs
+ - our_mul_CC
+ - ref
+ - atc_mc
+ - atc_ne
+
+ all_bin_vs_atc:
+ COMP_ESTIMATORS:
+ - our_bin_SLD
+ - our_bin_SLD_nbvs
+ - our_bin_SLD_bcts
+ - our_bin_SLD_ts
+ - our_bin_SLD_vs
+ - our_bin_CC
+ - ref
+ - atc_mc
+ - atc_ne
+
+ best_our_vs_atc:
+ COMP_ESTIMATORS:
+ - our_bin_SLD
+ - our_bin_SLD_bcts
+ - our_bin_SLD_vs
+ - our_bin_CC
+ - our_mul_SLD
+ - our_mul_SLD_bcts
+ - our_mul_SLD_vs
+ - our_mul_CC
+ - ref
+ - atc_mc
+ - atc_ne
+
+ best_our_vs_all:
+ COMP_ESTIMATORS:
+ - our_bin_SLD
+ - our_bin_SLD_bcts
+ - our_bin_SLD_vs
+ - our_bin_CC
+ - our_mul_SLD
+ - our_mul_SLD_bcts
+ - our_mul_SLD_vs
+ - our_mul_CC
+ - ref
+ - kfcv
+ - atc_mc
+ - atc_ne
+ - doc_feat
diff --git a/out/plot/rcv1_CCAT_10_acc.png b/out/plot/rcv1_CCAT_10_acc.png
deleted file mode 100644
index 2994b60..0000000
Binary files a/out/plot/rcv1_CCAT_10_acc.png and /dev/null differ
diff --git a/out/plot/rcv1_CCAT_20_acc.png b/out/plot/rcv1_CCAT_20_acc.png
deleted file mode 100644
index 83a7991..0000000
Binary files a/out/plot/rcv1_CCAT_20_acc.png and /dev/null differ
diff --git a/out/plot/rcv1_CCAT_30_acc.png b/out/plot/rcv1_CCAT_30_acc.png
deleted file mode 100644
index 2e34308..0000000
Binary files a/out/plot/rcv1_CCAT_30_acc.png and /dev/null differ
diff --git a/out/plot/rcv1_CCAT_40_acc.png b/out/plot/rcv1_CCAT_40_acc.png
deleted file mode 100644
index 031feda..0000000
Binary files a/out/plot/rcv1_CCAT_40_acc.png and /dev/null differ
diff --git a/out/plot/rcv1_CCAT_50_acc.png b/out/plot/rcv1_CCAT_50_acc.png
deleted file mode 100644
index 86d23e7..0000000
Binary files a/out/plot/rcv1_CCAT_50_acc.png and /dev/null differ
diff --git a/out/plot/rcv1_CCAT_60_acc.png b/out/plot/rcv1_CCAT_60_acc.png
deleted file mode 100644
index 374cb70..0000000
Binary files a/out/plot/rcv1_CCAT_60_acc.png and /dev/null differ
diff --git a/out/plot/rcv1_CCAT_70_acc.png b/out/plot/rcv1_CCAT_70_acc.png
deleted file mode 100644
index 3af314e..0000000
Binary files a/out/plot/rcv1_CCAT_70_acc.png and /dev/null differ
diff --git a/out/plot/rcv1_CCAT_80_acc.png b/out/plot/rcv1_CCAT_80_acc.png
deleted file mode 100644
index 2f525d0..0000000
Binary files a/out/plot/rcv1_CCAT_80_acc.png and /dev/null differ
diff --git a/out/plot/rcv1_CCAT_90_acc.png b/out/plot/rcv1_CCAT_90_acc.png
deleted file mode 100644
index 0150b84..0000000
Binary files a/out/plot/rcv1_CCAT_90_acc.png and /dev/null differ
diff --git a/out/rcv1_CCAT.md b/out/rcv1_CCAT.md
deleted file mode 100644
index 1eff4e7..0000000
--- a/out/rcv1_CCAT.md
+++ /dev/null
@@ -1,1955 +0,0 @@
-rcv1_CCAT
-
-> train: [0.09996662 0.90003338]
-> validation: [0.09996662 0.90003338]
-> evaluate_bin_sld: 198.301s
-> evaluate_mul_sld: 53.156s
-> kfcv: 41.095s
-> atc_mc: 42.167s
-> atc_ne: 41.909s
-> doc_feat: 35.796s
-> tot: 202.108s
-
-
-
-
- |
- bin |
- mul |
- kfcv |
- atc_mc |
- atc_ne |
- doc_feat |
-
-
-
-
- (0.0, 1.0) |
- 0.0048 |
- 0.0040 |
- 0.0866 |
- 0.0243 |
- 0.0243 |
- 0.0830 |
-
-
- (0.05, 0.95) |
- 0.0060 |
- 0.0072 |
- 0.0441 |
- 0.0134 |
- 0.0134 |
- 0.0407 |
-
-
- (0.1, 0.9) |
- 0.0084 |
- 0.0103 |
- 0.0032 |
- 0.0070 |
- 0.0070 |
- 0.0036 |
-
-
- (0.15, 0.85) |
- 0.0127 |
- 0.0172 |
- 0.0418 |
- 0.0090 |
- 0.0090 |
- 0.0450 |
-
-
- (0.2, 0.8) |
- 0.0184 |
- 0.0246 |
- 0.0841 |
- 0.0168 |
- 0.0168 |
- 0.0872 |
-
-
- (0.25, 0.75) |
- 0.0231 |
- 0.0318 |
- 0.1246 |
- 0.0239 |
- 0.0239 |
- 0.1276 |
-
-
- (0.3, 0.7) |
- 0.0313 |
- 0.0426 |
- 0.1678 |
- 0.0334 |
- 0.0334 |
- 0.1706 |
-
-
- (0.35, 0.65) |
- 0.0392 |
- 0.0536 |
- 0.2110 |
- 0.0422 |
- 0.0422 |
- 0.2137 |
-
-
- (0.4, 0.6) |
- 0.0418 |
- 0.0563 |
- 0.2528 |
- 0.0541 |
- 0.0541 |
- 0.2555 |
-
-
- (0.45, 0.55) |
- 0.0527 |
- 0.0715 |
- 0.2966 |
- 0.0622 |
- 0.0622 |
- 0.2991 |
-
-
- (0.5, 0.5) |
- 0.0569 |
- 0.0771 |
- 0.3383 |
- 0.0749 |
- 0.0749 |
- 0.3407 |
-
-
- (0.55, 0.45) |
- 0.0637 |
- 0.0867 |
- 0.3817 |
- 0.0847 |
- 0.0847 |
- 0.3840 |
-
-
- (0.6, 0.4) |
- 0.0727 |
- 0.0999 |
- 0.4250 |
- 0.0943 |
- 0.0943 |
- 0.4272 |
-
-
- (0.65, 0.35) |
- 0.0778 |
- 0.1062 |
- 0.4662 |
- 0.1040 |
- 0.1040 |
- 0.4683 |
-
-
- (0.7, 0.3) |
- 0.0825 |
- 0.1118 |
- 0.5099 |
- 0.1131 |
- 0.1131 |
- 0.5119 |
-
-
- (0.75, 0.25) |
- 0.0879 |
- 0.1197 |
- 0.5519 |
- 0.1217 |
- 0.1217 |
- 0.5537 |
-
-
- (0.8, 0.2) |
- 0.0887 |
- 0.1192 |
- 0.5945 |
- 0.1334 |
- 0.1334 |
- 0.5963 |
-
-
- (0.85, 0.15) |
- 0.0926 |
- 0.1269 |
- 0.6368 |
- 0.1426 |
- 0.1426 |
- 0.6384 |
-
-
- (0.9, 0.1) |
- 0.0887 |
- 0.1250 |
- 0.6791 |
- 0.1528 |
- 0.1528 |
- 0.6806 |
-
-
- (0.95, 0.05) |
- 0.0501 |
- 0.0961 |
- 0.7227 |
- 0.1614 |
- 0.1614 |
- 0.7241 |
-
-
- (1.0, 0.0) |
- 0.0004 |
- 0.0358 |
- 0.7631 |
- 0.1704 |
- 0.1704 |
- 0.7643 |
-
-
-
-
-
-
-> train: [0.19993324 0.80006676]
-> validation: [0.20010013 0.79989987]
-> evaluate_bin_sld: 199.250s
-> evaluate_mul_sld: 55.414s
-> kfcv: 41.131s
-> atc_mc: 42.125s
-> atc_ne: 41.892s
-> doc_feat: 35.279s
-> tot: 202.707s
-
-
-
-
- |
- bin |
- mul |
- kfcv |
- atc_mc |
- atc_ne |
- doc_feat |
-
-
-
-
- (0.0, 1.0) |
- 0.0055 |
- 0.0058 |
- 0.0915 |
- 0.0147 |
- 0.0147 |
- 0.0775 |
-
-
- (0.05, 0.95) |
- 0.0157 |
- 0.0084 |
- 0.0719 |
- 0.0130 |
- 0.0130 |
- 0.0581 |
-
-
- (0.1, 0.9) |
- 0.0154 |
- 0.0099 |
- 0.0503 |
- 0.0108 |
- 0.0108 |
- 0.0365 |
-
-
- (0.15, 0.85) |
- 0.0141 |
- 0.0111 |
- 0.0292 |
- 0.0104 |
- 0.0104 |
- 0.0158 |
-
-
- (0.2, 0.8) |
- 0.0120 |
- 0.0116 |
- 0.0103 |
- 0.0100 |
- 0.0100 |
- 0.0068 |
-
-
- (0.25, 0.75) |
- 0.0098 |
- 0.0124 |
- 0.0115 |
- 0.0091 |
- 0.0091 |
- 0.0243 |
-
-
- (0.3, 0.7) |
- 0.0079 |
- 0.0131 |
- 0.0312 |
- 0.0106 |
- 0.0106 |
- 0.0445 |
-
-
- (0.35, 0.65) |
- 0.0087 |
- 0.0154 |
- 0.0529 |
- 0.0097 |
- 0.0097 |
- 0.0660 |
-
-
- (0.4, 0.6) |
- 0.0074 |
- 0.0143 |
- 0.0729 |
- 0.0110 |
- 0.0110 |
- 0.0859 |
-
-
- (0.45, 0.55) |
- 0.0082 |
- 0.0148 |
- 0.0933 |
- 0.0111 |
- 0.0111 |
- 0.1062 |
-
-
- (0.5, 0.5) |
- 0.0081 |
- 0.0152 |
- 0.1152 |
- 0.0136 |
- 0.0136 |
- 0.1280 |
-
-
- (0.55, 0.45) |
- 0.0104 |
- 0.0164 |
- 0.1384 |
- 0.0147 |
- 0.0147 |
- 0.1511 |
-
-
- (0.6, 0.4) |
- 0.0108 |
- 0.0193 |
- 0.1567 |
- 0.0168 |
- 0.0168 |
- 0.1692 |
-
-
- (0.65, 0.35) |
- 0.0129 |
- 0.0212 |
- 0.1806 |
- 0.0196 |
- 0.0196 |
- 0.1930 |
-
-
- (0.7, 0.3) |
- 0.0134 |
- 0.0242 |
- 0.2005 |
- 0.0178 |
- 0.0178 |
- 0.2128 |
-
-
- (0.75, 0.25) |
- 0.0162 |
- 0.0238 |
- 0.2196 |
- 0.0201 |
- 0.0201 |
- 0.2318 |
-
-
- (0.8, 0.2) |
- 0.0161 |
- 0.0248 |
- 0.2425 |
- 0.0214 |
- 0.0214 |
- 0.2546 |
-
-
- (0.85, 0.15) |
- 0.0207 |
- 0.0320 |
- 0.2620 |
- 0.0227 |
- 0.0227 |
- 0.2740 |
-
-
- (0.9, 0.1) |
- 0.0233 |
- 0.0340 |
- 0.2841 |
- 0.0267 |
- 0.0267 |
- 0.2960 |
-
-
- (0.95, 0.05) |
- 0.0261 |
- 0.0393 |
- 0.3034 |
- 0.0274 |
- 0.0274 |
- 0.3151 |
-
-
- (1.0, 0.0) |
- 0.0019 |
- 0.0162 |
- 0.3217 |
- 0.0311 |
- 0.0311 |
- 0.3333 |
-
-
-
-
-
-
-> train: [0.29989987 0.70010013]
-> validation: [0.30006676 0.69993324]
-> evaluate_bin_sld: 197.848s
-> evaluate_mul_sld: 55.610s
-> kfcv: 40.783s
-> atc_mc: 42.124s
-> atc_ne: 41.370s
-> doc_feat: 35.340s
-> tot: 199.287s
-
-
-
-
- |
- bin |
- mul |
- kfcv |
- atc_mc |
- atc_ne |
- doc_feat |
-
-
-
-
- (0.0, 1.0) |
- 0.0051 |
- 0.0059 |
- 0.0530 |
- 0.0059 |
- 0.0059 |
- 0.0422 |
-
-
- (0.05, 0.95) |
- 0.0108 |
- 0.0082 |
- 0.0455 |
- 0.0063 |
- 0.0063 |
- 0.0347 |
-
-
- (0.1, 0.9) |
- 0.0127 |
- 0.0110 |
- 0.0356 |
- 0.0062 |
- 0.0062 |
- 0.0250 |
-
-
- (0.15, 0.85) |
- 0.0147 |
- 0.0145 |
- 0.0265 |
- 0.0076 |
- 0.0076 |
- 0.0160 |
-
-
- (0.2, 0.8) |
- 0.0158 |
- 0.0162 |
- 0.0173 |
- 0.0071 |
- 0.0071 |
- 0.0086 |
-
-
- (0.25, 0.75) |
- 0.0147 |
- 0.0158 |
- 0.0091 |
- 0.0070 |
- 0.0070 |
- 0.0075 |
-
-
- (0.3, 0.7) |
- 0.0134 |
- 0.0162 |
- 0.0073 |
- 0.0080 |
- 0.0080 |
- 0.0127 |
-
-
- (0.35, 0.65) |
- 0.0138 |
- 0.0178 |
- 0.0132 |
- 0.0100 |
- 0.0100 |
- 0.0230 |
-
-
- (0.4, 0.6) |
- 0.0130 |
- 0.0180 |
- 0.0204 |
- 0.0096 |
- 0.0096 |
- 0.0306 |
-
-
- (0.45, 0.55) |
- 0.0102 |
- 0.0149 |
- 0.0297 |
- 0.0102 |
- 0.0102 |
- 0.0397 |
-
-
- (0.5, 0.5) |
- 0.0094 |
- 0.0160 |
- 0.0405 |
- 0.0111 |
- 0.0111 |
- 0.0504 |
-
-
- (0.55, 0.45) |
- 0.0095 |
- 0.0135 |
- 0.0516 |
- 0.0123 |
- 0.0123 |
- 0.0615 |
-
-
- (0.6, 0.4) |
- 0.0086 |
- 0.0132 |
- 0.0596 |
- 0.0122 |
- 0.0122 |
- 0.0693 |
-
-
- (0.65, 0.35) |
- 0.0102 |
- 0.0123 |
- 0.0717 |
- 0.0149 |
- 0.0149 |
- 0.0814 |
-
-
- (0.7, 0.3) |
- 0.0098 |
- 0.0115 |
- 0.0797 |
- 0.0160 |
- 0.0160 |
- 0.0894 |
-
-
- (0.75, 0.25) |
- 0.0111 |
- 0.0108 |
- 0.0880 |
- 0.0160 |
- 0.0160 |
- 0.0975 |
-
-
- (0.8, 0.2) |
- 0.0112 |
- 0.0093 |
- 0.0996 |
- 0.0206 |
- 0.0206 |
- 0.1091 |
-
-
- (0.85, 0.15) |
- 0.0149 |
- 0.0119 |
- 0.1094 |
- 0.0197 |
- 0.0197 |
- 0.1187 |
-
-
- (0.9, 0.1) |
- 0.0167 |
- 0.0137 |
- 0.1178 |
- 0.0216 |
- 0.0216 |
- 0.1271 |
-
-
- (0.95, 0.05) |
- 0.0184 |
- 0.0145 |
- 0.1275 |
- 0.0222 |
- 0.0222 |
- 0.1367 |
-
-
- (1.0, 0.0) |
- 0.0007 |
- 0.0099 |
- 0.1371 |
- 0.0238 |
- 0.0238 |
- 0.1462 |
-
-
-
-
-
-
-> train: [0.40003338 0.59996662]
-> validation: [0.40003338 0.59996662]
-> evaluate_bin_sld: 197.597s
-> evaluate_mul_sld: 55.556s
-> kfcv: 40.650s
-> atc_mc: 41.687s
-> atc_ne: 41.375s
-> doc_feat: 34.998s
-> tot: 198.892s
-
-
-
-
- |
- bin |
- mul |
- kfcv |
- atc_mc |
- atc_ne |
- doc_feat |
-
-
-
-
- (0.0, 1.0) |
- 0.0013 |
- 0.0048 |
- 0.0194 |
- 0.0071 |
- 0.0071 |
- 0.0126 |
-
-
- (0.05, 0.95) |
- 0.0076 |
- 0.0084 |
- 0.0184 |
- 0.0071 |
- 0.0071 |
- 0.0111 |
-
-
- (0.1, 0.9) |
- 0.0092 |
- 0.0107 |
- 0.0161 |
- 0.0078 |
- 0.0078 |
- 0.0093 |
-
-
- (0.15, 0.85) |
- 0.0127 |
- 0.0149 |
- 0.0134 |
- 0.0070 |
- 0.0070 |
- 0.0077 |
-
-
- (0.2, 0.8) |
- 0.0183 |
- 0.0200 |
- 0.0110 |
- 0.0066 |
- 0.0066 |
- 0.0075 |
-
-
- (0.25, 0.75) |
- 0.0208 |
- 0.0230 |
- 0.0090 |
- 0.0075 |
- 0.0075 |
- 0.0069 |
-
-
- (0.3, 0.7) |
- 0.0235 |
- 0.0260 |
- 0.0080 |
- 0.0076 |
- 0.0076 |
- 0.0073 |
-
-
- (0.35, 0.65) |
- 0.0273 |
- 0.0306 |
- 0.0065 |
- 0.0079 |
- 0.0079 |
- 0.0095 |
-
-
- (0.4, 0.6) |
- 0.0296 |
- 0.0335 |
- 0.0074 |
- 0.0072 |
- 0.0072 |
- 0.0099 |
-
-
- (0.45, 0.55) |
- 0.0283 |
- 0.0313 |
- 0.0080 |
- 0.0085 |
- 0.0085 |
- 0.0116 |
-
-
- (0.5, 0.5) |
- 0.0267 |
- 0.0317 |
- 0.0087 |
- 0.0085 |
- 0.0085 |
- 0.0147 |
-
-
- (0.55, 0.45) |
- 0.0273 |
- 0.0331 |
- 0.0131 |
- 0.0086 |
- 0.0086 |
- 0.0196 |
-
-
- (0.6, 0.4) |
- 0.0239 |
- 0.0320 |
- 0.0136 |
- 0.0082 |
- 0.0082 |
- 0.0202 |
-
-
- (0.65, 0.35) |
- 0.0208 |
- 0.0290 |
- 0.0171 |
- 0.0084 |
- 0.0084 |
- 0.0241 |
-
-
- (0.7, 0.3) |
- 0.0186 |
- 0.0288 |
- 0.0213 |
- 0.0084 |
- 0.0084 |
- 0.0281 |
-
-
- (0.75, 0.25) |
- 0.0158 |
- 0.0261 |
- 0.0219 |
- 0.0090 |
- 0.0090 |
- 0.0288 |
-
-
- (0.8, 0.2) |
- 0.0130 |
- 0.0235 |
- 0.0269 |
- 0.0089 |
- 0.0089 |
- 0.0338 |
-
-
- (0.85, 0.15) |
- 0.0084 |
- 0.0180 |
- 0.0284 |
- 0.0083 |
- 0.0083 |
- 0.0352 |
-
-
- (0.9, 0.1) |
- 0.0057 |
- 0.0134 |
- 0.0322 |
- 0.0092 |
- 0.0092 |
- 0.0390 |
-
-
- (0.95, 0.05) |
- 0.0050 |
- 0.0091 |
- 0.0339 |
- 0.0101 |
- 0.0101 |
- 0.0406 |
-
-
- (1.0, 0.0) |
- 0.0007 |
- 0.0064 |
- 0.0379 |
- 0.0106 |
- 0.0106 |
- 0.0447 |
-
-
-
-
-
-
-> train: [0.5 0.5]
-> validation: [0.5 0.5]
-> evaluate_bin_sld: 197.283s
-> evaluate_mul_sld: 54.736s
-> kfcv: 40.375s
-> atc_mc: 41.898s
-> atc_ne: 41.366s
-> doc_feat: 35.145s
-> tot: 198.630s
-
-
-
-
- |
- bin |
- mul |
- kfcv |
- atc_mc |
- atc_ne |
- doc_feat |
-
-
-
-
- (0.0, 1.0) |
- 0.0004 |
- 0.0035 |
- 0.0257 |
- 0.0289 |
- 0.0289 |
- 0.0344 |
-
-
- (0.05, 0.95) |
- 0.0075 |
- 0.0085 |
- 0.0224 |
- 0.0253 |
- 0.0253 |
- 0.0310 |
-
-
- (0.1, 0.9) |
- 0.0081 |
- 0.0122 |
- 0.0205 |
- 0.0239 |
- 0.0239 |
- 0.0292 |
-
-
- (0.15, 0.85) |
- 0.0102 |
- 0.0148 |
- 0.0180 |
- 0.0205 |
- 0.0205 |
- 0.0267 |
-
-
- (0.2, 0.8) |
- 0.0139 |
- 0.0198 |
- 0.0165 |
- 0.0211 |
- 0.0211 |
- 0.0248 |
-
-
- (0.25, 0.75) |
- 0.0194 |
- 0.0245 |
- 0.0141 |
- 0.0170 |
- 0.0170 |
- 0.0224 |
-
-
- (0.3, 0.7) |
- 0.0230 |
- 0.0287 |
- 0.0137 |
- 0.0164 |
- 0.0164 |
- 0.0222 |
-
-
- (0.35, 0.65) |
- 0.0309 |
- 0.0338 |
- 0.0132 |
- 0.0168 |
- 0.0168 |
- 0.0210 |
-
-
- (0.4, 0.6) |
- 0.0350 |
- 0.0371 |
- 0.0097 |
- 0.0144 |
- 0.0144 |
- 0.0164 |
-
-
- (0.45, 0.55) |
- 0.0358 |
- 0.0390 |
- 0.0086 |
- 0.0125 |
- 0.0125 |
- 0.0150 |
-
-
- (0.5, 0.5) |
- 0.0369 |
- 0.0386 |
- 0.0073 |
- 0.0122 |
- 0.0122 |
- 0.0138 |
-
-
- (0.55, 0.45) |
- 0.0373 |
- 0.0398 |
- 0.0071 |
- 0.0110 |
- 0.0110 |
- 0.0128 |
-
-
- (0.6, 0.4) |
- 0.0368 |
- 0.0398 |
- 0.0064 |
- 0.0085 |
- 0.0085 |
- 0.0103 |
-
-
- (0.65, 0.35) |
- 0.0357 |
- 0.0385 |
- 0.0074 |
- 0.0103 |
- 0.0103 |
- 0.0105 |
-
-
- (0.7, 0.3) |
- 0.0319 |
- 0.0370 |
- 0.0067 |
- 0.0082 |
- 0.0082 |
- 0.0086 |
-
-
- (0.75, 0.25) |
- 0.0298 |
- 0.0358 |
- 0.0079 |
- 0.0066 |
- 0.0066 |
- 0.0070 |
-
-
- (0.8, 0.2) |
- 0.0235 |
- 0.0302 |
- 0.0073 |
- 0.0083 |
- 0.0083 |
- 0.0069 |
-
-
- (0.85, 0.15) |
- 0.0154 |
- 0.0244 |
- 0.0097 |
- 0.0077 |
- 0.0077 |
- 0.0066 |
-
-
- (0.9, 0.1) |
- 0.0083 |
- 0.0157 |
- 0.0108 |
- 0.0082 |
- 0.0082 |
- 0.0069 |
-
-
- (0.95, 0.05) |
- 0.0055 |
- 0.0098 |
- 0.0131 |
- 0.0080 |
- 0.0080 |
- 0.0066 |
-
-
- (1.0, 0.0) |
- 0.0007 |
- 0.0046 |
- 0.0145 |
- 0.0088 |
- 0.0088 |
- 0.0082 |
-
-
-
-
-
-
-> train: [0.59996662 0.40003338]
-> validation: [0.59996662 0.40003338]
-> evaluate_bin_sld: 194.960s
-> evaluate_mul_sld: 53.330s
-> kfcv: 40.320s
-> atc_mc: 41.904s
-> atc_ne: 41.423s
-> doc_feat: 35.289s
-> tot: 196.151s
-
-
-
-
- |
- bin |
- mul |
- kfcv |
- atc_mc |
- atc_ne |
- doc_feat |
-
-
-
-
- (0.0, 1.0) |
- 0.0003 |
- 0.0055 |
- 0.0815 |
- 0.0285 |
- 0.0285 |
- 0.0825 |
-
-
- (0.05, 0.95) |
- 0.0065 |
- 0.0127 |
- 0.0747 |
- 0.0278 |
- 0.0278 |
- 0.0758 |
-
-
- (0.1, 0.9) |
- 0.0072 |
- 0.0172 |
- 0.0677 |
- 0.0224 |
- 0.0224 |
- 0.0688 |
-
-
- (0.15, 0.85) |
- 0.0100 |
- 0.0257 |
- 0.0627 |
- 0.0218 |
- 0.0218 |
- 0.0638 |
-
-
- (0.2, 0.8) |
- 0.0135 |
- 0.0308 |
- 0.0548 |
- 0.0180 |
- 0.0180 |
- 0.0560 |
-
-
- (0.25, 0.75) |
- 0.0165 |
- 0.0338 |
- 0.0491 |
- 0.0160 |
- 0.0160 |
- 0.0503 |
-
-
- (0.3, 0.7) |
- 0.0205 |
- 0.0409 |
- 0.0438 |
- 0.0168 |
- 0.0168 |
- 0.0450 |
-
-
- (0.35, 0.65) |
- 0.0248 |
- 0.0459 |
- 0.0374 |
- 0.0156 |
- 0.0156 |
- 0.0386 |
-
-
- (0.4, 0.6) |
- 0.0284 |
- 0.0491 |
- 0.0277 |
- 0.0112 |
- 0.0112 |
- 0.0290 |
-
-
- (0.45, 0.55) |
- 0.0318 |
- 0.0515 |
- 0.0224 |
- 0.0099 |
- 0.0099 |
- 0.0237 |
-
-
- (0.5, 0.5) |
- 0.0342 |
- 0.0516 |
- 0.0159 |
- 0.0081 |
- 0.0081 |
- 0.0170 |
-
-
- (0.55, 0.45) |
- 0.0374 |
- 0.0519 |
- 0.0111 |
- 0.0073 |
- 0.0073 |
- 0.0121 |
-
-
- (0.6, 0.4) |
- 0.0410 |
- 0.0537 |
- 0.0069 |
- 0.0079 |
- 0.0079 |
- 0.0075 |
-
-
- (0.65, 0.35) |
- 0.0444 |
- 0.0517 |
- 0.0064 |
- 0.0076 |
- 0.0076 |
- 0.0064 |
-
-
- (0.7, 0.3) |
- 0.0438 |
- 0.0502 |
- 0.0100 |
- 0.0085 |
- 0.0085 |
- 0.0090 |
-
-
- (0.75, 0.25) |
- 0.0458 |
- 0.0483 |
- 0.0171 |
- 0.0089 |
- 0.0089 |
- 0.0157 |
-
-
- (0.8, 0.2) |
- 0.0412 |
- 0.0419 |
- 0.0218 |
- 0.0105 |
- 0.0105 |
- 0.0204 |
-
-
- (0.85, 0.15) |
- 0.0319 |
- 0.0348 |
- 0.0291 |
- 0.0117 |
- 0.0117 |
- 0.0276 |
-
-
- (0.9, 0.1) |
- 0.0192 |
- 0.0254 |
- 0.0358 |
- 0.0147 |
- 0.0147 |
- 0.0343 |
-
-
- (0.95, 0.05) |
- 0.0079 |
- 0.0154 |
- 0.0427 |
- 0.0166 |
- 0.0166 |
- 0.0412 |
-
-
- (1.0, 0.0) |
- 0.0005 |
- 0.0034 |
- 0.0490 |
- 0.0190 |
- 0.0190 |
- 0.0474 |
-
-
-
-
-
-
-> train: [0.69993324 0.30006676]
-> validation: [0.70010013 0.29989987]
-> evaluate_bin_sld: 196.856s
-> evaluate_mul_sld: 54.245s
-> kfcv: 41.167s
-> atc_mc: 42.203s
-> atc_ne: 41.565s
-> doc_feat: 34.998s
-> tot: 198.332s
-
-
-
-
- |
- bin |
- mul |
- kfcv |
- atc_mc |
- atc_ne |
- doc_feat |
-
-
-
-
- (0.0, 1.0) |
- 0.0003 |
- 0.0071 |
- 0.1570 |
- 0.0625 |
- 0.0625 |
- 0.1677 |
-
-
- (0.05, 0.95) |
- 0.0089 |
- 0.0102 |
- 0.1428 |
- 0.0548 |
- 0.0548 |
- 0.1536 |
-
-
- (0.1, 0.9) |
- 0.0078 |
- 0.0121 |
- 0.1327 |
- 0.0521 |
- 0.0521 |
- 0.1435 |
-
-
- (0.15, 0.85) |
- 0.0073 |
- 0.0155 |
- 0.1227 |
- 0.0517 |
- 0.0517 |
- 0.1336 |
-
-
- (0.2, 0.8) |
- 0.0081 |
- 0.0196 |
- 0.1094 |
- 0.0464 |
- 0.0464 |
- 0.1203 |
-
-
- (0.25, 0.75) |
- 0.0095 |
- 0.0225 |
- 0.1001 |
- 0.0427 |
- 0.0427 |
- 0.1111 |
-
-
- (0.3, 0.7) |
- 0.0117 |
- 0.0272 |
- 0.0885 |
- 0.0400 |
- 0.0400 |
- 0.0995 |
-
-
- (0.35, 0.65) |
- 0.0131 |
- 0.0309 |
- 0.0774 |
- 0.0368 |
- 0.0368 |
- 0.0885 |
-
-
- (0.4, 0.6) |
- 0.0144 |
- 0.0333 |
- 0.0626 |
- 0.0307 |
- 0.0307 |
- 0.0737 |
-
-
- (0.45, 0.55) |
- 0.0179 |
- 0.0365 |
- 0.0528 |
- 0.0297 |
- 0.0297 |
- 0.0640 |
-
-
- (0.5, 0.5) |
- 0.0183 |
- 0.0359 |
- 0.0418 |
- 0.0259 |
- 0.0259 |
- 0.0531 |
-
-
- (0.55, 0.45) |
- 0.0189 |
- 0.0369 |
- 0.0313 |
- 0.0222 |
- 0.0222 |
- 0.0426 |
-
-
- (0.6, 0.4) |
- 0.0220 |
- 0.0379 |
- 0.0201 |
- 0.0190 |
- 0.0190 |
- 0.0314 |
-
-
- (0.65, 0.35) |
- 0.0218 |
- 0.0364 |
- 0.0104 |
- 0.0160 |
- 0.0160 |
- 0.0208 |
-
-
- (0.7, 0.3) |
- 0.0229 |
- 0.0371 |
- 0.0067 |
- 0.0119 |
- 0.0119 |
- 0.0096 |
-
-
- (0.75, 0.25) |
- 0.0250 |
- 0.0378 |
- 0.0161 |
- 0.0101 |
- 0.0101 |
- 0.0067 |
-
-
- (0.8, 0.2) |
- 0.0237 |
- 0.0333 |
- 0.0259 |
- 0.0082 |
- 0.0082 |
- 0.0143 |
-
-
- (0.85, 0.15) |
- 0.0227 |
- 0.0282 |
- 0.0381 |
- 0.0060 |
- 0.0060 |
- 0.0265 |
-
-
- (0.9, 0.1) |
- 0.0180 |
- 0.0202 |
- 0.0499 |
- 0.0049 |
- 0.0049 |
- 0.0382 |
-
-
- (0.95, 0.05) |
- 0.0097 |
- 0.0117 |
- 0.0607 |
- 0.0072 |
- 0.0072 |
- 0.0489 |
-
-
- (1.0, 0.0) |
- 0.0014 |
- 0.0024 |
- 0.0724 |
- 0.0103 |
- 0.0103 |
- 0.0606 |
-
-
-
-
-
-
-> train: [0.79989987 0.20010013]
-> validation: [0.80006676 0.19993324]
-> evaluate_bin_sld: 197.725s
-> evaluate_mul_sld: 53.526s
-> kfcv: 40.971s
-> atc_mc: 41.975s
-> atc_ne: 41.358s
-> doc_feat: 35.091s
-> tot: 199.051s
-
-
-
-
- |
- bin |
- mul |
- kfcv |
- atc_mc |
- atc_ne |
- doc_feat |
-
-
-
-
- (0.0, 1.0) |
- 0.0009 |
- 0.0082 |
- 0.3148 |
- 0.0571 |
- 0.0571 |
- 0.3213 |
-
-
- (0.05, 0.95) |
- 0.0297 |
- 0.0223 |
- 0.2925 |
- 0.0492 |
- 0.0492 |
- 0.2991 |
-
-
- (0.1, 0.9) |
- 0.0283 |
- 0.0209 |
- 0.2733 |
- 0.0493 |
- 0.0493 |
- 0.2800 |
-
-
- (0.15, 0.85) |
- 0.0247 |
- 0.0182 |
- 0.2528 |
- 0.0447 |
- 0.0447 |
- 0.2596 |
-
-
- (0.2, 0.8) |
- 0.0216 |
- 0.0156 |
- 0.2328 |
- 0.0407 |
- 0.0407 |
- 0.2397 |
-
-
- (0.25, 0.75) |
- 0.0170 |
- 0.0136 |
- 0.2136 |
- 0.0425 |
- 0.0425 |
- 0.2205 |
-
-
- (0.3, 0.7) |
- 0.0146 |
- 0.0126 |
- 0.1941 |
- 0.0384 |
- 0.0384 |
- 0.2012 |
-
-
- (0.35, 0.65) |
- 0.0125 |
- 0.0113 |
- 0.1734 |
- 0.0331 |
- 0.0331 |
- 0.1806 |
-
-
- (0.4, 0.6) |
- 0.0113 |
- 0.0110 |
- 0.1510 |
- 0.0272 |
- 0.0272 |
- 0.1583 |
-
-
- (0.45, 0.55) |
- 0.0093 |
- 0.0135 |
- 0.1328 |
- 0.0247 |
- 0.0247 |
- 0.1402 |
-
-
- (0.5, 0.5) |
- 0.0088 |
- 0.0135 |
- 0.1131 |
- 0.0222 |
- 0.0222 |
- 0.1206 |
-
-
- (0.55, 0.45) |
- 0.0092 |
- 0.0155 |
- 0.0919 |
- 0.0207 |
- 0.0207 |
- 0.0995 |
-
-
- (0.6, 0.4) |
- 0.0092 |
- 0.0173 |
- 0.0742 |
- 0.0190 |
- 0.0190 |
- 0.0819 |
-
-
- (0.65, 0.35) |
- 0.0087 |
- 0.0178 |
- 0.0544 |
- 0.0161 |
- 0.0161 |
- 0.0621 |
-
-
- (0.7, 0.3) |
- 0.0093 |
- 0.0197 |
- 0.0323 |
- 0.0124 |
- 0.0124 |
- 0.0401 |
-
-
- (0.75, 0.25) |
- 0.0101 |
- 0.0218 |
- 0.0114 |
- 0.0093 |
- 0.0093 |
- 0.0187 |
-
-
- (0.8, 0.2) |
- 0.0117 |
- 0.0208 |
- 0.0098 |
- 0.0088 |
- 0.0088 |
- 0.0063 |
-
-
- (0.85, 0.15) |
- 0.0103 |
- 0.0178 |
- 0.0285 |
- 0.0064 |
- 0.0064 |
- 0.0204 |
-
-
- (0.9, 0.1) |
- 0.0103 |
- 0.0164 |
- 0.0480 |
- 0.0062 |
- 0.0062 |
- 0.0398 |
-
-
- (0.95, 0.05) |
- 0.0092 |
- 0.0117 |
- 0.0684 |
- 0.0071 |
- 0.0071 |
- 0.0601 |
-
-
- (1.0, 0.0) |
- 0.0011 |
- 0.0019 |
- 0.0887 |
- 0.0097 |
- 0.0097 |
- 0.0803 |
-
-
-
-
-
-
-> train: [0.90003338 0.09996662]
-> validation: [0.90003338 0.09996662]
-> evaluate_bin_sld: 201.315s
-> evaluate_mul_sld: 50.974s
-> kfcv: 40.175s
-> atc_mc: 41.663s
-> atc_ne: 41.058s
-> doc_feat: 35.055s
-> tot: 202.573s
-
-
-
-
- |
- bin |
- mul |
- kfcv |
- atc_mc |
- atc_ne |
- doc_feat |
-
-
-
-
- (0.0, 1.0) |
- 0.0321 |
- 0.0184 |
- 0.6421 |
- 0.1336 |
- 0.1336 |
- 0.6454 |
-
-
- (0.05, 0.95) |
- 0.0835 |
- 0.0729 |
- 0.6056 |
- 0.1244 |
- 0.1244 |
- 0.6090 |
-
-
- (0.1, 0.9) |
- 0.1080 |
- 0.0976 |
- 0.5703 |
- 0.1204 |
- 0.1204 |
- 0.5739 |
-
-
- (0.15, 0.85) |
- 0.1154 |
- 0.0971 |
- 0.5354 |
- 0.1147 |
- 0.1147 |
- 0.5390 |
-
-
- (0.2, 0.8) |
- 0.1081 |
- 0.0916 |
- 0.5007 |
- 0.1064 |
- 0.1064 |
- 0.5045 |
-
-
- (0.25, 0.75) |
- 0.1032 |
- 0.0830 |
- 0.4632 |
- 0.1005 |
- 0.1005 |
- 0.4671 |
-
-
- (0.3, 0.7) |
- 0.0945 |
- 0.0775 |
- 0.4274 |
- 0.0916 |
- 0.0916 |
- 0.4313 |
-
-
- (0.35, 0.65) |
- 0.0966 |
- 0.0709 |
- 0.3914 |
- 0.0843 |
- 0.0843 |
- 0.3954 |
-
-
- (0.4, 0.6) |
- 0.0795 |
- 0.0639 |
- 0.3543 |
- 0.0748 |
- 0.0748 |
- 0.3584 |
-
-
- (0.45, 0.55) |
- 0.0735 |
- 0.0533 |
- 0.3210 |
- 0.0728 |
- 0.0728 |
- 0.3253 |
-
-
- (0.5, 0.5) |
- 0.0716 |
- 0.0473 |
- 0.2829 |
- 0.0633 |
- 0.0633 |
- 0.2873 |
-
-
- (0.55, 0.45) |
- 0.0550 |
- 0.0393 |
- 0.2465 |
- 0.0568 |
- 0.0568 |
- 0.2509 |
-
-
- (0.6, 0.4) |
- 0.0505 |
- 0.0317 |
- 0.2117 |
- 0.0509 |
- 0.0509 |
- 0.2162 |
-
-
- (0.65, 0.35) |
- 0.0403 |
- 0.0226 |
- 0.1741 |
- 0.0438 |
- 0.0438 |
- 0.1788 |
-
-
- (0.7, 0.3) |
- 0.0372 |
- 0.0178 |
- 0.1387 |
- 0.0348 |
- 0.0348 |
- 0.1434 |
-
-
- (0.75, 0.25) |
- 0.0262 |
- 0.0122 |
- 0.1009 |
- 0.0256 |
- 0.0256 |
- 0.1057 |
-
-
- (0.8, 0.2) |
- 0.0248 |
- 0.0110 |
- 0.0651 |
- 0.0194 |
- 0.0194 |
- 0.0701 |
-
-
- (0.85, 0.15) |
- 0.0181 |
- 0.0075 |
- 0.0298 |
- 0.0128 |
- 0.0128 |
- 0.0348 |
-
-
- (0.9, 0.1) |
- 0.0129 |
- 0.0093 |
- 0.0069 |
- 0.0080 |
- 0.0080 |
- 0.0037 |
-
-
- (0.95, 0.05) |
- 0.0077 |
- 0.0085 |
- 0.0426 |
- 0.0046 |
- 0.0046 |
- 0.0373 |
-
-
- (1.0, 0.0) |
- 0.0010 |
- 0.0010 |
- 0.0789 |
- 0.0088 |
- 0.0088 |
- 0.0735 |
-
-
-
-
-
-
diff --git a/poetry.lock b/poetry.lock
index 7cb982e..7d7365d 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -956,6 +956,65 @@ files = [
{file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"},
]
+[[package]]
+name = "pyyaml"
+version = "6.0.1"
+description = "YAML parser and emitter for Python"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
+ {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
+ {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
+ {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
+ {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
+ {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
+ {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
+ {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
+ {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
+ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
+ {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
+ {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"},
+ {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"},
+ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
+ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
+ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
+ {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
+ {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
+ {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
+ {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
+ {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
+ {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
+ {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
+ {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
+ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
+]
+
[[package]]
name = "quapy"
version = "0.1.7"
@@ -1164,4 +1223,4 @@ test = ["pytest", "pytest-cov"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
-content-hash = "72e3afd9a24b88fc8a8f5f55e1c408f65090fce9015a442f6f41638191276b6f"
+content-hash = "0ce0e6b058900e7db2939e7eb047a1f868c88de67def370c1c1fa0ba532df0b0"
diff --git a/pyproject.toml b/pyproject.toml
index 9805ca9..d9ce79a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,6 +10,7 @@ python = "^3.11"
quapy = "^0.1.7"
pandas = "^2.0.3"
jinja2 = "^3.1.2"
+pyyaml = "^6.0.1"
[tool.poetry.scripts]
main = "quacc.main:main"
diff --git a/quacc/environ.py b/quacc/environ.py
index cc2f13c..1177964 100644
--- a/quacc/environ.py
+++ b/quacc/environ.py
@@ -1,21 +1,33 @@
-from pathlib import Path
+import yaml
defalut_env = {
"DATASET_NAME": "rcv1",
"DATASET_TARGET": "CCAT",
+ "METRICS": ["acc", "f1"],
"COMP_ESTIMATORS": [
- "OUR_BIN_SLD",
- "OUR_MUL_SLD",
- "KFCV",
- "ATC_MC",
- "ATC_NE",
- "DOC_FEAT",
- # "RCA",
- # "RCA_STAR",
+ "our_bin_SLD",
+ "our_bin_SLD_nbvs",
+ "our_bin_SLD_bcts",
+ "our_bin_SLD_ts",
+ "our_bin_SLD_vs",
+ "our_bin_CC",
+ "our_mul_SLD",
+ "our_mul_SLD_nbvs",
+ "our_mul_SLD_bcts",
+ "our_mul_SLD_ts",
+ "our_mul_SLD_vs",
+ "our_mul_CC",
+ "ref",
+ "kfcv",
+ "atc_mc",
+ "atc_ne",
+ "doc_feat",
+ "rca",
+ "rca_star",
],
"DATASET_N_PREVS": 9,
- "OUT_DIR": Path("out"),
- "PLOT_OUT_DIR": Path("out/plot"),
+ "OUT_DIR_NAME": "output",
+ "PLOT_DIR_NAME": "plot",
"PROTOCOL_N_PREVS": 21,
"PROTOCOL_REPEATS": 100,
"SAMPLE_SIZE": 1000,
@@ -24,8 +36,37 @@ defalut_env = {
class Environ:
def __init__(self, **kwargs):
- for k, v in kwargs.items():
+ self.exec = []
+ self.confs = {}
+ self.__setdict(kwargs)
+
+ def __setdict(self, d):
+ for k, v in d.items():
self.__setattr__(k, v)
+ def load_conf(self):
+ with open("conf.yaml", "r") as f:
+ confs = yaml.safe_load(f)
+
+ for common in confs["commons"]:
+ name = common["DATASET_NAME"]
+ if "DATASET_TARGET" in common:
+ name += "_" + common["DATASET_TARGET"]
+ for k, d in confs["confs"].items():
+ _k = f"{name}_{k}"
+ self.confs[_k] = common | d
+ self.exec.append(_k)
+
+ if "exec" in confs:
+ if len(confs["exec"]) > 0:
+ self.exec = confs["exec"]
+
+ def __iter__(self):
+ self.load_conf()
+ for _conf in self.exec:
+ if _conf in self.confs:
+ self.__setdict(self.confs[_conf])
+ yield _conf
+
env = Environ(**defalut_env)
diff --git a/quacc/error.py b/quacc/error.py
index 116cc42..6ed7dd4 100644
--- a/quacc/error.py
+++ b/quacc/error.py
@@ -1,13 +1,15 @@
import quapy as qp
+
def from_name(err_name):
- if err_name == 'f1e':
+ if err_name == "f1e":
return f1e
- elif err_name == 'f1':
+ elif err_name == "f1":
return f1
else:
return qp.error.from_name(err_name)
-
+
+
# def f1(prev):
# # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
# if prev[0] == 0 and prev[1] == 0 and prev[2] == 0:
@@ -18,18 +20,21 @@ def from_name(err_name):
# return float('NaN')
# else:
# recall = prev[0] / (prev[0] + prev[1])
-# precision = prev[0] / (prev[0] + prev[2])
+# precision = prev[0] / (prev[0] + prev[2])
# return 2 * (precision * recall) / (precision + recall)
+
def f1(prev):
- den = (2*prev[3]) + prev[1] + prev[2]
+ den = (2 * prev[3]) + prev[1] + prev[2]
if den == 0:
return 0.0
else:
- return (2*prev[3])/den
+ return (2 * prev[3]) / den
+
def f1e(prev):
return 1 - f1(prev)
+
def acc(prev):
- return (prev[1] + prev[2]) / sum(prev)
\ No newline at end of file
+ return (prev[0] + prev[3]) / sum(prev)
diff --git a/quacc/estimator.py b/quacc/estimator.py
index 4516b6d..2f9a92c 100644
--- a/quacc/estimator.py
+++ b/quacc/estimator.py
@@ -1,9 +1,9 @@
-from abc import abstractmethod
import math
+from abc import abstractmethod
import numpy as np
from quapy.data import LabelledCollection
-from quapy.method.aggregative import SLD
+from quapy.method.aggregative import CC, SLD
from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
@@ -15,7 +15,7 @@ class AccuracyEstimator:
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
if not pred_proba:
pred_proba = self.c_model.predict_proba(base.X)
- return ExtendedCollection.extend_collection(base, pred_proba)
+ return ExtendedCollection.extend_collection(base, pred_proba), pred_proba
@abstractmethod
def fit(self, train: LabelledCollection | ExtendedCollection):
@@ -27,9 +27,15 @@ class AccuracyEstimator:
class MulticlassAccuracyEstimator(AccuracyEstimator):
- def __init__(self, c_model: BaseEstimator):
+ def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs):
self.c_model = c_model
- self.q_model = SLD(LogisticRegression())
+ if q_model == "SLD":
+ available_args = ["recalib"]
+ sld_args = {k: v for k, v in kwargs.items() if k in available_args}
+ self.q_model = SLD(LogisticRegression(), **sld_args)
+ elif q_model == "CC":
+ self.q_model = CC(LogisticRegression())
+
self.e_train = None
def fit(self, train: LabelledCollection | ExtendedCollection):
@@ -67,10 +73,17 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
- def __init__(self, c_model: BaseEstimator):
+ def __init__(self, c_model: BaseEstimator, q_model="SLD", **kwargs):
self.c_model = c_model
- self.q_model_0 = SLD(LogisticRegression())
- self.q_model_1 = SLD(LogisticRegression())
+ if q_model == "SLD":
+ available_args = ["recalib"]
+ sld_args = {k: v for k, v in kwargs.items() if k in available_args}
+ self.q_model_0 = SLD(LogisticRegression(), **sld_args)
+ self.q_model_1 = SLD(LogisticRegression(), **sld_args)
+ elif q_model == "CC":
+ self.q_model_0 = CC(LogisticRegression())
+ self.q_model_1 = CC(LogisticRegression())
+
self.e_train = None
def fit(self, train: LabelledCollection | ExtendedCollection):
@@ -83,7 +96,7 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
elif isinstance(train, ExtendedCollection):
- self.e_train = train
+ self.e_train = train
self.n_classes = self.e_train.n_classes
[e_train_0, e_train_1] = self.e_train.split_by_pred()
diff --git a/quacc/evaluation/baseline.py b/quacc/evaluation/baseline.py
index f4e969d..e36a492 100644
--- a/quacc/evaluation/baseline.py
+++ b/quacc/evaluation/baseline.py
@@ -34,14 +34,14 @@ def kfcv(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
- report = EvaluationReport(prefix="kfcv")
+ report = EvaluationReport(name="kfcv")
for test in protocol():
test_preds = c_model_predict(test.X)
meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds))
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds))
report.append_row(
test.prevalence(),
- acc_score=(1.0 - acc_score),
+ acc_score=acc_score,
f1_score=f1_score,
acc=meta_acc,
f1=meta_f1,
@@ -57,13 +57,13 @@ def reference(
):
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
c_model_predict = getattr(c_model, "predict_proba")
- report = EvaluationReport(prefix="ref")
+ report = EvaluationReport(name="ref")
for test in protocol():
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
report.append_row(
test.prevalence(),
- acc_score=(1 - metrics.accuracy_score(test.y, test_preds)),
+ acc_score=metrics.accuracy_score(test.y, test_preds),
f1_score=metrics.f1_score(test.y, test_preds),
)
@@ -89,7 +89,7 @@ def atc_mc(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
- report = EvaluationReport(prefix="atc_mc")
+ report = EvaluationReport(name="atc_mc")
for test in protocol():
## Load OOD test data probs
test_probs = c_model_predict(test.X)
@@ -102,7 +102,7 @@ def atc_mc(
report.append_row(
test.prevalence(),
acc=meta_acc,
- acc_score=1.0 - atc_accuracy,
+ acc_score=atc_accuracy,
f1_score=f1_score,
f1=meta_f1,
)
@@ -129,7 +129,7 @@ def atc_ne(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
- report = EvaluationReport(prefix="atc_ne")
+ report = EvaluationReport(name="atc_ne")
for test in protocol():
## Load OOD test data probs
test_probs = c_model_predict(test.X)
@@ -142,7 +142,7 @@ def atc_ne(
report.append_row(
test.prevalence(),
acc=meta_acc,
- acc_score=(1.0 - atc_accuracy),
+ acc_score=atc_accuracy,
f1_score=f1_score,
f1=meta_f1,
)
@@ -182,14 +182,14 @@ def doc_feat(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
- report = EvaluationReport(prefix="doc_feat")
+ report = EvaluationReport(name="doc_feat")
for test in protocol():
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
test_scores = np.max(test_probs, axis=-1)
score = (v1acc + doc.get_doc(val_scores, test_scores)) / 100.0
meta_acc = abs(score - metrics.accuracy_score(test.y, test_preds))
- report.append_row(test.prevalence(), acc=meta_acc, acc_score=(1.0 - score))
+ report.append_row(test.prevalence(), acc=meta_acc, acc_score=score)
return report
@@ -206,17 +206,15 @@ def rca_score(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
- report = EvaluationReport(prefix="rca")
+ report = EvaluationReport(name="rca")
for test in protocol():
try:
test_pred = c_model_predict(test.X)
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
c_model2_predict = getattr(c_model2, predict_method)
val_pred2 = c_model2_predict(validation.X)
- rca_score = rca.get_score(val_pred1, val_pred2, validation.y)
- meta_score = abs(
- rca_score - (1 - metrics.accuracy_score(test.y, test_pred))
- )
+ rca_score = 1.0 - rca.get_score(val_pred1, val_pred2, validation.y)
+ meta_score = abs(rca_score - metrics.accuracy_score(test.y, test_pred))
report.append_row(test.prevalence(), acc=meta_score, acc_score=rca_score)
except ValueError:
report.append_row(
@@ -244,17 +242,15 @@ def rca_star_score(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
- report = EvaluationReport(prefix="rca_star")
+ report = EvaluationReport(name="rca_star")
for test in protocol():
try:
test_pred = c_model_predict(test.X)
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
c_model2_predict = getattr(c_model2, predict_method)
val2_pred2 = c_model2_predict(validation2.X)
- rca_star_score = rca.get_score(val2_pred1, val2_pred2, validation2.y)
- meta_score = abs(
- rca_star_score - (1 - metrics.accuracy_score(test.y, test_pred))
- )
+ rca_star_score = 1.0 - rca.get_score(val2_pred1, val2_pred2, validation2.y)
+ meta_score = abs(rca_star_score - metrics.accuracy_score(test.y, test_pred))
report.append_row(
test.prevalence(), acc=meta_score, acc_score=rca_star_score
)
diff --git a/quacc/evaluation/comp.py b/quacc/evaluation/comp.py
index ccc4e18..b8c403b 100644
--- a/quacc/evaluation/comp.py
+++ b/quacc/evaluation/comp.py
@@ -1,5 +1,6 @@
import multiprocessing
import time
+import traceback
from typing import List
import pandas as pd
@@ -19,14 +20,25 @@ pd.set_option("display.float_format", "{:.4f}".format)
class CompEstimator:
__dict = {
- "OUR_BIN_SLD": method.evaluate_bin_sld,
- "OUR_MUL_SLD": method.evaluate_mul_sld,
- "KFCV": baseline.kfcv,
- "ATC_MC": baseline.atc_mc,
- "ATC_NE": baseline.atc_ne,
- "DOC_FEAT": baseline.doc_feat,
- "RCA": baseline.rca_score,
- "RCA_STAR": baseline.rca_star_score,
+ "our_bin_SLD": method.evaluate_bin_sld,
+ "our_mul_SLD": method.evaluate_mul_sld,
+ "our_bin_SLD_nbvs": method.evaluate_bin_sld_nbvs,
+ "our_mul_SLD_nbvs": method.evaluate_mul_sld_nbvs,
+ "our_bin_SLD_bcts": method.evaluate_bin_sld_bcts,
+ "our_mul_SLD_bcts": method.evaluate_mul_sld_bcts,
+ "our_bin_SLD_ts": method.evaluate_bin_sld_ts,
+ "our_mul_SLD_ts": method.evaluate_mul_sld_ts,
+ "our_bin_SLD_vs": method.evaluate_bin_sld_vs,
+ "our_mul_SLD_vs": method.evaluate_mul_sld_vs,
+ "our_bin_CC": method.evaluate_bin_cc,
+ "our_mul_CC": method.evaluate_mul_cc,
+ "ref": baseline.reference,
+ "kfcv": baseline.kfcv,
+ "atc_mc": baseline.atc_mc,
+ "atc_ne": baseline.atc_ne,
+ "doc_feat": baseline.doc_feat,
+ "rca": baseline.rca_score,
+ "rca_star": baseline.rca_star_score,
}
def __class_getitem__(cls, e: str | List[str]):
@@ -55,7 +67,17 @@ def fit_and_estimate(_estimate, train, validation, test):
test, n_prevalences=env.PROTOCOL_N_PREVS, repeats=env.PROTOCOL_REPEATS
)
start = time.time()
- result = _estimate(model, validation, protocol)
+ try:
+ result = _estimate(model, validation, protocol)
+ except Exception as e:
+ print(f"Method {_estimate.__name__} failed.")
+ traceback(e)
+ return {
+ "name": _estimate.__name__,
+ "result": None,
+ "time": 0,
+ }
+
end = time.time()
print(f"{_estimate.__name__}: {end-start:.2f}s")
@@ -69,22 +91,33 @@ def fit_and_estimate(_estimate, train, validation, test):
def evaluate_comparison(
dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"]
) -> EvaluationReport:
- with multiprocessing.Pool(8) as pool:
+ with multiprocessing.Pool(len(estimators)) as pool:
dr = DatasetReport(dataset.name)
for d in dataset():
print(f"train prev.: {d.train_prev}")
start = time.time()
tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]]
results = [pool.apply_async(fit_and_estimate, t) for t in tasks]
- results = list(map(lambda r: r.get(), results))
+
+ results_got = []
+ for _r in results:
+ try:
+ r = _r.get()
+ if r["result"] is not None:
+ results_got.append(r)
+ except Exception as e:
+ print(e)
+
er = EvaluationReport.combine_reports(
- *list(map(lambda r: r["result"], results)), name=dataset.name
+ *[r["result"] for r in results_got],
+ name=dataset.name,
+ train_prev=d.train_prev,
+ valid_prev=d.validation_prev,
)
- times = {r["name"]: r["time"] for r in results}
+ times = {r["name"]: r["time"] for r in results_got}
end = time.time()
times["tot"] = end - start
er.times = times
- er.train_prevs = d.prevs
dr.add(er)
print()
diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py
index e42f203..67f8878 100644
--- a/quacc/evaluation/method.py
+++ b/quacc/evaluation/method.py
@@ -1,3 +1,5 @@
+import numpy as np
+import sklearn.metrics as metrics
from quapy.data import LabelledCollection
from quapy.protocol import (
AbstractStochasticSeededProtocol,
@@ -22,15 +24,17 @@ def estimate(
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
- base_prevs, true_prevs, estim_prevs = [], [], []
+ base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], []
for sample in protocol():
- e_sample = estimator.extend(sample)
+ e_sample, pred_proba = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
base_prevs.append(sample.prevalence())
true_prevs.append(e_sample.prevalence())
estim_prevs.append(estim_prev)
+ pred_probas.append(pred_proba)
+ labels.append(sample.y)
- return base_prevs, true_prevs, estim_prevs
+ return base_prevs, true_prevs, estim_prevs, pred_probas, labels
def evaluation_report(
@@ -38,16 +42,21 @@ def evaluation_report(
protocol: AbstractStochasticSeededProtocol,
method: str,
) -> EvaluationReport:
- base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
- report = EvaluationReport(prefix=method)
+ base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate(
+ estimator, protocol
+ )
+ report = EvaluationReport(name=method)
- for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
+ for base_prev, true_prev, estim_prev, pred_proba, label in zip(
+ base_prevs, true_prevs, estim_prevs, pred_probas, labels
+ ):
+ pred = np.argmax(pred_proba, axis=-1)
acc_score = error.acc(estim_prev)
f1_score = error.f1(estim_prev)
report.append_row(
base_prev,
- acc_score=1.0 - acc_score,
- acc=abs(error.acc(true_prev) - acc_score),
+ acc_score=acc_score,
+ acc=abs(metrics.accuracy_score(label, pred) - acc_score),
f1_score=f1_score,
f1=abs(error.f1(true_prev) - f1_score),
)
@@ -60,13 +69,18 @@ def evaluate(
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
method: str,
+ q_model: str,
+ **kwargs,
):
estimator: AccuracyEstimator = {
"bin": BinaryQuantifierAccuracyEstimator,
"mul": MulticlassAccuracyEstimator,
- }[method](c_model)
+ }[method](c_model, q_model=q_model, **kwargs)
estimator.fit(validation)
- return evaluation_report(estimator, protocol, method)
+ _method = f"{method}_{q_model}"
+ for k, v in kwargs.items():
+ _method += f"_{v}"
+ return evaluation_report(estimator, protocol, _method)
def evaluate_bin_sld(
@@ -74,7 +88,7 @@ def evaluate_bin_sld(
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
- return evaluate(c_model, validation, protocol, "bin")
+ return evaluate(c_model, validation, protocol, "bin", "SLD")
def evaluate_mul_sld(
@@ -82,4 +96,84 @@ def evaluate_mul_sld(
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
) -> EvaluationReport:
- return evaluate(c_model, validation, protocol, "mul")
+ return evaluate(c_model, validation, protocol, "mul", "SLD")
+
+
+def evaluate_bin_sld_nbvs(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="nbvs")
+
+
+def evaluate_mul_sld_nbvs(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="nbvs")
+
+
+def evaluate_bin_sld_bcts(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="bcts")
+
+
+def evaluate_mul_sld_bcts(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="bcts")
+
+
+def evaluate_bin_sld_ts(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="ts")
+
+
+def evaluate_mul_sld_ts(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="ts")
+
+
+def evaluate_bin_sld_vs(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="vs")
+
+
+def evaluate_mul_sld_vs(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="vs")
+
+
+def evaluate_bin_cc(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "bin", "CC")
+
+
+def evaluate_mul_cc(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "mul", "CC")
diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py
index ff8862f..3d14203 100644
--- a/quacc/evaluation/report.py
+++ b/quacc/evaluation/report.py
@@ -1,22 +1,24 @@
-import statistics as stats
+from pathlib import Path
from typing import List, Tuple
import numpy as np
import pandas as pd
from quacc import plot
+from quacc.environ import env
from quacc.utils import fmt_line_md
class EvaluationReport:
- def __init__(self, prefix=None):
+ def __init__(self, name=None):
self._prevs = []
self._dict = {}
self._g_prevs = None
self._g_dict = None
- self.name = prefix if prefix is not None else "default"
+ self.name = name if name is not None else "default"
self.times = {}
- self.train_prevs = {}
+ self.train_prev = None
+ self.valid_prev = None
self.target = "default"
def append_row(self, base: np.ndarray | Tuple, **row):
@@ -34,23 +36,40 @@ class EvaluationReport:
def columns(self):
return self._dict.keys()
- def groupby_prevs(self, metric: str = None):
+ def group_by_prevs(self, metric: str = None):
if self._g_dict is None:
self._g_prevs = []
self._g_dict = {k: [] for k in self._dict.keys()}
- last_end = 0
- for ind, bp in enumerate(self._prevs):
- if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]:
- continue
+ for col, vals in self._dict.items():
+ col_grouped = {}
+ for bp, v in zip(self._prevs, vals):
+ if bp not in col_grouped:
+ col_grouped[bp] = []
+ col_grouped[bp].append(v)
- self._g_prevs.append(bp)
- for col in self._dict.keys():
- self._g_dict[col].append(
- stats.mean(self._dict[col][last_end : ind + 1])
- )
+ self._g_dict[col] = [
+ vs
+ for bp, vs in sorted(col_grouped.items(), key=lambda cg: cg[0][1])
+ ]
- last_end = ind + 1
+ self._g_prevs = sorted(
+ [(p0, p1) for [p0, p1] in np.unique(self._prevs, axis=0).tolist()],
+ key=lambda bp: bp[1],
+ )
+
+ # last_end = 0
+ # for ind, bp in enumerate(self._prevs):
+ # if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]:
+ # continue
+
+ # self._g_prevs.append(bp)
+ # for col in self._dict.keys():
+ # self._g_dict[col].append(
+ # stats.mean(self._dict[col][last_end : ind + 1])
+ # )
+
+ # last_end = ind + 1
filtered_g_dict = self._g_dict
if metric is not None:
@@ -60,30 +79,83 @@ class EvaluationReport:
return self._g_prevs, filtered_g_dict
+ def avg_by_prevs(self, metric: str = None):
+ g_prevs, g_dict = self.group_by_prevs(metric=metric)
+
+ a_dict = {}
+ for col, vals in g_dict.items():
+ a_dict[col] = [np.mean(vs) for vs in vals]
+
+ return g_prevs, a_dict
+
+ def avg_all(self, metric: str = None):
+ f_dict = self._dict
+ if metric is not None:
+ f_dict = {c1: ls for ((c0, c1), ls) in self._dict.items() if c0 == metric}
+
+ a_dict = {}
+ for col, vals in f_dict.items():
+ a_dict[col] = [np.mean(vals)]
+
+ return a_dict
+
def get_dataframe(self, metric="acc"):
- g_prevs, g_dict = self.groupby_prevs(metric=metric)
+ g_prevs, g_dict = self.avg_by_prevs(metric=metric)
+ a_dict = self.avg_all(metric=metric)
+ for col in g_dict.keys():
+ g_dict[col].extend(a_dict[col])
return pd.DataFrame(
g_dict,
- index=g_prevs,
+ index=g_prevs + ["tot"],
columns=g_dict.keys(),
)
- def get_plot(self, mode="delta", metric="acc"):
- g_prevs, g_dict = self.groupby_prevs(metric=metric)
- t_prev = int(round(self.train_prevs["train"][0] * 100))
- title = f"{self.name}_{t_prev}_{metric}"
- plot.plot_delta(g_prevs, g_dict, metric, title)
+ def get_plot(self, mode="delta", metric="acc") -> Path:
+ if mode == "delta":
+ g_prevs, g_dict = self.group_by_prevs(metric=metric)
+ return plot.plot_delta(
+ g_prevs,
+ g_dict,
+ metric=metric,
+ name=self.name,
+ train_prev=self.train_prev,
+ )
+ elif mode == "diagonal":
+ _, g_dict = self.avg_by_prevs(metric=metric + "_score")
+ f_dict = {k: v for k, v in g_dict.items() if k != "ref"}
+ referece = g_dict["ref"]
+ return plot.plot_diagonal(
+ referece,
+ f_dict,
+ metric=metric,
+ name=self.name,
+ train_prev=self.train_prev,
+ )
+ elif mode == "shift":
+ g_prevs, g_dict = self.avg_by_prevs(metric=metric)
+ return plot.plot_shift(
+ g_prevs,
+ g_dict,
+ metric=metric,
+ name=self.name,
+ train_prev=self.train_prev,
+ )
def to_md(self, *metrics):
res = ""
- for k, v in self.train_prevs.items():
- res += fmt_line_md(f"{k}: {str(v)}")
+ res += fmt_line_md(f"train: {str(self.train_prev)}")
+ res += fmt_line_md(f"validation: {str(self.valid_prev)}")
for k, v in self.times.items():
res += fmt_line_md(f"{k}: {v:.3f}s")
res += "\n"
for m in metrics:
res += self.get_dataframe(metric=m).to_html() + "\n\n"
- self.get_plot(metric=m)
+ op_delta = self.get_plot(mode="delta", metric=m)
+ res += f"![plot_delta]({str(op_delta.relative_to(env.OUT_DIR))})\n"
+ op_diag = self.get_plot(mode="diagonal", metric=m)
+ res += f"![plot_diagonal]({str(op_diag.relative_to(env.OUT_DIR))})\n"
+ op_shift = self.get_plot(mode="shift", metric=m)
+ res += f"![plot_shift]({str(op_shift.relative_to(env.OUT_DIR))})\n"
return res
@@ -91,8 +163,9 @@ class EvaluationReport:
if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)):
raise ValueError("other has not same base prevalences of self")
- if len(set(self._dict.keys()).intersection(set(other._dict.keys()))) > 0:
- raise ValueError("self and other have matching keys")
+ inters_keys = set(self._dict.keys()).intersection(set(other._dict.keys()))
+ if len(inters_keys) > 0:
+ raise ValueError(f"self and other have matching keys {str(inters_keys)}.")
report = EvaluationReport()
report._prevs = self._prevs
@@ -100,12 +173,14 @@ class EvaluationReport:
return report
@staticmethod
- def combine_reports(*args, name="default"):
+ def combine_reports(*args, name="default", train_prev=None, valid_prev=None):
er = args[0]
for r in args[1:]:
er = er.merge(r)
er.name = name
+ er.train_prev = train_prev
+ er.valid_prev = valid_prev
return er
diff --git a/quacc/main.py b/quacc/main.py
index c900a98..6c2cc4c 100644
--- a/quacc/main.py
+++ b/quacc/main.py
@@ -1,16 +1,39 @@
+import os
+import shutil
+from pathlib import Path
+
import quacc.evaluation.comp as comp
from quacc.dataset import Dataset
from quacc.environ import env
+def create_out_dir(dir_name):
+ dir_path = Path(env.OUT_DIR_NAME) / dir_name
+ env.OUT_DIR = dir_path
+ shutil.rmtree(dir_path, ignore_errors=True)
+ os.mkdir(dir_path)
+ plot_dir_path = dir_path / "plot"
+ env.PLOT_OUT_DIR = plot_dir_path
+ os.mkdir(plot_dir_path)
+
+
def estimate_comparison():
- dataset = Dataset(
- env.DATASET_NAME, target=env.DATASET_TARGET, n_prevalences=env.DATASET_N_PREVS
- )
- output_path = env.OUT_DIR / f"{dataset.name}.md"
- with open(output_path, "w") as f:
- dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
- f.write(dr.to_md("acc"))
+ for conf in env:
+ create_out_dir(conf)
+ dataset = Dataset(
+ env.DATASET_NAME,
+ target=env.DATASET_TARGET,
+ n_prevalences=env.DATASET_N_PREVS,
+ )
+ output_path = env.OUT_DIR / f"{dataset.name}.md"
+ try:
+ dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
+ for m in env.METRICS:
+ output_path = env.OUT_DIR / f"{conf}_{m}.md"
+ with open(output_path, "w") as f:
+ f.write(dr.to_md(m))
+ except Exception as e:
+ print(f"Configuration {conf} failed. {e}")
# print(df.to_latex(float_format="{:.4f}".format))
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))
diff --git a/quacc/plot.py b/quacc/plot.py
index 79977d7..93170f2 100644
--- a/quacc/plot.py
+++ b/quacc/plot.py
@@ -1,16 +1,191 @@
+from pathlib import Path
+
import matplotlib.pyplot as plt
+import numpy as np
from quacc.environ import env
-def plot_delta(base_prevs, dict_vals, metric, title):
- fig, ax = plt.subplots()
+def _get_markers(n: int):
+ return [
+ "o",
+ "v",
+ "x",
+ "+",
+ "s",
+ "D",
+ "p",
+ "h",
+ "*",
+ "^",
+ ][:n]
- base_prevs = [f for f, p in base_prevs]
+
+def plot_delta(
+ base_prevs,
+ dict_vals,
+ *,
+ pos_class=1,
+ metric="acc",
+ name="default",
+ train_prev=None,
+ legend=True,
+) -> Path:
+ if train_prev is not None:
+ t_prev_pos = int(round(train_prev[pos_class] * 100))
+ title = f"delta_{name}_{t_prev_pos}_{metric}"
+ else:
+ title = f"delta_{name}_{metric}"
+
+ fig, ax = plt.subplots()
+ ax.set_aspect("auto")
+ ax.grid()
+
+ NUM_COLORS = len(dict_vals)
+ cm = plt.get_cmap("tab10")
+ if NUM_COLORS > 10:
+ cm = plt.get_cmap("tab20")
+ ax.set_prop_cycle(
+ color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)],
+ )
+
+ base_prevs = [bp[pos_class] for bp in base_prevs]
for method, deltas in dict_vals.items():
+ avg = np.array([np.mean(d, axis=-1) for d in deltas])
+ # std = np.array([np.std(d, axis=-1) for d in deltas])
ax.plot(
base_prevs,
+ avg,
+ label=method,
+ linestyle="-",
+ marker="o",
+ markersize=3,
+ zorder=2,
+ )
+ # ax.fill_between(base_prevs, avg - std, avg + std, alpha=0.25)
+
+ ax.set(xlabel="test prevalence", ylabel=metric, title=title)
+
+ if legend:
+ ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
+ output_path = env.PLOT_OUT_DIR / f"{title}.png"
+ fig.savefig(output_path, bbox_inches="tight")
+
+ return output_path
+
+
+def plot_diagonal(
+ reference,
+ dict_vals,
+ *,
+ pos_class=1,
+ metric="acc",
+ name="default",
+ train_prev=None,
+ legend=True,
+):
+ if train_prev is not None:
+ t_prev_pos = int(round(train_prev[pos_class] * 100))
+ title = f"diagonal_{name}_{t_prev_pos}_{metric}"
+ else:
+ title = f"diagonal_{name}_{metric}"
+
+ fig, ax = plt.subplots()
+ ax.set_aspect("auto")
+ ax.grid()
+
+ NUM_COLORS = len(dict_vals)
+ cm = plt.get_cmap("tab10")
+ ax.set_prop_cycle(
+ marker=_get_markers(NUM_COLORS) * 2,
+ color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)] * 2,
+ )
+
+ reference = np.array(reference)
+ x_ticks = np.unique(reference)
+ x_ticks.sort()
+
+ for _, deltas in dict_vals.items():
+ deltas = np.array(deltas)
+ ax.plot(
+ reference,
deltas,
+ linestyle="None",
+ markersize=3,
+ zorder=2,
+ )
+
+ for method, deltas in dict_vals.items():
+ deltas = np.array(deltas)
+ x_interp = x_ticks[[0, -1]]
+ y_interp = np.interp(x_interp, reference, deltas)
+ ax.plot(
+ x_interp,
+ y_interp,
+ label=method,
+ linestyle="-",
+ markersize="0",
+ zorder=1,
+ )
+
+ ax.set(xlabel="test prevalence", ylabel=metric, title=title)
+
+ if legend:
+ ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
+ output_path = env.PLOT_OUT_DIR / f"{title}.png"
+ fig.savefig(output_path, bbox_inches="tight")
+ return output_path
+
+
+def plot_shift(
+ base_prevs,
+ dict_vals,
+ *,
+ pos_class=1,
+ metric="acc",
+ name="default",
+ train_prev=None,
+ legend=True,
+) -> Path:
+ if train_prev is None:
+ raise AttributeError("train_prev cannot be None.")
+
+ train_prev = train_prev[pos_class]
+ t_prev_pos = int(round(train_prev * 100))
+ title = f"shift_{name}_{t_prev_pos}_{metric}"
+
+ fig, ax = plt.subplots()
+ ax.set_aspect("auto")
+ ax.grid()
+
+ NUM_COLORS = len(dict_vals)
+ cm = plt.get_cmap("tab10")
+ if NUM_COLORS > 10:
+ cm = plt.get_cmap("tab20")
+ ax.set_prop_cycle(
+ color=[cm(1.0 * i / NUM_COLORS) for i in range(NUM_COLORS)],
+ )
+
+ base_prevs = np.around(
+ [abs(bp[pos_class] - train_prev) for bp in base_prevs], decimals=2
+ )
+ for method, deltas in dict_vals.items():
+ delta_bins = {}
+ for bp, delta in zip(base_prevs, deltas):
+ if bp not in delta_bins:
+ delta_bins[bp] = []
+ delta_bins[bp].append(delta)
+
+ bp_unique, delta_avg = zip(
+ *sorted(
+ {k: np.mean(v) for k, v in delta_bins.items()}.items(),
+ key=lambda db: db[0],
+ )
+ )
+
+ ax.plot(
+ bp_unique,
+ delta_avg,
label=method,
linestyle="-",
marker="o",
@@ -19,8 +194,10 @@ def plot_delta(base_prevs, dict_vals, metric, title):
)
ax.set(xlabel="test prevalence", ylabel=metric, title=title)
- # ax.set_ylim(0, 1)
- # ax.set_xlim(0, 1)
- ax.legend()
+
+ if legend:
+ ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
output_path = env.PLOT_OUT_DIR / f"{title}.png"
- plt.savefig(output_path)
+ fig.savefig(output_path, bbox_inches="tight")
+
+ return output_path