from typing import Union, Callable
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.neighbors import KernelDensity

import quapy as qp
from quapy.protocol import UPP
from quapy.method._kdey import KDEBase
from quapy.data import LabelledCollection
from quapy.method.aggregative import AggregativeSoftQuantifier, KDEyML
import quapy.functional as F

from sklearn.metrics.pairwise import rbf_kernel
from scipy import optimize
from tqdm import tqdm

epsilon = 1e-10

class KDEyMLauto(KDEyML):
    def __init__(self, classifier: BaseEstimator = None, val_split=5, random_state=None, optim='two_steps'):
        self.classifier = qp._get_classifier(classifier)
        self.val_split = val_split
        self.bandwidth = None
        self.random_state = random_state
        self.optim = optim

    def chose_bandwidth(self, train, test_instances):
        classif_predictions = self.classifier_fit_predict(train, fit_classifier=True, predict_on=self.val_split)
        te_posteriors = self.classify(test_instances)
        return self.transduce(classif_predictions, te_posteriors)

    def transduce(self, classif_predictions, te_posteriors):
        tr_posteriors, tr_y = classif_predictions.Xy
        classes = classif_predictions.classes_
        n_classes = len(classes)

        current_bandwidth = 0.05
        if self.optim == 'both_fine':
            current_bandwidth = np.full(fill_value=current_bandwidth, shape=(n_classes,))
        current_prevalence = np.full(fill_value=1 / n_classes, shape=(n_classes,))

        if self.optim == 'max_likelihood':
            current_prevalence, current_bandwidth = self.optim_minimize_like(tr_posteriors, tr_y, te_posteriors, classes, grid=True)
        elif self.optim == 'max_likelihood2':
            current_prevalence, current_bandwidth = self.optim_minimize_like(tr_posteriors, tr_y, te_posteriors, classes, grid=False)
        else:

            iterations = 0
            convergence = False
            with qp.util.temp_seed(self.random_state):

                while not convergence:
                    previous_bandwidth = current_bandwidth
                    previous_prevalence = current_prevalence

                    iterations += 1
                    print(f'{iterations}:')

                    if self.optim == 'two_steps':
                        current_prevalence = self.optim_minimize_prevalence(current_bandwidth, current_prevalence, tr_posteriors, tr_y, te_posteriors, classes)
                        print(f'\testim-prev={F.strprev(current_prevalence)}')
                        current_bandwidth = self.optim_minimize_bandwidth(current_bandwidth, current_prevalence, tr_posteriors, tr_y, te_posteriors, classes)
                        print(f'\tbandwidth={current_bandwidth}')
                    elif self.optim == 'both':
                        current_prevalence, current_bandwidth = self.optim_minimize_both(current_bandwidth, current_prevalence, tr_posteriors, tr_y, te_posteriors, classes)
                    elif self.optim == 'both_fine':
                        current_prevalence, current_bandwidth = self.optim_minimize_both_fine(current_bandwidth, current_prevalence, tr_posteriors, tr_y, te_posteriors, classes)

                    # check converngece
                    prev_convergence = all(np.isclose(previous_prevalence, current_prevalence, atol=0.0001))
                    if isinstance(current_bandwidth, float):
                        band_convergence = np.isclose(previous_bandwidth, current_bandwidth, atol=0.0001)
                    else:
                        band_convergence = all(np.isclose(previous_bandwidth, current_bandwidth, atol=0.0001))

                    convergence = band_convergence and prev_convergence

        self.bandwidth = current_bandwidth
        print('bandwidth=', current_bandwidth)
        print('prevalence=', current_prevalence)
        return current_prevalence

    def optim_minimize_prevalence(self, current_bandwidth, current_prev, tr_posteriors, tr_y, te_posteriors, classes):
        mix_densities = self.get_mixture_components(tr_posteriors, tr_y, classes, current_bandwidth)
        test_densities = [self.pdf(kde_i, te_posteriors) for kde_i in mix_densities]

        def neg_loglikelihood_prev(prev):
            test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prev, test_densities))
            test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
            return -np.sum(test_loglikelihood)

        return optim_minimize(neg_loglikelihood_prev, current_prev)

    def optim_minimize_bandwidth(self, current_bandwidth, current_prev, tr_posteriors, tr_y, te_posteriors, classes):
        def neg_loglikelihood_bandwidth(bandwidth):
            mix_densities = self.get_mixture_components(tr_posteriors, tr_y, classes, bandwidth[0])
            test_densities = [self.pdf(kde_i, te_posteriors) for kde_i in mix_densities]
            test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(current_prev, test_densities))
            test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
            return -np.sum(test_loglikelihood)

        bounds = [(0.00001, 1)]
        r = optimize.minimize(neg_loglikelihood_bandwidth, x0=[current_bandwidth], method='SLSQP', bounds=bounds)
        print(f'iterations-bandwidth={r.nit}')
        return r.x[0]

    def optim_minimize_both(self, current_bandwidth, current_prev, tr_posteriors, tr_y, te_posteriors, classes):
        n_classes = len(current_prev)
        def neg_loglikelihood_bandwidth(prevalence_bandwidth):
            bandwidth = prevalence_bandwidth[-1]
            prevalence = prevalence_bandwidth[:-1]
            mix_densities = self.get_mixture_components(tr_posteriors, tr_y, classes, bandwidth)
            test_densities = [self.pdf(kde_i, te_posteriors) for kde_i in mix_densities]
            test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prevalence, test_densities))
            test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
            return -np.sum(test_loglikelihood)

        bounds = [(0, 1) for _ in range(n_classes)] + [(0.00001, 1)]
        constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x[:n_classes])})
        prevalence_bandwidth = np.append(current_prev, current_bandwidth)
        r = optimize.minimize(neg_loglikelihood_bandwidth, x0=prevalence_bandwidth, method='SLSQP', bounds=bounds, constraints=constraints)
        print(f'iterations-both={r.nit}')
        prev_band = r.x
        current_prevalence = prev_band[:-1]
        current_bandwidth = prev_band[-1]
        return current_prevalence, current_bandwidth

    def optim_minimize_both_fine(self, current_bandwidth, current_prev, tr_posteriors, tr_y, te_posteriors, classes):
        n_classes = len(current_bandwidth)
        def neg_loglikelihood_bandwidth(prevalence_bandwidth):
            prevalence = prevalence_bandwidth[:n_classes]
            bandwidth = prevalence_bandwidth[n_classes:]
            mix_densities = self.get_mixture_components(tr_posteriors, tr_y, classes, bandwidth)
            test_densities = [self.pdf(kde_i, te_posteriors) for kde_i in mix_densities]
            test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prevalence, test_densities))
            test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
            return -np.sum(test_loglikelihood)

        bounds = [(0, 1) for _ in range(n_classes)] + [(0.00001, 1) for _ in range(n_classes)]
        constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x[:n_classes])})
        prevalence_bandwidth = np.concatenate((current_prev, current_bandwidth))
        r = optimize.minimize(neg_loglikelihood_bandwidth, x0=prevalence_bandwidth, method='SLSQP', bounds=bounds, constraints=constraints)
        print(f'iterations-both-fine={r.nit}')
        prev_band = r.x
        current_prevalence = prev_band[:n_classes]
        current_bandwidth = prev_band[n_classes:]
        return current_prevalence, current_bandwidth

    def optim_minimize_like(self, tr_posteriors, tr_y, te_posteriors, classes, reduction=100, grid=True):
        n_classes = len(classes)

        # reduce samples to speed up computation
        posteriors_subsample = LabelledCollection(tr_posteriors, tr_y)
        posteriors_subsample = posteriors_subsample.sampling(reduction*n_classes)
        n_test = te_posteriors.shape[0]
        subsample_index = np.random.choice(np.arange(n_test), size=reduction)
        te_posterior_subsample = te_posteriors[subsample_index]

        if grid:
            _, best_band = self.choose_bandwidth_maxlikelihood_grid(*posteriors_subsample.Xy, te_posterior_subsample, classes)
        else:
            best_band = self.choose_bandwidth_maxlikelihood_search(*posteriors_subsample.Xy, te_posterior_subsample, classes)

        mix_densities = self.get_mixture_components(tr_posteriors, tr_y, classes, best_band)
        test_densities = [self.pdf(kde_i, te_posteriors) for kde_i in mix_densities]

        def neg_loglikelihood_prev(prev):
            test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prev, test_densities))
            test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
            return -np.sum(test_loglikelihood)

        init_prev = np.full(fill_value=1 / n_classes, shape=(n_classes,))
        pred_prev, neglikelihood = optim_minimize(neg_loglikelihood_prev, init_prev, return_loss=True)

        return pred_prev, best_band


    def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
        self.classif_predictions = classif_predictions
        return self

    def aggregate(self, posteriors: np.ndarray):
        return self.transduce(self.classif_predictions, posteriors)

    def choose_bandwidth_maxlikelihood_grid(self, tr_posteriors, tr_y, te_posteriors, classes):
        n_classes = len(classes)
        best_band = None
        best_like = None
        best_prev = None
        init_prev = np.full(fill_value=1 / n_classes, shape=(n_classes,))
        for bandwidth in np.logspace(-4, 0.5, 50):
            mix_densities = self.get_mixture_components(tr_posteriors, tr_y, classes, bandwidth)
            test_densities = [self.pdf(kde_i, te_posteriors) for kde_i in mix_densities]

            def neg_loglikelihood_prev(prev):
                test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prev, test_densities))
                test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
                return -np.sum(test_loglikelihood)

            pred_prev, neglikelihood = optim_minimize(neg_loglikelihood_prev, init_prev, return_loss=True)

            if best_like is None or neglikelihood < best_like:
                best_like = neglikelihood
                best_band = bandwidth
                best_prev = pred_prev

        print(f'best-like={best_like:.4f}')
        print(f'best-band={best_band:.4f}')
        return best_prev, best_band

    def choose_bandwidth_maxlikelihood_search(self, tr_posteriors, tr_y, te_posteriors, classes):
        n_classes = len(classes)
        init_prev = np.full(fill_value=1 / n_classes, shape=(n_classes,))

        def neglikelihood_band(bandwidth):
            mix_densities = self.get_mixture_components(tr_posteriors, tr_y, classes, bandwidth)
            test_densities = [self.pdf(kde_i, te_posteriors) for kde_i in mix_densities]

            def neg_loglikelihood_prev(prev):
                test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prev, test_densities))
                test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
                return -np.sum(test_loglikelihood)

            pred_prev, neglikelihood = optim_minimize(neg_loglikelihood_prev, init_prev, return_loss=True)

            return neglikelihood

        bounds = [(0.0001, 1)]
        r = optimize.minimize(neglikelihood_band, x0=[0.001], method='SLSQP', bounds=bounds)

        best_band = r.x[0]
        print(f'solved in nit={r.nit}')
        return best_band


def optim_minimize(loss: Callable, init_prev: np.ndarray, return_loss=False):
    """
    Searches for the optimal prevalence values, i.e., an `n_classes`-dimensional vector of the (`n_classes`-1)-simplex
    that yields the smallest lost. This optimization is carried out by means of a constrained search using scipy's
    SLSQP routine.

    :param loss: (callable) the function to minimize
    :return: (ndarray) the best prevalence vector found
    """

    n_classes = len(init_prev)
    # solutions are bounded to those contained in the unit-simplex
    bounds = tuple((0, 1) for _ in range(n_classes))  # values in [0,1]
    constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)})  # values summing up to 1
    r = optimize.minimize(loss, x0=init_prev, method='SLSQP', bounds=bounds, constraints=constraints, tol=1e-10)
    # print(f'iterations-prevalence={r.nit}')
    if return_loss:
        return r.x, r.fun
    else:
        return r.x



class KDEyMLauto2(KDEyML):

    def __init__(self, classifier: BaseEstimator=None, val_split=5, bandwidth=0.1, random_state=None, reduction=100, max_reduced=500, target='likelihood'):
        """
        reduction: number of examples per class for automatically setting the bandwidth
        """
        self.classifier = qp._get_classifier(classifier)
        self.val_split = val_split
        if bandwidth == 'auto':
            self.bandwidth = bandwidth
        else:
            self.bandwidth = KDEBase._check_bandwidth(bandwidth)
        self.reduction = reduction
        self.max_reduced = max_reduced
        self.random_state = random_state
        assert target in ['likelihood', 'likelihood+'] or target in qp.error.QUANTIFICATION_ERROR_NAMES, 'unknown target for auto'
        self.target = target

    def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
        if self.bandwidth == 'auto':
            self.auto_bandwidth_likelihood(classif_predictions)
        else:
            self.bandwidth_ = self.bandwidth
        self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.classes_, self.bandwidth_)
        return self

    def auto_bandwidth_likelihood(self, classif_predictions: LabelledCollection):
        n_classes = classif_predictions.n_classes

        train, val = classif_predictions.split_stratified(train_prop=0.5, random_state=self.random_state)

        if self.reduction is not None:
            # reduce samples to speed up computation
            tr_length = min(self.reduction * n_classes, self.max_reduced)
            if len(train) > tr_length:
                train = train.sampling(tr_length)

        init_prev = np.full(fill_value=1 / n_classes, shape=(n_classes,))
        repeats = 25
        prot = UPP(val, sample_size=self.reduction, repeats=repeats, random_state=self.random_state)

        if self.target == 'likelihood+':
            def neg_loglikelihood_band_(bandwidth):
                mix_densities = self.get_mixture_components(*train.Xy, train.classes_, bandwidth)
                loss_accum = 0

                for (sample, prev) in tqdm(prot(), total=repeats):
                    test_densities = [self.pdf(kde_i, sample) for kde_i in mix_densities]

                    def neg_loglikelihood_prev_(prev):
                        test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prev, test_densities))
                        test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
                        return -np.sum(test_loglikelihood)

                    pred_prev, loss_val = optim_minimize(neg_loglikelihood_prev_, init_prev, return_loss=True)
                    loss_accum += loss_val

                return loss_accum

            bounds = [tuple((0.0001, np.log10(0.2)))]
            init_bandwidth = 0.1
            r = optimize.minimize(neg_loglikelihood_band_, x0=[init_bandwidth], method='SLSQP', bounds=bounds)
            best_band = r.x[0]

        else:
            best_band = None
            best_loss_val = None
            init_prev = np.full(fill_value=1 / n_classes, shape=(n_classes,))
            for bandwidth in np.logspace(-4, np.log10(0.2), 20):
                mix_densities = self.get_mixture_components(*train.Xy, train.classes_, bandwidth)

                loss_accum = 0
                for (sample, prev) in tqdm(prot(), total=repeats):
                    test_densities = [self.pdf(kde_i, sample) for kde_i in mix_densities]

                    def neg_loglikelihood_prev_(prev):
                        test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prev, test_densities))
                        test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
                        return -np.sum(test_loglikelihood)

                    if self.target == 'likelihood':
                        loss_fn = neg_loglikelihood_prev_
                    else:
                        loss_fn = lambda prev_hat: qp.error.from_name(self.target)(prev, prev_hat)

                    pred_prev, loss_val = optim_minimize(loss_fn, init_prev, return_loss=True)
                    loss_accum += loss_val

                if best_loss_val is None or loss_accum < best_loss_val:
                    best_loss_val = loss_accum
                    best_band = bandwidth

        print(f'found bandwidth={best_band:.4f}') # (loss_val={best_loss_val:.5f})')
        self.bandwidth_ = best_band