import pickle
import os
from time import time
from collections import defaultdict

import numpy as np
from sklearn.linear_model import LogisticRegression

import quapy as qp
from KDEy.kdey_devel import KDEyMLauto, KDEyMLauto2, KDEyMLred
from quapy.method.aggregative import PACC, EMQ, KDEyML
from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP
from pathlib import Path

SEED = 1


def newLR():
    return LogisticRegression(max_iter=3000)


# typical hyperparameters explored for Logistic Regression
logreg_grid = {
    'C': np.logspace(-3,3,7),
    'class_weight': [None, 'balanced']
}


def wrap_hyper(classifier_hyper_grid: dict):
    return {'classifier__' + k: v for k, v in classifier_hyper_grid.items()}


METHODS = [
    ('PACC', PACC(newLR()), wrap_hyper(logreg_grid)),
    ('EMQ', EMQ(newLR()), wrap_hyper(logreg_grid)),
    ('KDEy-ML',  KDEyML(newLR()), {**wrap_hyper(logreg_grid), **{'bandwidth': np.logspace(-4, np.log10(0.2), 20)}}),
    # ('KDEy-MLred',  KDEyMLred(newLR()), {**wrap_hyper(logreg_grid), **{'bandwidth': np.logspace(-4, np.log10(0.2), 20)}}),
    ('KDEy-ML-scott', KDEyML(newLR(), bandwidth='scott'), wrap_hyper(logreg_grid)),
    ('KDEy-ML-silver', KDEyML(newLR(), bandwidth='silverman'), wrap_hyper(logreg_grid)),
    ('KDEy-ML-autoLike',  KDEyMLauto2(newLR(), bandwidth='auto', target='likelihood'), wrap_hyper(logreg_grid)),
    ('KDEy-ML-autoLike+',  KDEyMLauto2(newLR(), bandwidth='auto', target='likelihood+'), wrap_hyper(logreg_grid)), 
    ('KDEy-ML-autoAE',  KDEyMLauto2(newLR(), bandwidth='auto', target='mae'), wrap_hyper(logreg_grid)),
    ('KDEy-ML-autoRAE',  KDEyMLauto2(newLR(), bandwidth='auto', target='mrae'), wrap_hyper(logreg_grid)),
]


"""
TKDEyML era primero bandwidth (init 0.05) y luego prevalence (init uniform)
TKDEyML2 era primero prevalence (init uniform) y luego bandwidth (init 0.05)
TKDEyML3 era primero prevalence (init uniform) y luego bandwidth (init 0.1)
TKDEyML4 es como ML2 pero max 5 iteraciones por optimización 
"""
TRANSDUCTIVE_METHODS = [
    #('TKDEy-ML',  KDEyMLauto(newLR()), None),
    # ('TKDEy-MLboth',  KDEyMLauto(newLR(), optim='both'), None),
    # ('TKDEy-MLbothfine',  KDEyMLauto(newLR(), optim='both_fine'), None),
    # ('TKDEy-ML2',  KDEyMLauto(newLR(), optim='two_steps'), None),
    # ('TKDEy-MLike',  KDEyMLauto(newLR(), optim='max_likelihood'), None),
    # ('TKDEy-MLike2',  KDEyMLauto(newLR(), optim='max_likelihood2'), None),
    #('TKDEy-ML3',  KDEyMLauto(newLR()), None),
    #('TKDEy-ML4',  KDEyMLauto(newLR()), None),
]

def show_results(result_path):
    import pandas as pd
    df = pd.read_csv(result_path + '.csv', sep='\t')
    pd.set_option('display.max_columns', None)
    pd.set_option('display.max_rows', None)
    pd.set_option('display.width', 1000)  # Ajustar el ancho máximo
    pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE"], margins=True)
    print(pv)
    pv = df.pivot_table(index='Dataset', columns="Method", values=["MRAE"], margins=True)
    print(pv)
    pv = df.pivot_table(index='Dataset', columns="Method", values=["KLD"], margins=True)
    print(pv)
    pv = df.pivot_table(index='Dataset', columns="Method", values=["TR-TIME"], margins=True)
    print(pv)
    pv = df.pivot_table(index='Dataset', columns="Method", values=["TE-TIME"], margins=True)
    print(pv)


if __name__ == '__main__':

    qp.environ['SAMPLE_SIZE'] = 500
    qp.environ['N_JOBS'] = -1
    n_bags_val = 25
    n_bags_test = 100
    result_dir = f'results_quantification/ucimulti'

    os.makedirs(result_dir, exist_ok=True)

    global_result_path = f'{result_dir}/allmethods'
    with open(global_result_path + '.csv', 'wt') as csv:
        csv.write(f'Method\tDataset\tMAE\tMRAE\tKLD\tTR-TIME\tTE-TIME\n')

    for method_name, quantifier, param_grid in METHODS + TRANSDUCTIVE_METHODS:

        print('Init method', method_name)

        with open(global_result_path + '.csv', 'at') as csv:
            for dataset in qp.datasets.UCI_MULTICLASS_DATASETS:
                print('init', dataset)

                # run_experiment(global_result_path, method_name, quantifier, param_grid, dataset)
                local_result_path = os.path.join(Path(global_result_path).parent, method_name + '_' + dataset + '.dataframe')

                if os.path.exists(local_result_path):
                    print(f'result file {local_result_path} already exist; skipping')
                    report = qp.util.load_report(local_result_path)

                else:
                    with qp.util.temp_seed(SEED):

                        data = qp.datasets.fetch_UCIMulticlassDataset(dataset, verbose=True)
                        train, test = data.train_test

                        transductive_names = [name for (name, *_) in TRANSDUCTIVE_METHODS]

                        if method_name not in transductive_names:
                            if len(param_grid) == 0:
                                t_init = time()
                                quantifier.fit(train)
                                train_time = time() - t_init
                            else:
                                # model selection (train)
                                train, val = train.split_stratified(random_state=SEED)
                                protocol = UPP(val, repeats=n_bags_val)
                                modsel = GridSearchQ(
                                    quantifier, param_grid, protocol, refit=True, n_jobs=-1, verbose=1, error='mae'
                                )
                                t_init = time()
                                try:
                                    modsel.fit(train)
                                    print(f'best params {modsel.best_params_}')
                                    print(f'best score {modsel.best_score_}')
                                    quantifier = modsel.best_model()
                                except:
                                    print('something went wrong... trying to fit the default model')
                                    quantifier.fit(train)
                                train_time = time() - t_init
                        else:
                            # transductive
                            t_init = time()
                            quantifier.fit(train)  # <-- nothing actually (proyects the X into posteriors only)
                            train_time = time() - t_init

                        # test
                        t_init = time()
                        protocol = UPP(test, repeats=n_bags_test)
                        report = qp.evaluation.evaluation_report(
                            quantifier, protocol, error_metrics=['mae', 'mrae', 'kld'], verbose=True
                        )
                        test_time = time() - t_init
                        report['tr_time'] = train_time
                        report['te_time'] = test_time
                        report.to_csv(local_result_path)

                means = report.mean(numeric_only=True)
                csv.write(f'{method_name}\t{dataset}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{means["kld"]:.5f}\t{means["tr_time"]:.3f}\t{means["te_time"]:.3f}\n')
                csv.flush()

    show_results(global_result_path)