From ea7a574185c9226c65acca5eee13dc8c5e820677 Mon Sep 17 00:00:00 2001
From: Alex Moreo <alejandro.moreo@isti.cnr.it>
Date: Sun, 3 Mar 2024 19:25:00 +0100
Subject: [PATCH] adding support for f1 to doc

---
 ClassifierAccuracy/models_multiclass.py |  7 +-
 ClassifierAccuracy/util/commons.py      | 95 +++++++++++++++----------
 2 files changed, 63 insertions(+), 39 deletions(-)

diff --git a/ClassifierAccuracy/models_multiclass.py b/ClassifierAccuracy/models_multiclass.py
index cdc256b..15b80cd 100644
--- a/ClassifierAccuracy/models_multiclass.py
+++ b/ClassifierAccuracy/models_multiclass.py
@@ -416,7 +416,7 @@ class QuAcc:
             X_dot = np.hstack(add_covs)
 
         if self.add_X:
-            X_dot = safehstack(X, add_covs)
+            X_dot = safehstack(X, X_dot)
 
         return X_dot
 
@@ -616,8 +616,9 @@ class ATC(ClassifierAccuracyPrediction):
 
 class DoC(ClassifierAccuracyPrediction):
 
-    def __init__(self, h, sample_size, num_samples=100):
+    def __init__(self, h, acc, sample_size, num_samples=500):
         self.h = h
+        self.acc = acc
         self.sample_size = sample_size
         self.num_samples = num_samples
 
@@ -625,7 +626,7 @@ class DoC(ClassifierAccuracyPrediction):
         P = get_posteriors_from_h(self.h, X)
         mc = max_conf(P)
         pred_labels = np.argmax(P, axis=-1)
-        acc = (y == pred_labels).mean()
+        acc = self.acc(y, pred_labels)
         return mc, acc
 
     def _doc(self, mc1, mc2):
diff --git a/ClassifierAccuracy/util/commons.py b/ClassifierAccuracy/util/commons.py
index 1846208..7d1a55e 100644
--- a/ClassifierAccuracy/util/commons.py
+++ b/ClassifierAccuracy/util/commons.py
@@ -6,7 +6,7 @@ from glob import glob
 from pathlib import Path
 from time import time
 import numpy as np
-
+from sklearn.metrics import accuracy_score, f1_score
 
 from sklearn.datasets import fetch_rcv1
 from sklearn.model_selection import GridSearchCV
@@ -74,7 +74,7 @@ def gen_CAP(h, acc_fn, with_oracle=False)->[str, ClassifierAccuracyPrediction]:
     #yield 'PabCAP-SLD-median', PabloCAP(h, acc_fn, EMQ, aggr='median')
     yield 'ATC-MC', ATC(h, acc_fn, scoring_fn='maxconf')
     #yield 'ATC-NE', ATC(h, acc_fn, scoring_fn='neg_entropy')
-    yield 'DoC', DoC(h, sample_size=qp.environ['SAMPLE_SIZE'])
+    yield 'DoC', DoC(h, acc_fn, sample_size=qp.environ['SAMPLE_SIZE'])
 
 
 def gen_CAP_cont_table(h)->[str,CAPContingencyTable]:
@@ -103,7 +103,7 @@ def get_method_names():
 
 def gen_acc_measure():
     yield 'vanilla_accuracy', vanilla_acc_fn
-    #yield 'macro-F1', macrof1
+    yield 'macro-F1', macrof1_fn
 
 
 def split(data: LabelledCollection):
@@ -156,7 +156,30 @@ def true_acc(h:BaseEstimator, acc_fn: callable, U: LabelledCollection):
     return acc_fn(conf_table)
 
 
-def vanilla_acc_fn(cont_table):
+def from_contingency_table(param1, param2):
+    if param2 is None and isinstance(param1, np.ndarray) and param1.ndim==2 and (param1.shape[0]==param1.shape[1]):
+        return True
+    elif isinstance(param1, np.ndarray) and isinstance(param2, np.ndarray) and param1.shape==param2.shape:
+        return False
+    else:
+        raise ValueError('parameters for evaluation function not understood')
+
+
+def vanilla_acc_fn(param1, param2=None):
+    if from_contingency_table(param1, param2):
+        return _vanilla_acc_from_ct(param1)
+    else:
+        return accuracy_score(param1, param2)
+
+
+def macrof1_fn(param1, param2=None):
+    if from_contingency_table(param1, param2):
+        return macro_f1_from_ct(param1)
+    else:
+        return f1_score(param1, param2, average='macro')
+
+
+def _vanilla_acc_from_ct(cont_table):
     return np.diag(cont_table).sum() / cont_table.sum()
 
 
@@ -167,7 +190,7 @@ def _f1_bin(tp, fp, fn):
         return (2 * tp) / (2 * tp + fp + fn)
 
 
-def macrof1(cont_table):
+def macro_f1_from_ct(cont_table):
     n = cont_table.shape[0]
 
     if n==2:
@@ -182,6 +205,7 @@ def macrof1(cont_table):
         fp = cont_table[:,i].sum() - tp
         fn = cont_table[i,:].sum() - tp
         f1_per_class.append(_f1_bin(tp, fp, fn))
+
     return np.mean(f1_per_class)
 
 
@@ -269,7 +293,6 @@ def gen_tables(basedir, datasets):
     mock_h = LogisticRegression(),
     methods = [method for method, _ in gen_CAP(mock_h, None)] + [method for method, _ in gen_CAP_cont_table(mock_h)]
     classifiers = [classifier for classifier, _ in gen_classifiers()]
-    measures = [measure for measure, _ in gen_acc_measure()]
 
     os.makedirs('./tables', exist_ok=True)
 
@@ -288,39 +311,39 @@ def gen_tables(basedir, datasets):
     """
 
     classifier = classifiers[0]
-    metric = "vanilla_accuracy"
+    for metric in [measure for measure, _ in gen_acc_measure()]:
 
-    table = Table(datasets, methods)
-    for method, dataset in itertools.product(methods, datasets):
-        path = getpath(basedir, classifier, metric, dataset, method)
-        if not os.path.exists(path):
-            print('missing ', path)
-            continue
-        results = json.load(open(path, 'r'))
-        true_acc = results['true_acc']
-        estim_acc = np.asarray(results['estim_acc'])
-        if any(np.isnan(estim_acc)):
-            print(f'nan values found in {method=} {dataset=}')
-            continue
-        if any(estim_acc>1.00001):
-            print(f'values >1 found in {method=} {dataset=} [max={estim_acc.max()}]')
-            continue
-        if any(estim_acc<-0.00001):
-            print(f'values <0 found in {method=} {dataset=} [min={estim_acc.min()}]')
-            continue
-        errors = cap_errors(true_acc, estim_acc)
-        table.add(dataset, method, errors)
+        table = Table(datasets, methods)
+        for method, dataset in itertools.product(methods, datasets):
+            path = getpath(basedir, classifier, metric, dataset, method)
+            if not os.path.exists(path):
+                print('missing ', path)
+                continue
+            results = json.load(open(path, 'r'))
+            true_acc = results['true_acc']
+            estim_acc = np.asarray(results['estim_acc'])
+            if any(np.isnan(estim_acc)):
+                print(f'nan values found in {method=} {dataset=}')
+                continue
+            if any(estim_acc>1.00001):
+                print(f'values >1 found in {method=} {dataset=} [max={estim_acc.max()}]')
+                continue
+            if any(estim_acc<-0.00001):
+                print(f'values <0 found in {method=} {dataset=} [min={estim_acc.min()}]')
+                continue
+            errors = cap_errors(true_acc, estim_acc)
+            table.add(dataset, method, errors)
 
-    tex = table.latexTabular()
-    table_name = f'{basedir}_{classifier}_{metric}.tex'
-    with open(f'./tables/{table_name}', 'wt') as foo:
-        foo.write('\\resizebox{\\textwidth}{!}{%\n')
-        foo.write('\\begin{tabular}{c|'+('c'*len(methods))+'}\n')
-        foo.write(tex)
-        foo.write('\\end{tabular}%\n')
-        foo.write('}\n')
+        tex = table.latexTabular()
+        table_name = f'{basedir}_{classifier}_{metric}.tex'
+        with open(f'./tables/{table_name}', 'wt') as foo:
+            foo.write('\\resizebox{\\textwidth}{!}{%\n')
+            foo.write('\\begin{tabular}{c|'+('c'*len(methods))+'}\n')
+            foo.write(tex)
+            foo.write('\\end{tabular}%\n')
+            foo.write('}\n')
 
-    tex_doc += "\input{" + table_name + "}\n"
+        tex_doc += "\input{" + table_name + "}\n\n"
 
     tex_doc += """
     \\end{document}