forked from moreo/QuaPy
39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
|
from sklearn.decomposition import TruncatedSVD
|
||
|
from sklearn.linear_model import LogisticRegression
|
||
|
|
||
|
|
||
|
class PCALR:
|
||
|
|
||
|
def __init__(self, n_components=300, C=10, class_weight=None):
|
||
|
self.n_components = n_components
|
||
|
self.learner = LogisticRegression(C=C, class_weight=class_weight, max_iter=1000)
|
||
|
|
||
|
def get_params(self):
|
||
|
params = {'n_components': self.n_components}
|
||
|
params.update(self.learner.get_params())
|
||
|
return params
|
||
|
|
||
|
def set_params(self, **params):
|
||
|
if 'n_components' in params:
|
||
|
self.n_components = params['n_components']
|
||
|
del params['n_components']
|
||
|
self.learner.set_params(**params)
|
||
|
|
||
|
def fit(self, documents, labels):
|
||
|
self.pca = TruncatedSVD(self.n_components)
|
||
|
embedded = self.pca.fit_transform(documents, labels)
|
||
|
self.learner.fit(embedded, labels)
|
||
|
self.classes_ = self.learner.classes_
|
||
|
return self
|
||
|
|
||
|
def predict(self, documents):
|
||
|
embedded = self.transform(documents)
|
||
|
return self.learner.predict(embedded)
|
||
|
|
||
|
def predict_proba(self, documents):
|
||
|
embedded = self.transform(documents)
|
||
|
return self.learner.predict_proba(embedded)
|
||
|
|
||
|
def transform(self, documents):
|
||
|
return self.pca.transform(documents)
|