import os.path
import pickle
from pathlib import Path

from sklearn.linear_model import LogisticRegression

from method.aggregative import PACC, EMQ, KDEyML

"""
Ideas:
Try kernel based on feature covariance matrix, with dot product and with another kernel
Try Cauchy-Schwarz kernel

"""

import sklearn.metrics
from sklearn.gaussian_process import GaussianProcessRegressor
import numpy as np
from sklearn.gaussian_process.kernels import RBF, GenericKernelMixin, Kernel
from sklearn.metrics.pairwise import pairwise_distances, pairwise_kernels

from data import LabelledCollection
from protocol import UPP
from quapy.method.base import BaseQuantifier, BinaryQuantifier
import quapy.functional as F
from result_table.src.table import Table

np.random.seed(0)


class FeatCovKernel(GenericKernelMixin, Kernel):
    def __init__(self, dimensions):
        self.dimensions = dimensions

    def _f(self, sample1, sample2):
        """
        kernel value between a pair of samples
        """
        sample1 = sample1.reshape(-1, self.dimensions)
        sample2 = sample2.reshape(-1, self.dimensions)
        featCov1 = pairwise_distances(sample1.T, metric='correlation')
        featCov2 = pairwise_distances(sample2.T, metric='correlation')
        featDiffNorm = np.linalg.norm(featCov1-featCov2)
        simil = np.exp(-featDiffNorm)
        return simil

    def __call__(self, X, Y=None, eval_gradient=False):
        if Y is None:
            Y = X

        if eval_gradient:
            raise NotImplementedError()
        else:
            return np.array([[self._f(x, y) for y in Y] for x in X])

    def diag(self, X):
        return np.array([self._f(x, x) for x in X])

    def is_stationary(self):
        return True

class AveL2Kernel(GenericKernelMixin, Kernel):
    """
    A minimal (but valid) convolutional kernel for sequences of variable
    lengths."""

    def __init__(self, dimensions):
        self.dimensions=dimensions

    def _f(self, sample1, sample2):
        """
        kernel value between a pair of sequences
        """
        sample1 = sample1.reshape(-1, self.dimensions)
        sample2 = sample2.reshape(-1, self.dimensions)
        dist = pairwise_distances(sample1, sample2)
        mean_dist = dist.mean()
        closenest = np.exp(-mean_dist)
        return closenest

    def __call__(self, X, Y=None, eval_gradient=False):
        if Y is None:
            Y = X

        if eval_gradient:
            raise NotImplementedError()
        else:
            return np.array([[self._f(x, y) for y in Y] for x in X])

    def diag(self, X):
        return np.array([self._f(x, x) for x in X])

    def is_stationary(self):
        return True


class RJSDkernel(GenericKernelMixin, Kernel):
    """
    A minimal (but valid) convolutional kernel for sequences of variable
    lengths."""

    def __init__(self):
        pass

    def _f(self, sample1, sample2):
        """
        kernel value between a pair of sequences
        """
        div = RJSDk(sample1, sample2)
        closenest = np.exp(-div)
        print(f'{closenest:.4f}')
        return closenest

    def __call__(self, X, Y=None, eval_gradient=False):
        if Y is None:
            Y = X

        if eval_gradient:
            raise NotImplementedError()
        else:
            return np.array([[self._f(x, y) for y in Y] for x in X])

    def diag(self, X):
        return np.array([self._f(x, x) for x in X])

    def is_stationary(self):
        return True


def RJSDk(sample_1, sample_2):
    sample_1 = sample_1.reshape(-1, 3)
    sample_2 = sample_2.reshape(-1, 3)
    n1 = sample_1.shape[0]
    n2 = sample_2.shape[0]
    pi1 = n1 / (n1 + n2)
    pi2 = n2 / (n1 + n2)
    Z = np.concatenate([sample_1, sample_2])
    Kz = pairwise_kernels(Z, metric='rbf', n_jobs=-1)
    # Kz = pairwise_kernels(Z, metric='cosine', n_jobs=-1)
    Kx = Kz[:n1, :n1]
    Ky = Kz[n1:, n1:]

    SKz = S(Kz)
    SKx = S(Kx)
    SKy = S(Ky)

    return SKz - (pi1 * SKx + pi2 * SKy)

def S(K):
    K = K/np.trace(K)
    M = K @ np.log(K)
    s = -np.trace(M)
    return s
    # eigval, _ = np.linalg.eig(K)
    # accum = 0
    # for lamda_i in eigval:
    #     accum += (lamda_i * np.log(lamda_i))
    # return -accum


def target_function(X):
    X = X.reshape(-1,3)
    return X[:,0]**3 + 2.1*X[:,1]**2 + X[:,0] + 0.1


# X = np.random.rand(14,3)
# X /= X.sum(axis=1, keepdims=True)
# Y = np.random.rand(10,3)
# Y /= Y.sum(axis=1, keepdims=True)
#
# X = X.flatten()
# Y = Y.flatten()
#
# d = RJSDk(X, Y)
#
# print(d)
#
# d = RJSDk(X, X)
#
# print(d)
#
# import sys ; sys.exit(0)

# X_train = [np.random.rand(10*3) for _ in range(50)]
# y_train = [target_function(X).mean() for X in X_train]
#
# X_test = [np.random.rand(10*3) for _ in range(20)]
# y_test = [target_function(X).mean() for X in X_test]
#
#
# print('fit')
# # kernel = 1 * RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e2))
# kernel = MinL2Kernel()
# # kernel = RJSDkernel()
# gaussian_process = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)
# gaussian_process.fit(X_train, y_train)
# print('[done]')
#
# print(gaussian_process.kernel_)
#
# y_pred = gaussian_process.predict(X_test)
#
# mse = np.mean((y_test - y_pred)**2)
#
# print(mse)

class GPQuantifier(BaseQuantifier):

    def __init__(self, dimensions, kernel, num_tr_samples=20, size_tr_samples=50):
        self.dimensions = dimensions
        self.num_tr_samples = num_tr_samples
        self.size_tr_samples = size_tr_samples
        self.gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)

    def fit(self, data: LabelledCollection):
        sampler = UPP(data, sample_size=self.size_tr_samples, repeats=self.num_tr_samples)
        Xs, ps = list(zip(*[(X,p) for X,p in sampler()]))
        ps = [p[1] for p in ps]
        Xs = [X.flatten() for X in Xs]
        self.gp.fit(Xs, ps)
        return self

    def quantify(self, instances):
        X = [instances.flatten()]
        p = self.gp.predict(X)[0]
        return F.as_binary_prevalence(p, clip_if_necessary=True)

import quapy as qp

from quapy.data.datasets import fetch_UCIBinaryDataset, UCI_BINARY_DATASETS

table = Table('avel2')
methodnames = ['AveL2','PACC', 'SLD', 'KDEyML']

for methodname in methodnames:
    errors = []
    for dataset_name in UCI_BINARY_DATASETS:
        if dataset_name in ['balance.2']:
            continue

        result_path = f'./results_gp/{dataset_name}_{methodname}.pkl'
        os.makedirs(Path(result_path).parent, exist_ok=True)
        if os.path.exists(result_path):
            aes = pickle.load(open(result_path, 'rb'))
        else:
            dataset = fetch_UCIBinaryDataset(dataset_name)
            qp.data.preprocessing.standardize(dataset, inplace=True)
            train, test = dataset.train_test
            d = train.X.shape[1]
            if methodname=='AveL2':
                q = GPQuantifier(dimensions=d, kernel=AveL2Kernel(dimensions=d), num_tr_samples=150, size_tr_samples=100)
            elif methodname=='PACC':
                q = PACC(LogisticRegression())
            elif methodname=='SLD':
                q = EMQ(LogisticRegression())
            elif methodname=='KDEyML':
                q = KDEyML(LogisticRegression(), bandwidth=0.05)
            else:
                raise ValueError('unknown method' + methodname)
            q.fit(train)
            aes = qp.evaluation.evaluate(q, UPP(test, sample_size=100), error_metric='ae', verbose=False)
            pickle.dump(aes, open(result_path, 'wb'), pickle.HIGHEST_PROTOCOL)

        mae = np.mean(aes)
        print(f'{dataset_name}\t{np.mean(mae):.4f}')

        errors.append(mae)
        table.add(dataset_name, methodname, aes)

print(f'\nmean={np.mean(errors):.5f}')
table.format.show_std=False
table.format.mean_prec=4
table.LatexPDF('./table_gp/gp.pdf', tables=[table], resizebox=False)