From 7febaa269356033d448c0938df4d472a901231bc Mon Sep 17 00:00:00 2001
From: Alejandro Moreo <alejandro.moreo@isti.cnr.it>
Date: Wed, 29 May 2024 11:23:17 +0200
Subject: [PATCH] documenting the regressor

---
 LeQua2024/regressor.py | 38 ++++++++++++++++++++++++++++++++------
 1 file changed, 32 insertions(+), 6 deletions(-)

diff --git a/LeQua2024/regressor.py b/LeQua2024/regressor.py
index 064a925..03293f9 100644
--- a/LeQua2024/regressor.py
+++ b/LeQua2024/regressor.py
@@ -18,6 +18,7 @@ from scripts.evaluate import normalized_match_distance, match_distance
 
 def projection_simplex_sort(unnormalized_arr) -> np.ndarray:
     """Projects a point onto the probability simplex.
+    [This code is taken from the devel branch, that will correspond to the future QuaPy 0.1.9]
 
     The code is adapted from Mathieu Blondel's BSD-licensed
     `implementation <https://gist.github.com/mblondel/6f3b7aaad90606b98f71>`_
@@ -42,22 +43,51 @@ def projection_simplex_sort(unnormalized_arr) -> np.ndarray:
 
 
 class RegressionToSimplex(BaseEstimator):
+    """
+    A very simple regressor of probability distributions.
+    Internally, this class works by invoking an SVR regressor multioutput
+    followed by a mapping onto the probability simplex.
+
+    :param C: regularziation parameter for SVR
+    """
+
     def __init__(self, C=1):
         self.C = C
 
     def fit(self, X, y):
+        """
+        Learns the correction
+
+        :param X: array-like of shape `(n_instances, n_classes)` with uncorrected prevalence vectors
+        :param y: array-like of shape `(n_instances, n_classes)` with true prevalence vectors
+        :return: self
+        """
         self.reg = MultiOutputRegressor(SVR(C=self.C), n_jobs=-1)
         self.reg.fit(X, y)
         return self
 
     def predict(self, X):
+        """
+        Corrects the a vector of prevalence values
+
+        :param X: array-like of shape `(n_classes,)` with one vector of uncorrected prevalence values
+        :return: array-like of shape `(n_classes,)` with one vector of corrected prevalence values
+        """
         y_ = self.reg.predict(X)
-        # y_ = F.normalize_prevalence(y_)
         y_ = np.asarray([projection_simplex_sort(y_i) for y_i in y_])
         return y_
 
 
 class KDEyRegressor(BaseQuantifier):
+    """
+    This class implements a regressor-based correction on top of a quantifier.
+    The quantifier is taken to be KDEy-ML, which is considered to be already trained (this
+    method simply loads a pickled object).
+    The method then optimizes a regressor that corrects prevalence vectors using the
+    validation samples as training data.
+    The regressor is based on a multioutput SVR and relies on a post-processing to guarantee
+    that the output lies on the probability simplex (see also RegressionToSimplex)
+    """
 
     def __init__(self, kde_path, Cs=np.logspace(-3,3,7)):
         self.kde_path = kde_path
@@ -96,12 +126,8 @@ class KDEyRegressor(BaseQuantifier):
 
 
 if __name__ == '__main__':
-    train, gen_val, gen_test = fetch_lequa2024(task='T3', data_home='./data', merge_T3=True)
+    train, gen_val, _ = fetch_lequa2024(task='T3', data_home='./data', merge_T3=True)
     kdey_r = KDEyRegressor('./models/T3/KDEy-ML.pkl')
     kdey_r.fit(gen_val)
-    prev_hat_tr = kdey_r.quantify(train.X)
-    print(prev_hat_tr)
-    print(train.prevalence())
-
     pickle.dump(kdey_r, open('./models/T3/KDEyRegressor.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)