import itertools
from typing import Iterable

from densratio import densratio
from scipy.sparse import issparse, vstack
from scipy.stats import multivariate_normal
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV

import quapy as qp
from Transduction_office.grid_naive_quantif import GridQuantifier, binned_indexer, Indexer, GridQuantifier2, \
    classifier_indexer
from method.non_aggregative import MLPE
from quapy.protocol import AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol, UPP
from quapy.data import LabelledCollection
from quapy.method.aggregative import *
import quapy.functional as F
from time import time
from scipy.spatial.distance import cdist
from Transduction.pykliep import DensityRatioEstimator
from quapy.protocol import AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol
from quapy.method.aggregative import *
import quapy.functional as F


plottting = False


def gaussian(mean, cov=0.1, label=0, size=100, random_state=0):
    """
    Creates a label collection in which the instances are distributed according to a Gaussian with specified
    parameters and labels all data points with a specific label.

    :param mean: ndarray of shape (n_dimensions) with the center
    :param cov: ndarray of shape (n_dimensions, n_dimensions) with the covariance matrix, or a number for np.eye
    :param label: the class label for the collection
    :param size: number of instances
    :param random_state: allows for replicating experiments
    :return: an instance of LabelledCollection
    """
    mean = np.asarray(mean)
    assert mean.ndim==1, 'wrong shape for mean'
    n_features = mean.shape[0]
    if isinstance(cov, (int, float)):
        cov = np.eye(n_features) * cov
    instances = multivariate_normal.rvs(mean, cov, size, random_state=random_state)
    return LabelledCollection(instances, labels=[label]*size)


def _internal_plot(train, val, test):
    if plottting:
        xmin = min(train.X[:, 0].min(), val.X[:, 0].min(), test[:, 0].min())
        xmax = max(train.X[:, 0].max(), val.X[:, 0].max(), test[:, 0].max())
        ymin = min(train.X[:, 1].min(), val.X[:, 1].min(), test[:, 1].min())
        ymax = max(train.X[:, 1].max(), val.X[:, 1].max(), test[:, 1].max())
        plot(train, 'sel_train.png', xlim=(xmin, xmax), ylim=(ymin, ymax))
        plot(val, 'sel_val.png', xlim=(xmin, xmax), ylim=(ymin, ymax))
        plot(test, 'test.png', xlim=(xmin, xmax), ylim=(ymin, ymax))

def plot(data: LabelledCollection, path, xlim=None, ylim=None):
    import matplotlib.pyplot as plt
    plt.clf()
    if isinstance(data, LabelledCollection):
        if data.instances.shape[1] != 2:
            return

        negative, positive = data.separate()
        plt.scatter(negative.X[:,0], negative.X[:,1], label='neg', alpha=0.5)
        plt.scatter(positive.X[:, 0], positive.X[:, 1], label='pos', alpha=0.5)
    else:
        if data.shape[1] != 2:
            return
        plt.scatter(data[:, 0], data[:, 1], label='test', alpha=0.5)
    if xlim is not None:
        plt.xlim(*xlim)
        plt.ylim(*ylim)
    plt.legend()
    plt.savefig(path)

# ------------------------------------------------------------------------------------
# Protocol for generating prior probability shift + covariate shift by mixing "domains"
# ------------------------------------------------------------------------------------
class CovPriorShift(AbstractStochasticSeededProtocol):

    def __init__(self, domains: Iterable[LabelledCollection], sample_size=None, repeats=100, min_support=0, random_state=0,
                 return_type='sample_prev'):
        super(CovPriorShift, self).__init__(random_state)
        self.domains = list(itertools.chain.from_iterable(lc.separate() for lc in domains))
        self.sample_size = qp._get_sample_size(sample_size)
        self.repeats = repeats
        self.min_support = min_support
        self.collator = OnLabelledCollectionProtocol.get_collator(return_type)

    def samples_parameters(self):
        """
        Return all the necessary parameters to replicate the samples as according to the UPP protocol.

        :return: a list of indexes that realize the UPP sampling
        """
        indexes = []
        tentatives = 0
        while len(indexes) < self.repeats:
            alpha = F.uniform_simplex_sampling(n_classes=len(self.domains))
            sizes = (alpha * self.sample_size).astype(int)
            if all(sizes > self.min_support):
                indexes_i = [lc.sampling_index(size) for lc, size in zip(self.domains, sizes)]
                indexes.append(indexes_i)
                tentatives = 0
            else:
                tentatives += 1
            if tentatives > 100:
                raise ValueError('the support is too strict, and it is difficult '
                                 'or impossible to generate valid samples')
        return indexes

    def sample(self, params):
        indexes = params
        lcs = [lc.sampling_from_index(index) for index, lc in zip(indexes, self.domains)]
        return LabelledCollection.join(*lcs)

    def total(self):
        """
        Returns the number of samples that will be generated

        :return: int
        """
        return self.repeats


# ---------------------------------------------------------------------------------------
# Methods of "importance weight", e.g., by ratio density estimation (KLIEP, SILF, LogReg)
# ---------------------------------------------------------------------------------------
class ImportanceWeight:
    @abstractmethod
    def weights(self, Xtr, ytr, Xte):
        pass


class KLIEP(ImportanceWeight):

    def __init__(self):
        pass

    def weights(self, Xtr, ytr, Xte):
        kliep = DensityRatioEstimator()
        kliep.fit(Xtr, Xte)
        return kliep.predict(Xtr)


class USILF(ImportanceWeight):

    def __init__(self, alpha=0.):
        self.alpha = alpha

    def weights(self, Xtr, ytr, Xte):
        dense_ratio_obj = densratio(Xtr, Xte, alpha=self.alpha, verbose=False)
        return dense_ratio_obj.compute_density_ratio(Xtr)


class LogReg(ImportanceWeight):

    def __init__(self):
        pass

    def weights(self, Xtr, ytr, Xte):
        # check "Direct Density Ratio Estimation for
        # Large-scale Covariate Shift Adaptation", Eq.28

        if issparse(Xtr):
            X = vstack([Xtr, Xte])
        else:
            X = np.concatenate([Xtr, Xte])

        y = [0]*len(Xtr) + [1]*len(Xte)

        logreg = GridSearchCV(
            LogisticRegression(),
            param_grid={'C':np.logspace(-3,3,7), 'class_weight': ['balanced', None]},
            n_jobs=-1
        )
        logreg.fit(X, y)
        prob_train = logreg.predict_proba(Xtr)[:,0]
        prob_test  = logreg.predict_proba(Xtr)[:,1]
        prior_train = len(Xtr)
        prior_test = len(Xte)
        w = (prior_train/prior_test)*(prob_test/prob_train)
        return w


class MostTest(ImportanceWeight):

    def __init__(self):
        pass

    def weights(self, Xtr, ytr, Xte):
        # check "Direct Density Ratio Estimation for
        # Large-scale Covariate Shift Adaptation", Eq.28

        if issparse(Xtr):
            X = vstack([Xtr, Xte])
        else:
            X = np.concatenate([Xtr, Xte])

        y = [0]*len(Xtr) + [1]*len(Xte)

        logreg = GridSearchCV(
            LogisticRegression(),
            param_grid={'C':np.logspace(-3,3,7), 'class_weight': ['balanced', None]},
            n_jobs=-1
        )
        # logreg = LogisticRegression()
        # logreg.fit(X, y)
        # prob_test  = logreg.predict_proba(Xtr)[:,1]
        prob_test = cross_val_predict(logreg, X, y, n_jobs=-1, method="predict_proba")[:len(Xtr),1]
        return prob_test


class Random(ImportanceWeight):

    def __init__(self):
        pass

    def weights(self, Xtr, ytr, Xte):
        return np.random.rand(len(Xtr))


class MostSimilarK(ImportanceWeight):
    # retains the training documents that are most similar in average to the k closest test points

    def __init__(self, k):
        self.k = k

    def weights(self, Xtr, ytr, Xte):
        distances = cdist(Xtr, Xte)
        min_dist = np.min(distances)
        max_dist = np.max(distances)
        distances = (distances-min_dist)/(max_dist-min_dist)
        similarities = 1 / (1+distances)
        top_k_sim = np.sort(similarities, axis=1)[:,-self.k:]
        ave_sim = np.mean(top_k_sim, axis=1)
        return ave_sim

class MostSimilarTest(ImportanceWeight):
    # retains the training documents that are the most similar to one test document
    # i.e., for each test point, selects the K most similar train instances

    def __init__(self, k=1):
        self.k = k

    def weights(self, Xtr, ytr, Xte):
        distances = cdist(Xtr, Xte)
        most_similar_idx = np.argsort(distances, axis=0)[:self.k, :].flatten()
        weights = np.zeros(shape=Xtr.shape[0])
        weights[most_similar_idx] = 1
        return weights

# --------------------------------------------------------------------------------------------
# Quantification Methods that rely on Importance Weight for reweighting the training instances
# --------------------------------------------------------------------------------------------
class TransductiveQuantifier(BaseQuantifier):

    def fit(self, data: LabelledCollection):
        self.training_ = data
        return self

    @property
    def training(self):
        return self.training_


class ReweightingAggregative(TransductiveQuantifier):

    def __init__(self, classifier, weighter: ImportanceWeight, quantif_method=CC):
        self.classifier = classifier
        self.weighter = weighter
        self.quantif_method = quantif_method

    def quantify(self, instances):
        # time_weight = 2.95340 time_train = 0.00619
        w = self.weighter.weights(*self.training.Xy, instances)
        self.classifier.fit(*self.training.Xy, sample_weight=w)
        quantifier = self.quantif_method(self.classifier).fit(self.training, fit_classifier=False)
        return quantifier.quantify(instances)


# --------------------------------------------------------------------------------------------
# Quantification Methods that rely on Importance Weight for selecting a validation partition
# --------------------------------------------------------------------------------------------




class SelectorQuantifiersTrainVal(TransductiveQuantifier):

    def __init__(self, classifier, weighter: ImportanceWeight, quantif_method=ACC, val_split=0.4, only_positives=False):
        self.classifier = classifier
        self.weighter = weighter
        self.quantif_method = quantif_method
        self.val_split = val_split
        self.only_positives = only_positives

    def quantify(self, instances):
        w = self.weighter.weights(*self.training.Xy, instances)
        train, val = self.select_from_weights(w, self.training, self.val_split, self.only_positives)
        _internal_plot(train, val, instances)
        # print('\ttraining size', len(train), '\tval size', len(val))
        quantifier = self.quantif_method(self.classifier).fit(train, val_split=val)
        return quantifier.quantify(instances)

    def select_from_weights(self, w, data: LabelledCollection, val_prop=0.4, only_positives=False):
        order = np.argsort(w)
        if only_positives:
            val_prop = np.mean(w > 0)
        split_point = int(len(w) * val_prop)
        different_idx, similar_idx = order[:-split_point], order[-split_point:]
        different, similar = data.sampling_from_index(different_idx), data.sampling_from_index(similar_idx)
        # return different, similar
        train, val = similar.split_stratified(0.6)
        return train, val


class SelectorQuantifiersTrain(TransductiveQuantifier):

    def __init__(self, classifier, weighter: ImportanceWeight, quantif_method=ACC, only_positives=False):
        self.classifier = classifier
        self.weighter = weighter
        self.quantif_method = quantif_method
        self.only_positives = only_positives

    def quantify(self, instances):
        w = self.weighter.weights(*self.training.Xy, instances)
        train = self.select_from_weights(w, self.training, select_prop=None, only_positives=self.only_positives)
        # _internal_plot(train, None, instances)
        # print('\ttraining size', len(train))
        quantifier = self.quantif_method(self.classifier).fit(train)
        return quantifier.quantify(instances)

    def select_from_weights(self, w, data: LabelledCollection, select_prop=0.5, only_positives=False):
        order = np.argsort(w)
        if only_positives:
            select_prop = np.mean(w > 0)
        split_point = int(len(w) * select_prop)
        different_idx, similar_idx = order[:-split_point], order[-split_point:]
        different, similar = data.sampling_from_index(different_idx), data.sampling_from_index(similar_idx)
        return similar


if __name__ == '__main__':
    qp.environ['SAMPLE_SIZE'] = 500

    dA_l0 = gaussian(mean=[0,0], label=0, size=5000)
    dA_l1 = gaussian(mean=[1,0], label=1, size=5000)
    dB_l0 = gaussian(mean=[0,1], label=0, size=5000)
    dB_l1 = gaussian(mean=[1,1], label=1, size=5000)

    dA = LabelledCollection.join(dA_l0, dA_l1)
    dB = LabelledCollection.join(dB_l0, dB_l1)

    dA_train, dA_test = dA.split_stratified(0.5, random_state=0)
    dB_train, dB_test = dB.split_stratified(0.5, random_state=0)

    train = LabelledCollection.join(dA_train, dB_train)

    plot(train, 'train.png')

    def lr():
        return LogisticRegression()



    # EMQ.MAX_ITER*=10
    # val_split = 0.5
    k_sim = 10
    Q=ACC
    methods = [
        ('MLPE', MLPE()),
        ('CC', CC(lr())),
        ('PCC', PCC(lr())),
        ('ACC', ACC(lr())),
        ('PACC', PACC(lr())),
        ('HDy', HDy(lr())),
        ('EMQ', EMQ(lr())),
        ('GridQ', GridQuantifier2(classifier=lr())),
        # ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=2)), cell_quantifier=Q(lr()))),
        # ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=4)), cell_quantifier=Q(lr()))),
        # ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=6)), cell_quantifier=Q(lr()))),
        # ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=8)), cell_quantifier=Q(lr()))),
        # ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=10)), cell_quantifier=Q(lr()))),
        # ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=20)), cell_quantifier=Q(lr()))),
        # ('kSim-ACC', SelectorQuantifiers(lr(), MostSimilar(k_sim), ACC, val_split=val_split)),
        # ('kSim-PACC', SelectorQuantifiers(lr(), MostSimilar(k_sim), PACC, val_split=val_split)),
        # ('kSim-HDy', SelectorQuantifiers(lr(), MostSimilar(k_sim), HDy, val_split=val_split)),
        # ('Sel-CC', SelectorQuantifiersTrain(lr(), MostSimilarTest(k=k_sim), CC, only_positives=True)),
        # ('Sel-PCC', SelectorQuantifiersTrain(lr(), MostSimilarTest(k=k_sim), PCC, only_positives=True)),
        # ('Sel-ACC', SelectorQuantifiersTrainVal(lr(), MostSimilarTest(k=k_sim), ACC, only_positives=True)),
        # ('Sel-PACC', SelectorQuantifiersTrainVal(lr(), MostSimilarTest(k=k_sim), PACC, only_positives=True)),
        # ('Sel-HDy', SelectorQuantifiersTrainVal(lr(), MostSimilarTest(k=k_sim), HDy, only_positives=True)),
        # ('Sel-EMQ', SelectorQuantifiersTrain(lr(), MostSimilarTest(k=k_sim), EMQ, only_positives=True)),
        # ('Sel-EMQ', SelectorQuantifiersTrainVal(lr(), USILF(), PACC, only_positives=False)),
        # ('Sel-PACC', SelectorQuantifiers(lr(), MostTest(), PACC)),
        # ('Sel-HDy', SelectorQuantifiers(lr(), MostTest(), HDy)),
        # ('LogReg-CC', ReweightingAggregative(lr(), LogReg(), CC)),
        # ('LogReg-PCC', ReweightingAggregative(lr(), LogReg(), PCC)),
        # ('LogReg-EMQ', ReweightingAggregative(lr(), LogReg(), EMQ)),
        # ('KLIEP-CC', ReweightingAggregative(lr(), KLIEP(), CC)),
        # ('KLIEP-PCC', ReweightingAggregative(lr(), KLIEP(), PCC)),
        # ('KLIEP-EMQ', ReweightingAggregative(lr(), KLIEP(), EMQ)),
        # ('SILF-CC', ReweightingAggregative(lr(), USILF(), CC)),
        # ('SILF-PCC', ReweightingAggregative(lr(), USILF(), PCC)),
        # ('SILF-EMQ', ReweightingAggregative(lr(), USILF(), EMQ))
    ]

    for name, model in methods:
        with qp.util.temp_seed(5):
            # print('original training size', len(train))
            model.fit(train)

            prot = CovPriorShift([dA_test, dB_test], repeats=1 if plottting else 150)
            # prot = UPP(dA_test+dB_test, repeats=1 if plottting else 150)
            mae = qp.evaluation.evaluate(model, protocol=prot, error_metric='mae')
            print(f'{name}: {mae = :.4f}')
            # mrae = qp.evaluation.evaluate(model, protocol=prot, error_metric='mrae')
            # print(f'{name}: {mrae = :.4f}')