1
0
Fork 0
QuaPy/MultiLabel/mlclassification.py

34 lines
1.2 KiB
Python

from copy import deepcopy
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import StandardScaler
class MultilabelStackedClassifier: # aka Funnelling Monolingual
def __init__(self, base_estimator=LogisticRegression()):
if not hasattr(base_estimator, 'predict_proba'):
print('the estimator does not seem to be probabilistic: calibrating')
base_estimator = CalibratedClassifierCV(base_estimator)
self.base = deepcopy(OneVsRestClassifier(base_estimator))
self.meta = deepcopy(OneVsRestClassifier(base_estimator))
self.norm = StandardScaler()
def fit(self, X, y):
assert y.ndim==2, 'the dataset does not seem to be multi-label'
self.base.fit(X, y)
P = self.base.predict_proba(X)
P = self.norm.fit_transform(P)
self.meta.fit(P, y)
return self
def predict(self, X):
P = self.base.predict_proba(X)
P = self.norm.transform(P)
return self.meta.predict(P)
def predict_proba(self, X):
P = self.base.predict_proba(X)
P = self.norm.transform(P)
return self.meta.predict_proba(P)