from typing import Union
import numpy as np
from sklearn.base import BaseEstimator, clone
from sklearn.cluster import KMeans, OPTICS
from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import LogisticRegression
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import cross_val_predict

from quapy.method.base import BaseQuantifier, BinaryQuantifier
from quapy.data import LabelledCollection
from quapy.method.aggregative import ACC, PACC, PCC
import quapy.functional as F


class RegionAdjustmentQ(BaseQuantifier):

    def __init__(self, quantifier: BaseQuantifier, k=10):
        self.quantifier = quantifier
        self.k = k # number of regions

    def fit(self, data: LabelledCollection):
        X, y = data.Xy
        Xp, Xn = X[y==1], X[y==0]

        nk_per_class = (data.prevalence() * self.k).round().astype(int)
        print(f'number of regions per class {nk_per_class}')

        kmeans_neg = KMeans(n_clusters=nk_per_class[0])
        rn = kmeans_neg.fit_predict(Xn)  # regions negative

        kmeans_pos = KMeans(n_clusters=nk_per_class[1])
        rp = kmeans_pos.fit_predict(Xp) + nk_per_class[0]  # regions positive

        classes = np.arange(self.k)
        pos = LabelledCollection(Xp, rp, classes_=classes)
        neg = LabelledCollection(Xn, rn, classes_=classes)

        region_data = pos + neg
        self.quantifier.fit(region_data)

        self.reg2class = {r: (0 if r < nk_per_class[0] else 1) for r in range(2 * self.k)}

        return self

    def quantify(self, instances):
        region_prevalence = self.quantifier.quantify(instances)
        bin_prevalence = np.zeros(shape=2, dtype=np.float)
        for r, prev in enumerate(region_prevalence):
            bin_prevalence[self.reg2class[r]] += prev
        return bin_prevalence

    def set_params(self, **parameters):
        pass

    def get_params(self, deep=True):
        pass

    @property
    def classes_(self):
        return np.asarray([0,1])


class RegionAdjustment(ACC):

    def __init__(self, learner: BaseEstimator, val_split=0.4, k=2):
        self.learner = learner
        self.val_split = val_split
        # lets say k is the number of regions (here: clusters of k-means) for each class
        self.k = k

    def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
        X, y = data.Xy
        Xp, Xn = X[y==1], X[y==0]

        nk_per_class = (data.prevalence() * self.k).round().astype(int)
        print(f'number of clusters per class {nk_per_class}')

        kmeans_neg = KMeans(n_clusters=nk_per_class[0])
        rn = kmeans_neg.fit_predict(Xn)  # regions negative

        kmeans_pos = KMeans(n_clusters=nk_per_class[1])
        rp = kmeans_pos.fit_predict(Xp) + nk_per_class[0]  # regions positive

        classes = np.arange(self.k)
        pos = LabelledCollection(Xp, rp, classes_=classes)
        neg = LabelledCollection(Xn, rn, classes_=classes)

        region_data = pos + neg
        super(RegionProbAdjustment, self).fit(region_data, fit_learner, val_split)

        self.reg2class = {r: (0 if r < nk_per_class[0] else 1) for r in range(2 * self.k)}

        return self

    def classify(self, data):
        regions = super(RegionAdjustment, self).classify(data)
        return regions

    def aggregate(self, classif_predictions):
        region_prevalence = super(RegionAdjustment, self).aggregate(classif_predictions)
        bin_prevalence = np.zeros(shape=2, dtype=np.float)
        for r, prev in enumerate(region_prevalence):
            bin_prevalence[self.reg2class[r]] += prev
        return bin_prevalence


class RegionProbAdjustment(PACC):

    def __init__(self, learner: BaseEstimator, val_split=0.4, k=2):
        self.learner = learner
        self.val_split = val_split
        # lets say k is the number of regions (here: clusters of k-means) for all classes
        self.k = k

    def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
        X, y = data.Xy
        Xp, Xn = X[y==1], X[y==0]
        nk_per_class = (data.prevalence()*self.k).round().astype(int)
        print(f'number of clusters per class {nk_per_class}')

        kmeans_neg = KMeans(n_clusters=nk_per_class[0])
        rn = kmeans_neg.fit_predict(Xn)  # regions negative

        kmeans_pos = KMeans(n_clusters=nk_per_class[1])
        rp = kmeans_pos.fit_predict(Xp)+nk_per_class[0]  # regions positive

        classes = np.arange(self.k)
        pos = LabelledCollection(Xp, rp, classes_=classes)
        neg = LabelledCollection(Xn, rn, classes_=classes)

        region_data = pos + neg
        super(RegionProbAdjustment, self).fit(region_data, fit_learner, val_split)

        self.reg2class = {r:(0 if r < nk_per_class[0] else 1) for r in range(2*self.k)}

        return self

    def classify(self, data):
        regions = super(RegionProbAdjustment, self).classify(data)
        return regions

    def aggregate(self, classif_predictions):
        region_prevalence = super(RegionProbAdjustment, self).aggregate(classif_predictions)
        bin_prevalence = np.zeros(shape=2, dtype=np.float)
        for r, prev in enumerate(region_prevalence):
            bin_prevalence[self.reg2class[r]] += prev
        return bin_prevalence


class RegionProbAdjustmentGlobal(BaseQuantifier):

    def __init__(self, quantifier_fn: BaseQuantifier, k=5, clustering='gmm'):
        self.quantifier_fn = quantifier_fn
        self.k = k
        self.clustering = clustering

    def _find_regions(self, X):
        if self.clustering == 'gmm':
            self.svd = TruncatedSVD(n_components=500)
            X = self.svd.fit_transform(X)

            lowest_bic = np.infty
            bic = []
            for n_components in range(3, 8):
                # Fit a Gaussian mixture with EM
                gmm = GaussianMixture(n_components).fit(X)
                bic.append(gmm.bic(X))
                print(bic)
                if bic[-1] < lowest_bic:
                    lowest_bic = bic[-1]
                    best_gmm = gmm
            print(f'choosen GMM with {len(best_gmm.weights_)} components')
            self.cluster = best_gmm
            regions = self.cluster.predict(X)
        elif self.clustering == 'kmeans':
            print(f'kmeans with k={self.k}')
            self.cluster = KMeans(n_clusters=self.k)
            regions = self.cluster.fit_predict(X)
        elif self.clustering == 'optics':
            print('optics')
            self.svd = TruncatedSVD(n_components=500)
            X = self.svd.fit_transform(X)
            self.cluster = OPTICS()
            regions = self.cluster.fit_predict(X)
        else:
            raise NotImplementedError
        return regions

    def _get_regions(self, X):
        if self.clustering == 'gmm':
            return self.cluster.predict(self.svd.transform(X))
        elif self.clustering == 'kmeans':
            return self.cluster.predict(X)
        elif self.clustering == 'optics':
            return self.cluster.predict(self.svd.transform(X))
        else:
            raise NotImplementedError


    def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
        self.classes = data.classes_

        # first k-means (all classes involved), then PACC local to each cluster
        g = self._find_regions(data.instances)
        # g = self._get_regions(data.instances)
        X, y = data.Xy
        self.g_quantifiers = {}
        trivial, trivial_data = 0, 0
        for gi in np.unique(g):
            qi_data = LabelledCollection(X[g==gi], y[g==gi], classes_=data.classes_)
            if qi_data.counts()[1] <= 1:
                # check for <= 1 instead of prevalence==0, since PACC requires at least two
                # examples for performing stratified split
                # some class is (almost) empty
                # if qi_data.prevalence()[0] == 1:  # all negatives
                self.g_quantifiers[gi] = TrivialRejectorQuantifier()
                trivial+=1
                trivial_data += len(qi_data)
            elif qi_data.counts()[0] <= 1:  # (almost) all positives
                self.g_quantifiers[gi] = TrivialAcceptorQuantifier()
                trivial += 1
                trivial_data += len(qi_data)
            else:
                self.g_quantifiers[gi] = self.quantifier_fn().fit(qi_data)
        print(f'trivials={trivial} amounting to {trivial_data*100.0/len(data):.2f}% of the data')

        return self

    @property
    def classes_(self):
        return self.classes

    def quantify(self, instances):
        # g = self.cluster.predict(instances)
        g = self._get_regions(instances)
        prevalence = np.zeros(len(self.classes_), dtype=np.float)
        for gi in np.unique(g):
            proportion_gi = (g==gi).mean()
            prev_gi = self.g_quantifiers[gi].quantify(instances[g==gi])
            prevalence += prev_gi * proportion_gi
        return prevalence


    def get_params(self, deep=True):
        pass

    def set_params(self, **parameters):
        pass


class TrivialRejectorQuantifier(BinaryQuantifier):
    def fit(self, data: LabelledCollection):
        return self

    def quantify(self, instances):
        return np.asarray([1,0])

    def set_params(self, **parameters):
        pass

    def get_params(self, deep=True):
        pass

    @property
    def classes_(self):
        return np.asarray([0,1])


class TrivialAcceptorQuantifier(BinaryQuantifier):
    def fit(self, data: LabelledCollection):
        return self

    def quantify(self, instances):
        return np.asarray([0,1])

    def set_params(self, **parameters):
        pass

    def get_params(self, deep=True):
        pass

    @property
    def classes_(self):
        return np.asarray([0,1])


class ClassWeightPCC(BaseQuantifier):

    def __init__(self, estimator=LogisticRegression):
        self.estimator = estimator
        self.learner = PACC(self.estimator())
        self.deployed = False

    def fit(self, data: LabelledCollection, fit_learner=True):
        self.train = data
        self.learner.fit(self.train)
        return self

    def quantify(self, instances):
        guessed_prevalence = self.learner.quantify(instances)
        class_weight = self._get_class_weight(guessed_prevalence)
        base_estimator = clone(self.learner.learner)
        base_estimator.set_params(class_weight=class_weight)
        pcc = PCC(base_estimator)
        return pcc.fit(self.train).quantify(instances)

    def _get_class_weight(self, prevalence):
        # class_weight = compute_class_weight('balanced', classes=[0, 1], y=mock_y(prevalence))
        # return {0: class_weight[1], 1: class_weight[0]}
        # weights = prevalence/prevalence.min()
        weights = prevalence / self.train.prevalence()
        normfactor = weights.min()
        if normfactor <= 0:
            normfactor = 1E-3
        weights /= normfactor
        return {0:weights[0], 1:weights[1]}

    def set_params(self, **parameters):
        # parameters = {p:v for p,v in parameters.items()}
        # print(parameters)
        self.learner.set_params(**parameters)

    def get_params(self, deep=True):
        return self.learner.get_params()

    @property
    def classes_(self):
        return self.train.classes_


class PosteriorConditionalAdjustemnt(BaseQuantifier):

    def __init__(self):
        self.estimator = LogisticRegression()
        self.k = 3

    def get_adjustment_matrix(self, y, prob):
        n_classes = 2
        classes = [0, 1]
        confusion = np.empty(shape=(n_classes, n_classes))
        for i, class_ in enumerate(classes):
            index = y == class_
            if any(index):
                confusion[i] = prob[index].mean(axis=0)
            else:
                if i == 0:
                    confusion[i] = np.asarray([1,0])
                else:
                    confusion[i] = np.asarray([0, 1])

        confusion = confusion.T
        return confusion

    def fit(self, data: LabelledCollection, fit_learner=True):
        X, y = data.Xy
        proba = cross_val_predict(self.estimator, X, y, n_jobs=-1, method='predict_proba')

        order = np.argsort(proba[:,1])
        proba = proba[order]
        y = y[order]
        X = X[order]  # to keep the alignment for the final classifier
        n = len(data)
        bucket_size = n // self.k
        bucket_remainder = n % bucket_size
        self.buckets = {}
        self.prob_separations = []
        for bucket in range(self.k):
            from_pos = bucket*bucket_size
            to_pos = (bucket+1)*bucket_size + (bucket_remainder if bucket==self.k-1 else 0)
            slice_b = slice(from_pos, to_pos)
            y_b = y[slice_b]
            proba_b = proba[slice_b]
            self.buckets[bucket] = self.get_adjustment_matrix(y_b, proba_b)
            self.prob_separations.append(proba_b[-1,1])
        self.prob_separations[-1] = 1  # the last one should account for the entire prob

        self.estimator.fit(X,y)
        return self

    def quantify(self, instances):
        proba = self.estimator.predict_proba(instances)
        #proba = sorted(proba, key=lambda p:p[1])

        prev = np.zeros(shape=2, dtype=np.float)
        n = proba.shape[0]
        last_prob_sep = 0
        for b, prob_sep in enumerate(self.prob_separations):
            proba_b = proba[np.logical_and(proba[:,1] >= last_prob_sep, proba[:,1] < prob_sep)]
            last_prob_sep=prob_sep
            if proba_b.size > 0:
                pcc_b = F.prevalence_from_probabilities(proba_b, binarize=False)
                adj_matrix = self.buckets[b]
                pacc_b = ACC.solve_adjustment(adj_matrix, pcc_b)
                bucket_prev = proba_b.shape[0] / n
                print(f'bucket {b} -> {F.strprev(pacc_b)} with prop {bucket_prev:.4f}')
                prev += (pacc_b*bucket_prev)

        print(F.strprev(prev))
        return prev

    def set_params(self, **parameters):
        # parameters = {p:v for p,v in parameters.items()}
        # print(parameters)
        self.learner.set_params(**parameters)

    def get_params(self, deep=True):
        return self.learner.get_params()

    @property
    def classes_(self):
        return self.train.classes_