import os
import pickle
import shutil
import numpy as np
from sklearn.linear_model import LogisticRegression
from os.path import join
import quapy as qp
from quapy.method.aggregative import KDEyML
from quapy.protocol import UPP
from kdey_devel import KDEyMLauto
from utils import *
from constants import *
import quapy.functional as F


qp.environ["SAMPLE_SIZE"] = 100 if DEBUG else 500
val_repeats  = 100 if DEBUG else 500
test_repeats = 100 if DEBUG else 500
if DEBUG:
    qp.environ["DEFAULT_CLS"] = LogisticRegression()

test_results = {}
val_choice = {}

bandwidth_range = np.linspace(0.01, 0.20, 20)
if DEBUG:
    bandwidth_range = np.linspace(0.01, 0.20, 5)


def datasets():
    dataset_list = qp.datasets.UCI_MULTICLASS_DATASETS[:4]
    if DEBUG:
        dataset_list = dataset_list[:4]
    for dataset_name in dataset_list:
        dataset = qp.datasets.fetch_UCIMulticlassDataset(dataset_name)
        if DEBUG:
            dataset = dataset.reduce(random_state=0)
        yield dataset


@measuretime
def predict_b_modsel(dataset):
    # bandwidth chosen during model selection in validation
    train = dataset.training
    train_tr, train_va = train.split_stratified(random_state=0)
    kdey = KDEyML(random_state=0)
    modsel = qp.model_selection.GridSearchQ(
        model=kdey,
        param_grid={'bandwidth': bandwidth_range},
        protocol=UPP(train_va, repeats=val_repeats),
        refit=False,
        n_jobs=-1,
        verbose=True
    ).fit(train_tr)
    chosen_bandwidth = modsel.best_params_['bandwidth']
    modsel_choice = float(chosen_bandwidth)
    # kdey.set_params(bandwidth=chosen_bandwidth)
    # kdey.fit(train)
    # kdey.qua
    return modsel_choice

@measuretime
def predict_b_kdeymlauto(dataset):
    # bandwidth chosen during model selection in validation
    train, test = dataset.train_test
    kdey = KDEyMLauto(random_state=0)
    print(f'true-prevalence: {F.strprev(test.prevalence())}')
    chosen_bandwidth, _ = kdey.chose_bandwidth(train, test.X)
    auto_bandwidth = float(chosen_bandwidth)
    return auto_bandwidth


def in_test_search(dataset, n_jobs=-1):
    train, test = dataset.train_test

    print(f"generating true tests scores using KDEy in {dataset.name}")

    def experiment_job(bandwidth):
        kdey = KDEyML(bandwidth=bandwidth, random_state=0)
        kdey.fit(train)
        test_gen = UPP(test, repeats=test_repeats)
        mae = qp.evaluation.evaluate(kdey, protocol=test_gen, error_metric='mae', verbose=True)
        print(f'{bandwidth=}: {mae:.5f}')
        return float(mae)

    dataset_results = qp.util.parallel(experiment_job, bandwidth_range, n_jobs=n_jobs)
    return dataset_results, bandwidth_range






for dataset in datasets():
    print('NAME', dataset.name)
    print(len(dataset.training))
    print(len(dataset.test))

    result_path = f'./results/{dataset.name}/'
    if DEBUG:
        result_path = result_path.replace('results', 'results_debug')
        if os.path.exists(result_path):
            shutil.rmtree(result_path)

    dataset_results, bandwidth_range = qp.util.pickled_resource(join(result_path, 'test.pkl'), in_test_search, dataset)

    triplet_list_results = []
    modsel_choice, modsel_time = qp.util.pickled_resource(join(result_path, 'modsel.pkl'), predict_b_modsel, dataset)
    triplet_list_results.append(('modsel', modsel_choice, modsel_time,))
    auto_choice, auto_time = qp.util.pickled_resource(join(result_path, 'auto.pkl'), predict_b_kdeymlauto, dataset)
    triplet_list_results.append(('auto', auto_choice, auto_time,))

    print(f'Dataset = {dataset.name}')
    print(modsel_choice)
    print(dataset_results)

    plot_bandwidth(dataset.name, dataset_results, bandwidth_range, triplet_list_results)
    error_table(dataset.name, dataset_results, bandwidth_range, triplet_list_results)
    # time_table(dataset.name, dataset_results, bandwidth_range, triplet_list_results)