import numpy as np
import quapy as qp
from sklearn.multioutput import MultiOutputRegressor
from sklearn.svm import SVR

from quapy.data import LabelledCollection
from quapy.method.base import BaseQuantifier
from quapy.method.aggregative import AggregativeSoftQuantifier


class LocalStackingQuantification(BaseQuantifier):

    def __init__(self, surrogate_quantifier, n_samples_gen=200, n_samples_sel=50, comparison_measure='ae', random_state=None):
        assert isinstance(surrogate_quantifier, AggregativeSoftQuantifier), \
            f'the surrogate quantifier must be of type {AggregativeSoftQuantifier.__class__.__name__}'
        self.surrogate_quantifier = surrogate_quantifier
        self.n_samples_gen = n_samples_gen
        self.n_samples_sel = n_samples_sel
        self.comparison_measure = qp.error.from_name(comparison_measure)
        self.random_state = random_state

    def fit(self, data: LabelledCollection):
        train, val = data.split_stratified()
        self.surrogate_quantifier.fit(train)
        self.val_data = val
        return self

    def normalize(self, out_simplex:np.ndarray):
        in_simplex = out_simplex/out_simplex.sum()
        return in_simplex

    def quantify(self, instances: np.ndarray):
        assert hasattr(self, 'val_data'), 'quantify called before fit'
        pred_prevs = self.surrogate_quantifier.quantify(instances)
        test_size = instances.shape[0]

        samples = []
        samples_pred_prevs = []
        samples_distance = []
        for i in range(self.n_samples_gen):
            sample_i = self.val_data.sampling(test_size, *pred_prevs, random_state=self.random_state)
            pred_prev_sample_i = self.surrogate_quantifier.quantify(sample_i.X)
            err_dist = self.comparison_measure(pred_prevs, pred_prev_sample_i)

            samples.append(sample_i)
            samples_pred_prevs.append(pred_prev_sample_i)
            samples_distance.append(err_dist)

        ord_distances = np.argsort(samples_distance)
        samples_sel = np.asarray(samples)[ord_distances][:self.n_samples_sel]
        samples_pred_prevs_sel = np.asarray(samples_pred_prevs)[ord_distances][:self.n_samples_sel]

        reg = MultiOutputRegressor(SVR())
        reg_X = samples_pred_prevs_sel
        reg_y = [s.prevalence() for s in samples_sel]
        reg.fit(reg_X, reg_y)

        corrected_prev = reg.predict([pred_prevs])[0]

        corrected_prev = self.normalize(corrected_prev)
        return corrected_prev



class LocalStackingQuantification2(BaseQuantifier):

    """
    Este en vez de seleccionar samples de training para los que la prevalencia predicha se parece a la prevalencia
    predica en test, saca directamente samples de training con la prevalencia predicha en test
    """

    def __init__(self, surrogate_quantifier, n_samples_gen=200, n_samples_sel=50, comparison_measure='ae', random_state=None):
        assert isinstance(surrogate_quantifier, AggregativeSoftQuantifier), \
            f'the surrogate quantifier must be of type {AggregativeSoftQuantifier.__class__.__name__}'
        self.surrogate_quantifier = surrogate_quantifier
        self.n_samples_gen = n_samples_gen
        self.n_samples_sel = n_samples_sel
        self.comparison_measure = qp.error.from_name(comparison_measure)
        self.random_state = random_state

    def fit(self, data: LabelledCollection):
        train, val = data.split_stratified()
        self.surrogate_quantifier.fit(train)
        self.val_data = val
        return self

    def normalize(self, out_simplex:np.ndarray):
        in_simplex = out_simplex/out_simplex.sum()
        return in_simplex

    def quantify(self, instances: np.ndarray):
        assert hasattr(self, 'val_data'), 'quantify called before fit'
        pred_prevs = self.surrogate_quantifier.quantify(instances)
        test_size = instances.shape[0]

        samples = []
        samples_pred_prevs = []
        for i in range(self.n_samples_gen):
            sample_i = self.val_data.sampling(test_size, *pred_prevs, random_state=self.random_state)
            pred_prev_sample_i = self.surrogate_quantifier.quantify(sample_i.X)
            samples.append(sample_i)
            samples_pred_prevs.append(pred_prev_sample_i)

        reg = MultiOutputRegressor(SVR())
        reg_X = samples_pred_prevs
        reg_y = [s.prevalence() for s in samples]
        reg.fit(reg_X, reg_y)

        corrected_prev = reg.predict([pred_prevs])[0]

        corrected_prev = self.normalize(corrected_prev)
        return corrected_prev