from typing import Union
import numpy as np
from sklearn.base import BaseEstimator, clone
from sklearn.cluster import KMeans, OPTICS
from sklearn.decomposition import TruncatedSVD
from sklearn.mixture import GaussianMixture
from quapy.method.base import BaseQuantifier, BinaryQuantifier
from quapy.data import LabelledCollection
from quapy.method.aggregative import ACC, PACC


class RegionAdjustment(ACC):

    def __init__(self, learner: BaseEstimator, val_split=0.4, k=2):
        self.learner = learner
        self.val_split = val_split
        # lets say k is the number of regions (here: clusters of k-means) for each class
        self.k = k

    def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
        X, y = data.Xy
        Xp, Xn = X[y==1], X[y==0]
        kmeans = KMeans(n_clusters=self.k)
        rn = kmeans.fit_predict(Xn)  # regions negative
        rp = kmeans.fit_predict(Xp)+self.k  # regions positive
        classes = np.arange(self.k*2)
        pos = LabelledCollection(Xp, rp, classes_=classes)
        neg = LabelledCollection(Xn, rn, classes_=classes)
        region_data = pos + neg
        super(RegionAdjustment, self).fit(region_data, fit_learner, val_split)
        self.reg2class = {r:(0 if r < self.k else 1) for r in range(2*self.k)}
        return self

    def classify(self, data):
        regions = super(RegionAdjustment, self).classify(data)
        return regions

    def aggregate(self, classif_predictions):
        region_prevalence = super(RegionAdjustment, self).aggregate(classif_predictions)
        bin_prevalence = np.zeros(shape=2, dtype=np.float)
        for r, prev in enumerate(region_prevalence):
            bin_prevalence[self.reg2class[r]] += prev
        return bin_prevalence


class RegionProbAdjustment(PACC):

    def __init__(self, learner: BaseEstimator, val_split=0.4, k=2):
        self.learner = learner
        self.val_split = val_split
        # lets say k is the number of regions (here: clusters of k-means) for all classes
        self.k = k

    def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
        X, y = data.Xy
        Xp, Xn = X[y==1], X[y==0]
        nk_per_class = (data.prevalence()*self.k).round().astype(int)
        print(f'number of clusters per class {nk_per_class}')

        kmeans_neg = KMeans(n_clusters=nk_per_class[0])
        rn = kmeans_neg.fit_predict(Xn)  # regions negative

        kmeans_pos = KMeans(n_clusters=nk_per_class[1])
        rp = kmeans_pos.fit_predict(Xp)+nk_per_class[0]  # regions positive

        classes = np.arange(self.k)
        pos = LabelledCollection(Xp, rp, classes_=classes)
        neg = LabelledCollection(Xn, rn, classes_=classes)

        region_data = pos + neg
        super(RegionProbAdjustment, self).fit(region_data, fit_learner, val_split)

        self.reg2class = {r:(0 if r < nk_per_class[0] else 1) for r in range(2*self.k)}

        return self

    def classify(self, data):
        regions = super(RegionProbAdjustment, self).classify(data)
        return regions

    def aggregate(self, classif_predictions):
        region_prevalence = super(RegionProbAdjustment, self).aggregate(classif_predictions)
        bin_prevalence = np.zeros(shape=2, dtype=np.float)
        for r, prev in enumerate(region_prevalence):
            bin_prevalence[self.reg2class[r]] += prev
        return bin_prevalence


class RegionProbAdjustmentGlobal(BaseQuantifier):

    def __init__(self, quantifier_fn: BaseQuantifier, k=5, clustering='gmm'):
        self.quantifier_fn = quantifier_fn
        self.k = k
        self.clustering = clustering

    def _find_regions(self, X):
        if self.clustering == 'gmm':
            self.svd = TruncatedSVD(n_components=500)
            X = self.svd.fit_transform(X)

            lowest_bic = np.infty
            bic = []
            for n_components in range(3, 8):
                # Fit a Gaussian mixture with EM
                gmm = GaussianMixture(n_components).fit(X)
                bic.append(gmm.bic(X))
                print(bic)
                if bic[-1] < lowest_bic:
                    lowest_bic = bic[-1]
                    best_gmm = gmm
            print(f'choosen GMM with {len(best_gmm.weights_)} components')
            self.cluster = best_gmm
            regions = self.cluster.predict(X)
        elif self.clustering == 'kmeans':
            print(f'kmeans with k={self.k}')
            self.cluster = KMeans(n_clusters=self.k)
            regions = self.cluster.fit_predict(X)
        elif self.clustering == 'optics':
            print('optics')
            self.svd = TruncatedSVD(n_components=500)
            X = self.svd.fit_transform(X)
            self.cluster = OPTICS()
            regions = self.cluster.fit_predict(X)
        else:
            raise NotImplementedError
        return regions

    def _get_regions(self, X):
        if self.clustering == 'gmm':
            return self.cluster.predict(self.svd.transform(X))
        elif self.clustering == 'kmeans':
            return self.cluster.predict(X)
        elif self.clustering == 'optics':
            return self.cluster.predict(self.svd.transform(X))
        else:
            raise NotImplementedError


    def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
        self.classes = data.classes_

        # first k-means (all classes involved), then PACC local to each cluster
        g = self._find_regions(data.instances)
        # g = self._get_regions(data.instances)
        X, y = data.Xy
        self.g_quantifiers = {}
        trivial=0
        for gi in np.unique(g):
            qi_data = LabelledCollection(X[g==gi], y[g==gi], classes_=data.classes_)
            if qi_data.counts()[1] <= 1:
                # check for <= 1 instead of prevalence==0, since PACC requires at least two
                # examples for performing stratified split
                # some class is (almost) empty
                # if qi_data.prevalence()[0] == 1:  # all negatives
                self.g_quantifiers[gi] = TrivialRejectorQuantifier()
                trivial+=1
            elif qi_data.counts()[0] <= 1:  # (almost) all positives
                self.g_quantifiers[gi] = TrivialAcceptorQuantifier()
                trivial += 1
            else:
                self.g_quantifiers[gi] = self.quantifier_fn().fit(qi_data)
        print(f'trivials={trivial}')

        return self

    @property
    def classes_(self):
        return self.classes

    def quantify(self, instances):
        # g = self.cluster.predict(instances)
        g = self._get_regions(instances)
        prevalence = np.zeros(len(self.classes_), dtype=np.float)
        for gi in np.unique(g):
            proportion_gi = (g==gi).mean()
            prev_gi = self.g_quantifiers[gi].quantify(instances[g==gi])
            prevalence += prev_gi * proportion_gi
        return prevalence


    def get_params(self, deep=True):
        pass

    def set_params(self, **parameters):
        pass


class TrivialRejectorQuantifier(BinaryQuantifier):
    def fit(self, data: LabelledCollection):
        return self

    def quantify(self, instances):
        return np.asarray([1,0])

    def set_params(self, **parameters):
        pass

    def get_params(self, deep=True):
        pass

    @property
    def classes_(self):
        return np.asarray([0,1])


class TrivialAcceptorQuantifier(BinaryQuantifier):
    def fit(self, data: LabelledCollection):
        return self

    def quantify(self, instances):
        return np.asarray([0,1])

    def set_params(self, **parameters):
        pass

    def get_params(self, deep=True):
        pass

    @property
    def classes_(self):
        return np.asarray([0,1])