From 79dcef35ae2e93fc110a1e6e1c39401314539beb Mon Sep 17 00:00:00 2001
From: Alex Moreo <alejandro.moreo@isti.cnr.it>
Date: Sat, 24 Feb 2024 14:36:04 +0100
Subject: [PATCH] half iplementation of the equatons method

---
 ClassifierAccuracy/models_multiclass.py | 61 ++++++++++++++++++++++++-
 1 file changed, 59 insertions(+), 2 deletions(-)

diff --git a/ClassifierAccuracy/models_multiclass.py b/ClassifierAccuracy/models_multiclass.py
index ba1fde7..0d098db 100644
--- a/ClassifierAccuracy/models_multiclass.py
+++ b/ClassifierAccuracy/models_multiclass.py
@@ -12,6 +12,7 @@ from sklearn.model_selection import cross_val_predict
 
 from quapy.method.base import BaseQuantifier
 from quapy.method.aggregative import PACC
+import quapy.functional as F
 
 
 class ClassifierAccuracyPrediction(ABC):
@@ -117,11 +118,67 @@ class ContTableWithHTransferCAP(ClassifierAccuracyPrediction):
         :param test: test collection (ignored)
         :return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix`
         """
-        prev_hat = self.q.quantify(test)
-        adjustment = prev_hat / self.train_prev
+        test_prev_estim = self.q.quantify(test)
+        adjustment = test_prev_estim / self.train_prev
         return self.cont_table * adjustment[:, np.newaxis]
 
 
+class NsquaredEquationsCAP(ClassifierAccuracyPrediction):
+    """
+
+    """
+    def __int__(self, h: BaseEstimator, acc: callable, q_class):
+        super().__init__(h, acc)
+        self.q = q_class(classifier=h)
+
+    def fit(self, val: LabelledCollection):
+        y_hat = self.h.predict(val.X)
+        y_true = val.y
+        self.cont_table = confusion_matrix(y_true, y_pred=y_hat, labels=val.classes_)
+        self.q.fit(val, fit_classifier=False, val_split=val)
+        return self
+
+    def predict_ct(self, test):
+        """
+        :param test: test collection (ignored)
+        :return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix`
+        """
+
+        # we need a n x n matrix of unknowns
+
+        n = self.cont_table.shape[1]
+        I = np.arange(n*n).reshape(n,n)
+        h_label_preds = self.h.predict(test)
+        cc_prev_estim = F.prevalence_from_labels(h_label_preds, self.h.classes_)
+        q_prev_estim = self.q.quantify(test)
+
+        A = np.zeros_like(self.cont_table)
+        b = np.zeros(n)
+
+        # first equation: the sum of all unknowns is 1
+        eq_no = 0
+        A[eq_no, :] = 1
+        b[eq_no] = 1
+        eq_no += 1
+
+        # n-1 equations: the sum of class-cond predictions must equal the sum of predictions
+        for i in range(n-1):
+            A[eq_no + i, I[:, i+1]] = 1
+            b[eq_no + i] = cc_prev_estim[i+1]
+        eq_no += (n-1)
+
+        # n-1 equations: the sum of true true class-conditional positives must equal the class prev label in test
+        for i in range(n-1):
+            A[eq_no + i, I[i+1, :]] = 1
+            b[eq_no + i] = q_prev_estim[i+1]
+
+        # (n-1)*(n-1) equations: the class cond rations should be the same in training and in test due to the
+        # PPS assumptions
+    
+
+
+
+
 
 
 class UpperBound(ClassifierAccuracyPrediction):