forked from moreo/QuaPy
66 lines
2.5 KiB
Python
Executable File
66 lines
2.5 KiB
Python
Executable File
from sklearn.svm import LinearSVC
|
|
from sklearn.model_selection import GridSearchCV
|
|
import numpy as np
|
|
from joblib import Parallel, delayed
|
|
from time import time
|
|
|
|
|
|
class MLSVC:
|
|
"""
|
|
Multi-Label Support Vector Machine, with individual optimizations per binary problem.
|
|
"""
|
|
|
|
def __init__(self, n_jobs=1, estimator=LinearSVC, *args, **kwargs):
|
|
self.n_jobs = n_jobs
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
self.verbose = False if 'verbose' not in self.kwargs else self.kwargs['verbose']
|
|
self.estimator = estimator
|
|
|
|
|
|
def fit(self, X, y, **grid_search_params):
|
|
tini = time()
|
|
assert len(y.shape)==2 and set(np.unique(y).tolist()) == {0,1}, 'data format is not multi-label'
|
|
nD,nC = y.shape
|
|
prevalence = np.sum(y, axis=0)
|
|
self.svms = np.array([self.estimator(*self.args, **self.kwargs) for _ in range(nC)])
|
|
if grid_search_params and grid_search_params['param_grid']:
|
|
self._print('grid_search activated with: {}'.format(grid_search_params))
|
|
# Grid search cannot be performed if the category prevalence is less than the parameter cv.
|
|
# In those cases we place a svm instead of a gridsearchcv
|
|
cv = 5 if 'cv' not in grid_search_params else grid_search_params['cv']
|
|
assert isinstance(cv, int), 'cv must be an int (other policies are not supported yet)'
|
|
self.svms = [GridSearchCV(svm_i, refit=True, **grid_search_params) if prevalence[i]>=cv else svm_i
|
|
for i,svm_i in enumerate(self.svms)]
|
|
for i in np.argwhere(prevalence==0).flatten():
|
|
self.svms[i] = TrivialRejector()
|
|
|
|
self.svms = Parallel(n_jobs=self.n_jobs)(
|
|
delayed(self.svms[c].fit)(X,y[:,c]) for c,svm in enumerate(self.svms)
|
|
)
|
|
self.training_time = time() - tini
|
|
|
|
|
|
def predict(self, X):
|
|
return np.vstack(list(map(lambda svmi: svmi.predict(X), self.svms))).T
|
|
|
|
|
|
def predict_proba(self, X):
|
|
return np.vstack(map(lambda svmi: svmi.predict_proba(X)[:,np.argwhere(svmi.classes_==1)[0,0]], self.svms)).T
|
|
|
|
|
|
def _print(self, msg):
|
|
if self.verbose>0:
|
|
print(msg)
|
|
|
|
|
|
def best_params(self):
|
|
return [svmi.best_params_ if isinstance(svmi, GridSearchCV) else None for svmi in self.svms]
|
|
|
|
|
|
class TrivialRejector:
|
|
def fit(self,*args,**kwargs): return self
|
|
def predict(self, X): return np.zeros(X.shape[0])
|
|
def predict_proba(self, X): return np.zeros(X.shape[0])
|
|
|