from typing import Union

import numpy as np
from scipy.spatial.distance import cdist
from sklearn import clone
from sklearn.linear_model import LogisticRegression

from quapy.data import LabelledCollection
from quapy.method.aggregative import PACC, _training_helper, PCC
from quapy.method.base import BaseQuantifier

from sklearn.preprocessing import normalize

# ideas: the observation proves that if you have a validation set from the target distribution, then it "repairs"
# the predictions of the classifier. This might sound as a triviliaty, but note that the classifier is trained on
# another distribution. So one could take a look at the test set (w/o labels) and extract a portion of the entire
# labelled collection that matches the test set well, and keep the remainder as the training set on which to train
# the classifier. (The version implemented so far follows a different heuristic, based on having a validation split
# which is iid wrt the training set, and using this validation split to extract another validation split closer to the
# test distribution.

# note: the T3 variant (the iterative one) admits two variants: (i) the estimated test prev is used to sample, via
# artificial sampling, a sample from the validation that reflects the desired prevalence; (ii) the test prev is used
# to compute the weights that compensate (i.e., rebalance) the relative importance of each of the current samples
# wrt to the believed prevalence. Both are implemented, but the current one is the (ii), and (i) is commented


class TransductivePACC(BaseQuantifier):
    """
    PACC works by adjusting the PCC estimate applying a linear correction. This correction assumes P(X|Y) is fixed
    between the training and test distributions, meaning that the missclassification rates estimated in the training
    distribution (e.g., by means of a train/val split, or by means of k-FCV) is a good representative of the
    missclassification rates in the test. In situations in which the training and test distributions are shifted, and
    in which P(X|Y) cannot be assumed to remain constant (e.g., in contexts of covariate shift), this adjustment
    can be arbitrarily harmful. Transductive quantifiers decide the correction as a function of the test set.
    TransductivePACC in particular implements this intuition by picking a validation subset from the training set
    such that it is close to the test set. In this preliminary example, we simply rely on distances for choosing
    points close to every test point. The missclassification rates are estimated in this "transductive" validation
    split.

    :param learner:
    :param how_many:
    :param metric:
    """

    def __init__(self, learner, how_many=1, metric='euclidean'):
        self.learner = learner
        self.how_many = how_many
        self.metric = metric

    def quantify(self, instances):
        validation_index = self.get_closer_val_intances(instances, how_many=self.how_many, metric=self.metric)
        validation_selected = self.validation_pool.sampling_from_index(validation_index)
        pacc = PACC(self.learner, val_split=validation_selected)
        pacc.fit(None, fit_learner=False)
        self.to_show_val_selected = validation_selected  # todo: remove
        return pacc.quantify(instances)

    def fit(self, data: LabelledCollection, fit_learner=True, val_split=Union[float,LabelledCollection]):
        if isinstance(val_split, float):
            self.training, self.validation_pool = data.split_stratified(1-val_split)
        elif isinstance(val_split, LabelledCollection):
            self.training = data
            self.validation_pool = val_split
        else:
            raise ValueError('val_split data type not understood')
        self.learner, _ = _training_helper(self.learner, self.training, fit_learner=True, ensure_probabilistic=True)
        return self

    def get_closer_val_intances(self, T, how_many=1, metric='euclidean'):
        """
        Takes "how_many" instances (indices) from X that are the closes to every instance in T
        :param T: test instances
        :param how_many: how many samples to choose for every test datapoint
        :param metric: similarity function (see `scipy.spatial.distance.cdist`)
        :return: ndarray with indices of validation_pool's datapoints
        """
        dist = cdist(T, self.validation_pool.instances, metric=metric)
        indexes = np.argsort(dist, axis=1)[:, :how_many].flatten()
        return indexes


class TransductiveInvdistancePACC(BaseQuantifier):
    """
    This is a modification of TransductivePACC. The idea is that, instead of choosing the closest validation points,
    we could select all validation points but weighted inversely proportionally to the distance.
    The main objective here is to repair the performance of the t-quantifier in cases of PPS.

    :param learner:
    :param how_many:
    :param metric:
    """

    def __init__(self, learner, metric='euclidean'):
        self.learner = learner
        self.metric = metric

    def quantify(self, instances):
        validation_similarities = self.get_val_similarities(instances, metric=self.metric)
        validation_weight = validation_similarities.sum(axis=0)
        validation_posteriors = self.learner.predict_proba(self.validation_pool.instances)
        positive_posteriors = validation_posteriors[self.validation_pool.labels == 1][:,1]
        negative_posteriors = validation_posteriors[self.validation_pool.labels == 0][:,1]
        positive_weights = validation_weight[self.validation_pool.labels == 1]
        negative_weights = validation_weight[self.validation_pool.labels == 0]

        soft_tpr = (positive_posteriors*positive_weights).sum()/(positive_weights.sum())
        soft_fpr = (negative_posteriors*negative_weights).sum()/(negative_weights.sum())

        pcc = PCC(learner=self.learner).quantify(instances)
        adjusted = (pcc[1] - soft_fpr)/(soft_tpr-soft_fpr)
        adjusted = np.clip(adjusted, 0, 1)
        return np.asarray([1-adjusted,adjusted])

    def set_params(self, **parameters):
        pass

    def get_params(self, deep=True):
        pass

    def fit(self, data: LabelledCollection, fit_learner=True, val_split=Union[float,LabelledCollection]):
        if isinstance(val_split, float):
            self.training, self.validation_pool = data.split_stratified(1-val_split)
        elif isinstance(val_split, LabelledCollection):
            self.training = data
            self.validation_pool = val_split
        else:
            raise ValueError('val_split data type not understood')
        self.learner, _ = _training_helper(self.learner, self.training, fit_learner=True, ensure_probabilistic=True)
        return self

    def get_val_similarities(self, T, metric='euclidean'):
        """
        Takes "how_many" instances (indices) from X that are the closes to every instance in T
        :param T: test instances
        :param metric: similarity function (see `scipy.spatial.distance.cdist`)
        :return: ndarray with indices of validation_pool's datapoints
        """
        # dist = cdist(T, self.validation_pool.instances, metric=metric)
        # norm_dist = (dist/np.max(dist))
        # sim = 1 - norm_dist  # other variants: divide by the max distance for each test point, and not overall distance
        # norm_sim = normalize(sim**2, norm='l1') # <-- this kinds of helps
        # return norm_sim

        dist = cdist(T, self.validation_pool.instances, metric=metric)
        # dist = dist**4 # <--
        norm_dist = (dist / np.max(dist))
        sim = 1 - norm_dist  # other variants: divide by the max distance for each test point, and not overall distance
        norm_sim = normalize(sim**4, norm='l1')  # <-- this kinds helps a lot and don't know why
        return norm_sim

        # this doesn't work at all (dont know why)
        # cut_dist = np.median(dist)/3
        # dist[dist>cut_dist]=cut_dist
        # norm_dist = (dist / cut_dist)
        # sim = 1 - norm_dist  # other variants: divide by the max distance for each test point, and not overall distance
        # norm_sim = normalize(sim, norm='l1')
        # return norm_sim


class TransductiveInvdistanceIterativePACC(BaseQuantifier):
    """
    This is a modification of TransductiveInvdistancePACC.
    The idea is that, to also consider in the weight the importance prev_test / prev_train (where prev_test has to be
    estimated by means of an auxiliary quantifier).

    :param learner:
    :param metric:
    """

    def __init__(self, learner, metric='euclidean', oracle_test_prev=None):
        self.learner = learner
        self.metric = metric
        self.oracle_test_prev = oracle_test_prev

    def quantify(self, instances):

        if self.oracle_test_prev is None:
            proxy = TransductiveInvdistancePACC(learner=clone(self.learner)).fit(training, val_split=self.validation_pool)
            test_prev = proxy.quantify(instances)
            #print(f'\ttest_prev_estimated={F.strprev(test_prev)}')
        else:
            test_prev = self.oracle_test_prev

        #size = len(self.validation_pool)
        #validation = self.validation_pool.sampling(size, *test_prev[:-1])
        validation = self.validation_pool

        validation_similarities = self.get_val_similarities(instances, validation, metric=self.metric, test_prev_estim=test_prev)
        validation_weight = validation_similarities.sum(axis=0)
        validation_posteriors = self.learner.predict_proba(validation.instances)
        positive_posteriors = validation_posteriors[validation.labels == 1][:,1]
        negative_posteriors = validation_posteriors[validation.labels == 0][:,1]
        positive_weights = validation_weight[validation.labels == 1]
        negative_weights = validation_weight[validation.labels == 0]

        soft_tpr = (positive_posteriors*positive_weights).sum()/(positive_weights.sum())
        soft_fpr = (negative_posteriors*negative_weights).sum()/(negative_weights.sum())

        pcc = PCC(learner=self.learner).quantify(instances)
        adjusted = (pcc[1] - soft_fpr)/(soft_tpr-soft_fpr)
        adjusted = np.clip(adjusted, 0, 1)
        return np.asarray([1-adjusted, adjusted])

    def set_params(self, **parameters):
        pass

    def get_params(self, deep=True):
        pass

    def fit(self, data: LabelledCollection, fit_learner=True, val_split=Union[float,LabelledCollection]):
        if isinstance(val_split, float):
            self.training, self.validation_pool = data.split_stratified(1-val_split)
        elif isinstance(val_split, LabelledCollection):
            self.training = data
            self.validation_pool = val_split
        else:
            raise ValueError('val_split data type not understood')
        self.learner, _ = _training_helper(self.learner, self.training, fit_learner=True, ensure_probabilistic=True)
        return self

    def get_val_similarities(self, T, validation, metric='euclidean', test_prev_estim=None):
        """
        Takes "how_many" instances (indices) from X that are the closes to every instance in T
        :param T: test instances
        :param metric: similarity function (see `scipy.spatial.distance.cdist`)
        :return: ndarray with indices of validation_pool's datapoints
        """

        dist = cdist(T, validation.instances, metric=metric)
        # dist = dist**4 # <--
        norm_dist = (dist / np.max(dist))
        sim = 1 - norm_dist  # other variants: divide by the max distance for each test point, and not overall distance
        norm_sim = normalize(sim ** 4, norm='l1')  # <-- this kinds helps a lot and don't know why

        if test_prev_estim is not None:
            pos_reweight = test_prev_estim[1] / validation.prevalence()[1]
            neg_reweight = test_prev_estim[0] / validation.prevalence()[0]

            pos_reweight /= (pos_reweight + neg_reweight)
            neg_reweight /= (pos_reweight + neg_reweight)

            rebalance_weight = np.zeros(len(validation))
            rebalance_weight[validation.labels == 1] = pos_reweight
            rebalance_weight[validation.labels == 0] = neg_reweight

            rebalance_weight /= rebalance_weight.sum()

            # norm_sim = normalize(sim, norm='l1')
            norm_sim *= rebalance_weight
            norm_sim = normalize(norm_sim**3, norm='l1')

        return norm_sim

        # norm_sim = normalize(sim, norm='l1')  # <-- this kinds helps a lot and don't know why
        # norm_sim = normalize(norm_sim**2, norm='l1')  # <-- this kinds helps a lot and don't know why
        #return norm_sim


def plot_samples(val_orig:LabelledCollection, val_sel:LabelledCollection, test):
    import matplotlib.pyplot as plt
    import matplotlib
    import numpy as np

    font = {'family': 'normal',
            'weight': 'bold',
            'size': 10}
    matplotlib.rc('font', **font)
    size=0.5
    alpha=0.25

    # plot 1:
    instances, labels = val_orig.Xy
    x1 = instances[:,0]
    x2 = instances[:,1]

    # plt.ion()
    # plt.show()

    plt.subplot(1, 3, 1)
    plt.scatter(x1[labels==0], x2[labels==0], s=size, alpha=alpha)
    plt.scatter(x1[labels==1], x2[labels==1], s=size, alpha=alpha)
    plt.title('Validation Pool')

    # plot 2:
    instances, labels = val_sel.Xy
    x1 = instances[:, 0]
    x2 = instances[:, 1]

    plt.subplot(1, 3, 2)
    plt.scatter(x1[labels == 0], x2[labels == 0], s=size, alpha=alpha)
    plt.scatter(x1[labels == 1], x2[labels == 1], s=size, alpha=alpha)
    plt.title('Validation Choosen')

    # plot 3:
    instances, labels = test.Xy
    x1 = instances[:, 0]
    x2 = instances[:, 1]

    plt.subplot(1, 3, 3)
    # plt.scatter(x1, x2, s=size, alpha=alpha)
    plt.scatter(x1[labels == 0], x2[labels == 0], s=size, alpha=alpha)
    plt.scatter(x1[labels == 1], x2[labels == 1], s=size, alpha=alpha)
    plt.title('Test')

    # plt.draw()
    # plt.pause(0.001)
    plt.show()


class Distribution:
    def sample(self, n): pass


class ThreeGMDist(Distribution):
    """
    Three Gaussian Mixture Distribution, with one negative normal, and two positive normals
    """
    def __init__(self, mean_neg, cov_neg, mean_pos_A, cov_pos_A, mean_pos_B, cov_pos_B, prior_pos, prior_A):
        assert 0<=prior_pos<=1, 'pos_prior out of range'
        assert len(mean_neg) == len(mean_pos_A) == len(mean_pos_B), 'dimension missmatch'
        #todo check for cov dimensions
        self.mean_neg = mean_neg
        self.cov_neg = cov_neg
        self.mean_pos_A = mean_pos_A
        self.cov_pos_A = cov_pos_A
        self.mean_pos_B = mean_pos_B
        self.cov_pos_B = cov_pos_B
        self.prior_pos = prior_pos
        self.prior_A = prior_A

    def sample(self, n):
        npos = int(n*self.prior_pos)
        nneg = n-npos
        nposA = int(npos*self.prior_A)
        nposB = npos-nposA
        neg = np.random.multivariate_normal(mean=self.mean_neg, cov=self.cov_neg, size=nneg)
        pos_A = np.random.multivariate_normal(mean=self.mean_pos_A, cov=self.cov_pos_A, size=nposA)  # hard
        pos_B = np.random.multivariate_normal(mean=self.mean_pos_B, cov=self.cov_pos_B, size=nposB)  # easy
        return LabelledCollection(
            instances=np.concatenate([neg, pos_A, pos_B]),
            labels=[0]*nneg + [1]*(nposA+nposB)
        )



if __name__ == '__main__':
    import quapy as qp
    import quapy.functional as F
    print('proof of concept')

    def test(q, testset, methodtag, show=False, scores=None):
        estim_prev = q.quantify(testset.instances)
        ae = qp.error.ae(testset.prevalence(), estim_prev)
        print(f'{methodtag}\tpredicts={F.strprev(estim_prev)} true={F.strprev(testset.prevalence())} with an AE of {ae:.4f}')
        if show:
            plot_samples(q.validation_pool, q.to_show_val_selected, testset)
        if scores is not None:
            scores.append(ae)
        return ae

    def rand():
        return np.random.rand()

    def cls():
        return LogisticRegression()

    def scores():
        return {
            'i-PACC': [],
            'i-PCC': [],
            't-PACC': [],
            't2-PACC': [],
            't3-PACC': [],
        }

    score_shift = {
        'pps': scores(),
        'cov': scores(),
        'covs': scores(),
    }

    for i in range(1000):

        mneg, covneg = [0, 0], [[1, 0], [0, 1]]
        mposA, covposA = [2, 0], [[1, 0], [0, 1]]
        mposB, covposB = [3, 3], [[1, 0], [0, 1]]
        source_dist = ThreeGMDist(mneg, covneg, mposA, covposA, mposB, covposB, prior_pos=0.5, prior_A=0.5)
        target_dist_pps = ThreeGMDist(mneg, covneg, mposA, covposA, mposB, covposB, prior_pos=rand(), prior_A=0.5)
        target_dist_covs = ThreeGMDist(mneg, covneg, mposA, covposA, mposB, covposB, prior_pos=0.5, prior_A=rand())
        target_dist_covs_pps = ThreeGMDist(mneg, covneg, mposA, covposA, mposB, covposB, prior_pos=rand(), prior_A=rand())

        training = source_dist.sample(1000)
        validation_iid = source_dist.sample(1000)
        test_pps = target_dist_pps.sample(1000)
        val_pps  = target_dist_pps.sample(1000)
        test_cov = target_dist_covs.sample(1000)
        val_cov  = target_dist_covs.sample(1000)
        test_cov_pps = target_dist_covs_pps.sample(1000)
        val_cov_pps = target_dist_covs_pps.sample(1000)

        #print('observacion:')
        #inductive_pacc = PACC(cls())
        #inductive_pacc.fit(training, val_split=val_cov)
        #test(inductive_pacc, test_cov, 'i-PACC (val covs) on covariate shift')
        #inductive_pacc.fit(training, val_split=val_cov_pps)
        #test(inductive_pacc, test_cov_pps, 'i-PACC (val val_cov_pps) on covariate & prior shift')

        inductive_pacc = PACC(cls())
        inductive_pacc.fit(training, val_split=validation_iid)

        inductive_pcc = PCC(cls())
        inductive_pcc.fit(training)

        transductive_pacc = TransductivePACC(cls(), how_many=1)
        transductive_pacc.fit(training, val_split=validation_iid)

        transductive_pacc2 = TransductiveInvdistancePACC(cls())
        transductive_pacc2.fit(training, val_split=validation_iid)

        transductive_pacc3 = TransductiveInvdistanceIterativePACC(cls())
        transductive_pacc3.fit(training, val_split=validation_iid)

        print('\nPrior Probability Shift')
        print('-'*80)
        test(inductive_pacc, test_pps, 'i-PACC', scores=score_shift['pps']['i-PACC'])
        test(inductive_pcc, test_pps, 'i-PCC', scores=score_shift['pps']['i-PCC'])
        test(transductive_pacc, test_pps, 't-PACC', show=False, scores=score_shift['pps']['t-PACC'])
        test(transductive_pacc2, test_pps, 't2-PACC', show=False, scores=score_shift['pps']['t2-PACC'])
        test(transductive_pacc3, test_pps, 't3-PACC', show=False, scores=score_shift['pps']['t3-PACC'])

        print('\nCovariate Shift')
        print('-' * 80)
        test(inductive_pacc, test_cov, 'i-PACC', scores=score_shift['cov']['i-PACC'])
        test(inductive_pcc, test_cov, 'i-PCC', scores=score_shift['cov']['i-PCC'])
        test(transductive_pacc, test_cov, 't-PACC', show=False, scores=score_shift['cov']['t-PACC'])
        test(transductive_pacc2, test_cov, 't2-PACC', show=False, scores=score_shift['cov']['t2-PACC'])
        test(transductive_pacc3, test_cov, 't3-PACC', show=False, scores=score_shift['cov']['t3-PACC'])

        print('\nCovariate Shift- TYPEII')
        print('-' * 80)
        test(inductive_pacc, test_cov_pps, 'i-PACC', scores=score_shift['covs']['i-PACC'])
        test(inductive_pcc, test_cov_pps, 'i-PCC', scores=score_shift['covs']['i-PCC'])
        test(transductive_pacc, test_cov_pps, 't-PACC', show=False, scores=score_shift['covs']['t-PACC'])
        test(transductive_pacc2, test_cov_pps, 't2-PACC', scores=score_shift['covs']['t2-PACC'])
        test(transductive_pacc3, test_cov_pps, 't3-PACC', scores=score_shift['covs']['t3-PACC'])

        for shift in score_shift.keys():
            print(shift)
            for method in score_shift[shift]:
                print(f'\t{method}: {np.mean(score_shift[shift][method]):.4f}')

        # print()
        # print('-'*80)
        # # proposed method
        #
        # transductive_pacc = TransductiveInvdistanceIterativePACC(cls(), oracle_test_prev=test_pps.prevalence())
        # transductive_pacc.fit(training, val_split=validation_iid)
        # test(transductive_pacc, test_pps, 't3(oracle)-PACC on prior probability shift', show=False)
        #
        # transductive_pacc = TransductiveInvdistanceIterativePACC(cls(), oracle_test_prev=test_cov.prevalence())
        # transductive_pacc.fit(training, val_split=validation_iid)
        # test(transductive_pacc, test_cov, 't3(oracle)-PACC on covariate shift', show=False)
        #
        # transductive_pacc = TransductiveInvdistanceIterativePACC(cls(), oracle_test_prev=test_cov_pps.prevalence())
        # transductive_pacc.fit(training, val_split=validation_iid)
        # test(transductive_pacc, test_cov_pps, 't3(oracle)-PACC on covariate & prior shift')