import pickle
import os
from time import time
from collections import defaultdict
from tqdm import tqdm
import numpy as np
from sklearn.linear_model import LogisticRegression

import quapy as qp
from KDEy.kdey_devel import KDEyMLauto, optim_minimize
from method._kdey import KDEBase
from quapy.method.aggregative import PACC, EMQ, KDEyML
from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP
from pathlib import Path
from quapy import functional as F
import matplotlib.pyplot as plt


SEED = 1


def newLR():
    return LogisticRegression(max_iter=1000)


def plot(xaxis, metrics_measurements, metrics_names, suffix):
    fig, ax1 = plt.subplots(figsize=(8, 6))

    def add_plot(ax, mean_error, std_error, name, color, marker):
        ax.plot(xaxis, mean_error, label=name, marker=marker, color=color)
        if std_error is not None:
            ax.fill_between(xaxis, mean_error - std_error, mean_error + std_error, color=color, alpha=0.2)

    colors = ['b', 'g', 'r', 'c', 'purple']

    def get_mean_std(measurement):
        measurement = np.asarray(measurement)
        measurement_mean = np.mean(measurement, axis=0)
        if measurement.ndim == 2:
            measurement_std = np.std(measurement, axis=0)
        else:
            measurement_std = None
        return measurement_mean, measurement_std

    for i, (measurement, name) in enumerate(zip(metrics_measurements, metrics_names)):
        color = colors[i%len(colors)]
        add_plot(ax1, *get_mean_std(measurement), name, color=color, marker='o')

    ax1.set_xscale('log')

    # Configurar etiquetas para el primer eje Y
    ax1.set_xlabel('Bandwidth')
    ax1.set_ylabel('Normalized value')
    ax1.grid(True)
    ax1.legend(loc='upper left')

    # Crear un segundo eje Y que comparte el eje X
    # ax2 = ax1.twinx()

    # Pintar likelihood_val en el segundo eje Y
    # add_plot(ax2, *get_mean_std(likelihood_measurements), name='NLL', color='purple', marker='x')

    # Configurar etiquetas para el segundo eje Y
    # ax1.set_ylabel('neg log likelihood')
    # ax1.legend(loc='upper right')

    # Mostrar el gráfico
    plt.title(dataset)
    # plt.show()
    os.makedirs('./plots/likelihood/', exist_ok=True)

    plt.savefig(f'./plots/likelihood/{dataset}-fig{suffix}.png')
    plt.close()


def generate_data(from_train=False):
    data = qp.datasets.fetch_UCIMulticlassDataset(dataset)
    n_classes = data.n_classes
    print(f'{i=}')
    print(f'{dataset=}')
    print(f'{n_classes=}')
    print(len(data.training))
    print(len(data.test))

    train, test = data.train_test
    if from_train:
        train, test = train.split_stratified(0.5)
    train_prev = train.prevalence()
    test_prev = test.prevalence()

    print(f'train-prev = {F.strprev(train_prev)}')
    print(f'test-prev = {F.strprev(test_prev)}')

    repeats = 10
    prot = UPP(test, sample_size=SAMPLE_SIZE, repeats=repeats)
    kde = KDEyMLauto(newLR())
    kde.fit(train)
    AE_error, RAE_error, MSE_error, KLD_error, LIKE_value = [], [], [], [], []
    tr_posteriors, tr_y = kde.classif_predictions.Xy
    for sample_no, (sample, prev) in tqdm(enumerate(prot()), total=repeats):
        te_posteriors = kde.classifier.predict_proba(sample)
        classes = train.classes_

        xaxis = []
        ae_error = []
        rae_error = []
        mse_error = []
        kld_error = []
        likelihood_value = []

        # for bandwidth in np.linspace(0.01, 0.2, 50):
        for bandwidth in np.logspace(-5, np.log10(0.2), 50):
            mix_densities = kde.get_mixture_components(tr_posteriors, tr_y, classes, bandwidth)
            test_densities = [kde.pdf(kde_i, te_posteriors) for kde_i in mix_densities]

            def neg_loglikelihood_prev(prev):
                test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prev, test_densities))
                test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
                return -np.sum(test_loglikelihood)

            init_prev = np.full(fill_value=1 / n_classes, shape=(n_classes,))
            pred_prev, likelihood = optim_minimize(neg_loglikelihood_prev, init_prev, return_loss=True)

            xaxis.append(bandwidth)
            ae_error.append(qp.error.ae(prev, pred_prev))
            rae_error.append(qp.error.rae(prev, pred_prev))
            mse_error.append(qp.error.mse(prev, pred_prev))
            kld_error.append(qp.error.kld(prev, pred_prev))
            likelihood_value.append(likelihood)

        AE_error.append(ae_error)
        RAE_error.append(rae_error)
        MSE_error.append(mse_error)
        KLD_error.append(kld_error)
        LIKE_value.append(likelihood_value)

    return xaxis, AE_error, RAE_error, MSE_error, KLD_error, LIKE_value


def normalize_metric(Error_matrix):
    max_val, min_val = np.max(Error_matrix), np.min(Error_matrix)
    return (np.asarray(Error_matrix) - min_val) / (max_val - min_val)


SAMPLE_SIZE=150
qp.environ['SAMPLE_SIZE'] = SAMPLE_SIZE

show_ae = True
show_rae = True
show_mse = False
show_kld = True
normalize = True

epsilon = 1e-10
DATASETS = qp.datasets.UCI_MULTICLASS_DATASETS
for i, dataset in enumerate(tqdm(DATASETS, desc='processing datasets', total=len(DATASETS))):


    xaxis, AE_error_te, RAE_error_te, MSE_error_te, KLD_error_te, LIKE_value_te = qp.util.pickled_resource(
        f'./plots/likelihood/pickles/{dataset}.pkl', generate_data, False
    )

    xaxis, AE_error_tr, RAE_error_tr, MSE_error_tr, KLD_error_tr, LIKE_value_tr = qp.util.pickled_resource(
        f'./plots/likelihood/pickles/{dataset}_tr.pkl', generate_data, True
    )


    # Test measurements
    # ----------------------------------------------------------------------------------------------------
    measurements = []
    measurement_names = []
    if show_ae:
        measurements.append(AE_error_te)
        measurement_names.append('AE')
    if show_rae:
        measurements.append(RAE_error_te)
        measurement_names.append('RAE')
    if show_kld:
        measurements.append(KLD_error_te)
        measurement_names.append('KLD')
    if show_mse:
        measurements.append(MSE_error_te)
        measurement_names.append('MSE')
    measurements.append(normalize_metric(LIKE_value_te))
    measurements.append(normalize_metric(LIKE_value_tr))
    measurement_names.append('NLL(te)')
    measurement_names.append('NLL(tr)')

    if normalize:
        measurements = [normalize_metric(m) for m in measurements]

    # plot(xaxis, measurements, measurement_names, suffix='AVE')

    # Train-Test measurements
    # ----------------------------------------------------------------------------------------------------
    # measurements = []
    # measurement_names = []
    # measurements.append(normalize_metric(LIKE_value_te))
    # measurements.append(normalize_metric(LIKE_value_tr))
    # measurement_names.append('NLL(te)')
    # measurement_names.append('NLL(tr)')
    plot(xaxis, measurements, measurement_names, suffix='AVEtr')