trying to understand GP
This commit is contained in:
parent
d0444d3bbb
commit
09ee9efb3f
ClassifierAccuracy
|
@ -0,0 +1,65 @@
|
|||
from sklearn.gaussian_process import GaussianProcessRegressor
|
||||
import numpy as np
|
||||
from sklearn.gaussian_process.kernels import RBF, GenericKernelMixin, Kernel
|
||||
from sklearn.metrics.pairwise import pairwise_distances
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
class MinL2Kernel(GenericKernelMixin, Kernel):
|
||||
"""
|
||||
A minimal (but valid) convolutional kernel for sequences of variable
|
||||
lengths."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _f(self, sample1, sample2):
|
||||
"""
|
||||
kernel value between a pair of sequences
|
||||
"""
|
||||
sample1 = sample1.reshape(-1,3)
|
||||
sample2 = sample2.reshape(-1, 3)
|
||||
dist = pairwise_distances(sample1, sample2)
|
||||
return dist.min(axis=1).mean()
|
||||
|
||||
def __call__(self, X, Y=None, eval_gradient=False):
|
||||
if Y is None:
|
||||
Y = X
|
||||
|
||||
if eval_gradient:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
return np.array([[self._f(x, y) for y in Y] for x in X])
|
||||
|
||||
def diag(self, X):
|
||||
return np.array([self._f(x, x) for x in X])
|
||||
|
||||
def is_stationary(self):
|
||||
return True
|
||||
|
||||
|
||||
def f(X):
|
||||
X = X.reshape(-1,3)
|
||||
return X[:,0]**3 + 2.1*X[:,1]**2 + X[:,0] + 0.1
|
||||
|
||||
|
||||
X_train = [np.random.rand(10*3) for _ in range(11)]
|
||||
y_train = [f(X).mean() for X in X_train]
|
||||
|
||||
X_test = [np.random.rand(10*3) for _ in range(11)]
|
||||
y_test = [f(X).mean() for X in X_test]
|
||||
|
||||
|
||||
#kernel = 1 * RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e2))
|
||||
kernel = MinL2Kernel()
|
||||
gaussian_process = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)
|
||||
gaussian_process.fit(X_train, y_train)
|
||||
|
||||
print(gaussian_process.kernel_)
|
||||
|
||||
y_pred = gaussian_process.predict(X_test)
|
||||
|
||||
mse = np.mean((y_test - y_pred)**2)
|
||||
|
||||
print(mse)
|
Loading…
Reference in New Issue