import numpy as np
import quapy as qp
from sklearn import clone
from sklearn.metrics import confusion_matrix
import scipy
from scipy.sparse import issparse, csr_matrix
from data import LabelledCollection
from abc import ABC, abstractmethod
from sklearn.model_selection import cross_val_predict


class ConfusionMatrixPredictor(ABC):
    """
    Abstract class of predictors of a confusion matrix for the performance of a classifier.
    For the binary case, this accounts to predicting the 4-cell contingency table consisting of the
    true positives (TP), true negatives (TN), false positives (FP), and false negatives (FN) that
    most evaluation metrics make use of.
    """
    @abstractmethod
    def fit(self, train: LabelledCollection):
        pass

    @abstractmethod
    def predict(self, test):
        pass


class MLCMEstimator(ConfusionMatrixPredictor):
    """
    The Maximum Likelihood Confusion Matrix Estimator is a method that relies on the IID assumption, and thus
    computes, via k-FCV (or any other technique) the counters of the confusion matrix, assuming that those are
    good estimates for the test case.
    """
    def __init__(self, classifier, strategy='kfcv', **kwargs):
        assert strategy in ['kfcv'], 'unknown strategy'
        if strategy=='kfcv':
            assert 'k' in kwargs, 'strategy "kfcv" requires "k" to be passed as an argument'
        self.classifier = classifier
        self.strategy = strategy
        self.kwargs = kwargs

    def sout(self, msg):
        if 'verbose' in self.kwargs:
            print(msg)

    def fit(self, train: LabelledCollection):
        X, y = train.Xy
        if self.strategy == 'kfcv':
            k=self.kwargs['k']
            n_jobs = self.kwargs['n_jobs'] if 'n_jobs' in self.kwargs else 1
            predict = self.kwargs['predict'] if 'predict' in self.kwargs else 'predict'
            self.sout(f'{self.__class__.__name__}: '
                      f'running cross_val_predict with k={k} n_jobs={n_jobs} predict={predict}')
            predictions = cross_val_predict(self.classifier, X, y, cv=k, n_jobs=n_jobs, method=predict)
            self.conf_matrix = confusion_matrix(y, predictions, labels=train.classes_)
        return self

    def predict(self, test):
        """
        This method disregards the test set, under the assumption that it is IID wrt the training. This meaning that
        the confusion matrix for the test data should coincide with the one computed for training (using any cross
        validation strategy).

        :param test: test collection (ignored)
        :return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix`
        """
        return self.conf_matrix


class UpperBound(ConfusionMatrixPredictor):
    def __init__(self, classifier, y_test):
        self.classifier = classifier
        self.y_test = y_test

    def fit(self, train: LabelledCollection):
        self.classifier.fit(*train.Xy)
        self.classes = train.classes_
        return self

    def show_true_labels(self, y_test):
        self.y_test = y_test

    def predict(self, test):
        predictions = self.classifier.predict(test)
        return confusion_matrix(self.y_test, predictions, labels=self.classes)


def get_counters(y_true, y_pred):
    counters = np.full(shape=y_true.shape, fill_value=-1)
    counters[np.logical_and(y_true == 1, y_pred == 1)] = 0
    counters[np.logical_and(y_true == 1, y_pred == 0)] = 1
    counters[np.logical_and(y_true == 0, y_pred == 1)] = 2
    counters[np.logical_and(y_true == 0, y_pred == 0)] = 3
    class_map = {
        0:'tp',
        1:'fn',
        2:'fp',
        3:'tn'
    }
    return counters, class_map


def safehstack(matrix, posteriors):
    if issparse(matrix):
        instances = csr_matrix(scipy.sparse.hstack([matrix, posteriors]))
    else:
        instances = np.hstack([matrix, posteriors])
    return instances


class QuantificationCMPredictor(ConfusionMatrixPredictor):
    """
    """
    def __init__(self, classifier, quantifier, strategy='kfcv', **kwargs):
        assert strategy in ['kfcv'], 'unknown strategy'
        if strategy=='kfcv':
            assert 'k' in kwargs, 'strategy "kfcv" requires "k" to be passed as an argument'
        self.classifier = clone(classifier)
        self.quantifier = quantifier
        self.strategy = strategy
        self.kwargs = kwargs

    def sout(self, msg):
        if 'verbose' in self.kwargs:
            print(msg)

    def fit(self, train: LabelledCollection):
        X, y = train.Xy
        if self.strategy == 'kfcv':
            k=self.kwargs['k']
            n_jobs = self.kwargs['n_jobs'] if 'n_jobs' in self.kwargs else 1
            self.sout(f'{self.__class__.__name__}: '
                      f'running cross_val_predict with k={k} n_jobs={n_jobs}')
            predictions = cross_val_predict(self.classifier, X, y, cv=k, n_jobs=n_jobs, method='predict')
            posteriors  = cross_val_predict(self.classifier, X, y, cv=k, n_jobs=n_jobs, method='predict_proba')
            self.classifier.fit(X, y)
            instances = safehstack(train.instances, posteriors)
            counters, class_map = get_counters(train.labels, predictions)
            q_data = LabelledCollection(instances=instances, labels=counters, classes_=[0,1,2,3])
            print('counters prevalence', q_data.counts())
            self.quantifier.fit(q_data)
        return self

    def predict(self, test):
        """

        :param test: test collection (ignored)
        :return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix`
        """
        posteriors = self.classifier.predict_proba(test)
        instances = safehstack(test, posteriors)
        counters = self.quantifier.quantify(instances)
        tp, fn, fp, tn = counters
        conf_matrix = np.asarray([[tn, fp], [fn, tp]])
        return conf_matrix

    def quantify(self, test):
        posteriors = self.classifier.predict_proba(test)
        instances = safehstack(test, posteriors)
        counters = self.quantifier.quantify(instances)
        tp, fn, fp, tn = counters
        den_tpr = (tp+fn)
        if den_tpr>0:
            tpr = tp/den_tpr
        else:
            tpr = 1

        den_fpr = (fp+tn)
        if den_fpr>0:
            fpr = fp / den_fpr
        else:
            fpr = 0

        pcc = posteriors.sum(axis=0)[1]
        pacc = (pcc-fpr)/(tpr-fpr)
        pacc = np.clip(pacc, 0, 1)

        q = tp+fn
        return q