documenting the regressor

This commit is contained in:
Alejandro Moreo Fernandez 2024-05-29 11:23:17 +02:00
parent 3264e66cc9
commit 7febaa2693
1 changed files with 32 additions and 6 deletions

View File

@ -18,6 +18,7 @@ from scripts.evaluate import normalized_match_distance, match_distance
def projection_simplex_sort(unnormalized_arr) -> np.ndarray: def projection_simplex_sort(unnormalized_arr) -> np.ndarray:
"""Projects a point onto the probability simplex. """Projects a point onto the probability simplex.
[This code is taken from the devel branch, that will correspond to the future QuaPy 0.1.9]
The code is adapted from Mathieu Blondel's BSD-licensed The code is adapted from Mathieu Blondel's BSD-licensed
`implementation <https://gist.github.com/mblondel/6f3b7aaad90606b98f71>`_ `implementation <https://gist.github.com/mblondel/6f3b7aaad90606b98f71>`_
@ -42,22 +43,51 @@ def projection_simplex_sort(unnormalized_arr) -> np.ndarray:
class RegressionToSimplex(BaseEstimator): class RegressionToSimplex(BaseEstimator):
"""
A very simple regressor of probability distributions.
Internally, this class works by invoking an SVR regressor multioutput
followed by a mapping onto the probability simplex.
:param C: regularziation parameter for SVR
"""
def __init__(self, C=1): def __init__(self, C=1):
self.C = C self.C = C
def fit(self, X, y): def fit(self, X, y):
"""
Learns the correction
:param X: array-like of shape `(n_instances, n_classes)` with uncorrected prevalence vectors
:param y: array-like of shape `(n_instances, n_classes)` with true prevalence vectors
:return: self
"""
self.reg = MultiOutputRegressor(SVR(C=self.C), n_jobs=-1) self.reg = MultiOutputRegressor(SVR(C=self.C), n_jobs=-1)
self.reg.fit(X, y) self.reg.fit(X, y)
return self return self
def predict(self, X): def predict(self, X):
"""
Corrects the a vector of prevalence values
:param X: array-like of shape `(n_classes,)` with one vector of uncorrected prevalence values
:return: array-like of shape `(n_classes,)` with one vector of corrected prevalence values
"""
y_ = self.reg.predict(X) y_ = self.reg.predict(X)
# y_ = F.normalize_prevalence(y_)
y_ = np.asarray([projection_simplex_sort(y_i) for y_i in y_]) y_ = np.asarray([projection_simplex_sort(y_i) for y_i in y_])
return y_ return y_
class KDEyRegressor(BaseQuantifier): class KDEyRegressor(BaseQuantifier):
"""
This class implements a regressor-based correction on top of a quantifier.
The quantifier is taken to be KDEy-ML, which is considered to be already trained (this
method simply loads a pickled object).
The method then optimizes a regressor that corrects prevalence vectors using the
validation samples as training data.
The regressor is based on a multioutput SVR and relies on a post-processing to guarantee
that the output lies on the probability simplex (see also RegressionToSimplex)
"""
def __init__(self, kde_path, Cs=np.logspace(-3,3,7)): def __init__(self, kde_path, Cs=np.logspace(-3,3,7)):
self.kde_path = kde_path self.kde_path = kde_path
@ -96,12 +126,8 @@ class KDEyRegressor(BaseQuantifier):
if __name__ == '__main__': if __name__ == '__main__':
train, gen_val, gen_test = fetch_lequa2024(task='T3', data_home='./data', merge_T3=True) train, gen_val, _ = fetch_lequa2024(task='T3', data_home='./data', merge_T3=True)
kdey_r = KDEyRegressor('./models/T3/KDEy-ML.pkl') kdey_r = KDEyRegressor('./models/T3/KDEy-ML.pkl')
kdey_r.fit(gen_val) kdey_r.fit(gen_val)
prev_hat_tr = kdey_r.quantify(train.X)
print(prev_hat_tr)
print(train.prevalence())
pickle.dump(kdey_r, open('./models/T3/KDEyRegressor.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(kdey_r, open('./models/T3/KDEyRegressor.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)