documenting the regressor
This commit is contained in:
parent
3264e66cc9
commit
7febaa2693
|
@ -18,6 +18,7 @@ from scripts.evaluate import normalized_match_distance, match_distance
|
||||||
|
|
||||||
def projection_simplex_sort(unnormalized_arr) -> np.ndarray:
|
def projection_simplex_sort(unnormalized_arr) -> np.ndarray:
|
||||||
"""Projects a point onto the probability simplex.
|
"""Projects a point onto the probability simplex.
|
||||||
|
[This code is taken from the devel branch, that will correspond to the future QuaPy 0.1.9]
|
||||||
|
|
||||||
The code is adapted from Mathieu Blondel's BSD-licensed
|
The code is adapted from Mathieu Blondel's BSD-licensed
|
||||||
`implementation <https://gist.github.com/mblondel/6f3b7aaad90606b98f71>`_
|
`implementation <https://gist.github.com/mblondel/6f3b7aaad90606b98f71>`_
|
||||||
|
@ -42,22 +43,51 @@ def projection_simplex_sort(unnormalized_arr) -> np.ndarray:
|
||||||
|
|
||||||
|
|
||||||
class RegressionToSimplex(BaseEstimator):
|
class RegressionToSimplex(BaseEstimator):
|
||||||
|
"""
|
||||||
|
A very simple regressor of probability distributions.
|
||||||
|
Internally, this class works by invoking an SVR regressor multioutput
|
||||||
|
followed by a mapping onto the probability simplex.
|
||||||
|
|
||||||
|
:param C: regularziation parameter for SVR
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, C=1):
|
def __init__(self, C=1):
|
||||||
self.C = C
|
self.C = C
|
||||||
|
|
||||||
def fit(self, X, y):
|
def fit(self, X, y):
|
||||||
|
"""
|
||||||
|
Learns the correction
|
||||||
|
|
||||||
|
:param X: array-like of shape `(n_instances, n_classes)` with uncorrected prevalence vectors
|
||||||
|
:param y: array-like of shape `(n_instances, n_classes)` with true prevalence vectors
|
||||||
|
:return: self
|
||||||
|
"""
|
||||||
self.reg = MultiOutputRegressor(SVR(C=self.C), n_jobs=-1)
|
self.reg = MultiOutputRegressor(SVR(C=self.C), n_jobs=-1)
|
||||||
self.reg.fit(X, y)
|
self.reg.fit(X, y)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
|
"""
|
||||||
|
Corrects the a vector of prevalence values
|
||||||
|
|
||||||
|
:param X: array-like of shape `(n_classes,)` with one vector of uncorrected prevalence values
|
||||||
|
:return: array-like of shape `(n_classes,)` with one vector of corrected prevalence values
|
||||||
|
"""
|
||||||
y_ = self.reg.predict(X)
|
y_ = self.reg.predict(X)
|
||||||
# y_ = F.normalize_prevalence(y_)
|
|
||||||
y_ = np.asarray([projection_simplex_sort(y_i) for y_i in y_])
|
y_ = np.asarray([projection_simplex_sort(y_i) for y_i in y_])
|
||||||
return y_
|
return y_
|
||||||
|
|
||||||
|
|
||||||
class KDEyRegressor(BaseQuantifier):
|
class KDEyRegressor(BaseQuantifier):
|
||||||
|
"""
|
||||||
|
This class implements a regressor-based correction on top of a quantifier.
|
||||||
|
The quantifier is taken to be KDEy-ML, which is considered to be already trained (this
|
||||||
|
method simply loads a pickled object).
|
||||||
|
The method then optimizes a regressor that corrects prevalence vectors using the
|
||||||
|
validation samples as training data.
|
||||||
|
The regressor is based on a multioutput SVR and relies on a post-processing to guarantee
|
||||||
|
that the output lies on the probability simplex (see also RegressionToSimplex)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, kde_path, Cs=np.logspace(-3,3,7)):
|
def __init__(self, kde_path, Cs=np.logspace(-3,3,7)):
|
||||||
self.kde_path = kde_path
|
self.kde_path = kde_path
|
||||||
|
@ -96,12 +126,8 @@ class KDEyRegressor(BaseQuantifier):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train, gen_val, gen_test = fetch_lequa2024(task='T3', data_home='./data', merge_T3=True)
|
train, gen_val, _ = fetch_lequa2024(task='T3', data_home='./data', merge_T3=True)
|
||||||
kdey_r = KDEyRegressor('./models/T3/KDEy-ML.pkl')
|
kdey_r = KDEyRegressor('./models/T3/KDEy-ML.pkl')
|
||||||
kdey_r.fit(gen_val)
|
kdey_r.fit(gen_val)
|
||||||
prev_hat_tr = kdey_r.quantify(train.X)
|
|
||||||
print(prev_hat_tr)
|
|
||||||
print(train.prevalence())
|
|
||||||
|
|
||||||
pickle.dump(kdey_r, open('./models/T3/KDEyRegressor.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(kdey_r, open('./models/T3/KDEyRegressor.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue