import time
from functools import wraps
import os
from os.path import join
from result_table.src.table import Table
import numpy as np
from constants import *

def measuretime(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        time_it_took = end_time - start_time
        if isinstance(result, tuple):
            return (*result, time_it_took)
        else:
            return result, time_it_took
    return wrapper


def plot_bandwidth(dataset_name, test_results, bandwidths, triplet_list_results):
    import matplotlib.pyplot as plt

    print("PLOT", dataset_name)
    print(dataset_name)

    plt.figure(figsize=(8, 6))

    # show test results
    plt.plot(bandwidths, test_results, marker='o', color='k')

    colors = plt.cm.tab10(np.linspace(0, 1, len(triplet_list_results)))
    for i, (method_name, method_choice, method_time) in enumerate(triplet_list_results):
        plt.axvline(x=method_choice, linestyle='--', label=method_name, color=colors[i])

    # Agregar etiquetas y título
    plt.xlabel('Bandwidth')
    plt.ylabel('MAE')
    plt.title(dataset_name)

    # Mostrar la leyenda
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    # Mostrar la gráfica
    plt.grid(True)

    plotdir = './plots'
    if DEBUG:
        plotdir = './plots_debug'
    os.makedirs(plotdir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(f'{plotdir}/{dataset_name}.png')
    plt.close()

def error_table(dataset_name, test_results, bandwidth_range, triplet_list_results):
    best_bandwidth = bandwidth_range[np.argmin(test_results)]
    best_score = np.min(test_results)
    print(f'Method\tChoice\tAE\tTime')
    table=Table(name=dataset_name)
    table.format.with_mean=False
    table.format.with_rank_mean = False
    table.format.show_std = False
    for method_name, method_choice, took in triplet_list_results:
        if method_choice in bandwidth_range:
            index = np.where(bandwidth_range == method_choice)[0][0]
            method_score = test_results[index]
        else:
            method_score = 1
        error = np.abs(best_score-method_score)
        table.add(benchmark='Choice', method=method_name, v=method_choice)
        table.add(benchmark='ScoreChoice', method=method_name, v=method_score)
        table.add(benchmark='Best', method=method_name, v=best_bandwidth)
        table.add(benchmark='ScoreBest', method=method_name, v=best_score)
        table.add(benchmark='AE', method=method_name, v=error)
        table.add(benchmark='Time', method=method_name, v=took)
    outpath = './tables'
    if DEBUG:
        outpath = './tables_debug'
    table.latexPDF(join(outpath, dataset_name+'.pdf'), transpose=True)