import itertools
import json
import os
from collections import defaultdict
from glob import glob
from os import makedirs
from os.path import join
from pathlib import Path
from time import time

import matplotlib.pyplot as plt
from sklearn.datasets import fetch_rcv1

from quapy.method.aggregative import EMQ, ACC
from models_multiclass import *
from quapy.data import LabelledCollection
from quapy.data.datasets import fetch_UCIMulticlassLabelledCollection, UCI_MULTICLASS_DATASETS
from quapy.data.datasets import fetch_reviews


def gen_classifiers():
    yield 'LR', LogisticRegression()
    #yield 'NB', GaussianNB()
    #yield 'SVM(rbf)', SVC()
    #yield 'SVM(linear)', LinearSVC()


def gen_multi_datasets(only_names=False)-> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
    for dataset_name in UCI_MULTICLASS_DATASETS:
        if only_names:
            yield dataset_name, None
        else:
            dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
            yield dataset_name, split(dataset)


def gen_bin_datasets(only_names=False) -> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
    if only_names:
        for dataset_name in ['imdb', 'CCAT', 'GCAT', 'MCAT']:
            yield dataset_name, None
    else:
        train, U = fetch_reviews('imdb', tfidf=True, min_df=10, pickle=True).train_test
        L, V = train.split_stratified(0.5, random_state=0)
        yield 'imdb', (L, V, U)

        training = fetch_rcv1(subset='train')
        test = fetch_rcv1(subset='test')
        class_names = training.target_names.tolist()
        for cat in ['CCAT', 'GCAT', 'MCAT']:
            class_idx = class_names.index(cat)
            tr_labels = training.target[:,class_idx].toarray().flatten()
            te_labels = test.target[:,class_idx].toarray().flatten()
            tr = LabelledCollection(training.data, tr_labels)
            U = LabelledCollection(test.data, te_labels)
            L, V = tr.split_stratified(train_prop=0.5, random_state=0)
            yield cat, (L, V, U)


def gen_CAP(h, acc_fn)->[str, ClassifierAccuracyPrediction]:
    #yield 'SebCAP', SebastianiCAP(h, acc_fn, ACC)
    yield 'SebCAP-SLD', SebastianiCAP(h, acc_fn, EMQ)
    #yield 'SebCAPweight', SebastianiCAP(h, acc_fn, ACC, alpha=0)
    #yield 'PabCAP', PabloCAP(h, acc_fn, ACC)
    yield 'PabCAP-SLD-median', PabloCAP(h, acc_fn, EMQ, aggr='median')


def gen_CAP_cont_table(h)->[str,CAPContingencyTable]:
    acc_fn = None
    yield 'Naive', NaiveCAP(h, acc_fn)
    yield 'CT-PPS-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
    yield 'QuAcc(EMQ)nxn', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()))
    #yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression()))
    yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression()))
    #yield 'CT-PPSh-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()), reuse_h=True)
    #yield 'Equations-ACCh', NsquaredEquationsCAP(h, acc_fn, ACC, reuse_h=True)
    # yield 'Equations-ACC', NsquaredEquationsCAP(h, acc_fn, ACC)
    #yield 'Equations-SLD', NsquaredEquationsCAP(h, acc_fn, EMQ)


def get_method_names():
    mock_h = LogisticRegression()
    return [m for m, _ in gen_CAP(mock_h, None)] + [m for m, _ in gen_CAP_cont_table(mock_h)]


def gen_acc_measure():
    yield 'vanilla_accuracy', vanilla_acc_fn
    #yield 'macro-F1', macrof1


def split(data: LabelledCollection):
    train_val, test = data.split_stratified(train_prop=0.66, random_state=0)
    train, val = train_val.split_stratified(train_prop=0.5, random_state=0)
    return train, val, test


def fit_method(method, V):
    tinit = time()
    method.fit(V)
    t_train = time() - tinit
    return method, t_train


def predictionsCAP(method, test_prot):
    tinit = time()
    estim_accs = [method.predict(Ui.X) for Ui in test_prot()]
    t_test_ave = (time() - tinit) / test_prot.total()
    return estim_accs, t_test_ave


def predictionsCAPcont_table(method, test_prot, gen_acc_measure):
    estim_accs_dict = {}
    tinit = time()
    estim_tables = [method.predict_ct(Ui.X) for Ui in test_prot()]
    for acc_name, acc_fn in gen_acc_measure():
        estim_accs_dict[acc_name] = [acc_fn(cont_table) for cont_table in estim_tables]
    t_test_ave = (time() - tinit) / test_prot.total()
    return estim_accs_dict, t_test_ave


def any_missing(basedir, cls_name, dataset_name, method_name):
    for acc_name, _ in gen_acc_measure():
        if not os.path.exists(getpath(basedir, cls_name, acc_name, dataset_name, method_name)):
            return True
    return False


def true_acc(h:BaseEstimator, acc_fn: callable, U: LabelledCollection):
    y_pred = h.predict(U.X)
    y_true = U.y
    conf_table = confusion_matrix(y_true, y_pred=y_pred, labels=U.classes_)
    return acc_fn(conf_table)


def vanilla_acc_fn(cont_table):
    return np.diag(cont_table).sum() / cont_table.sum()


def _f1_bin(tp, fp, fn):
    if tp + fp + fn == 0:
        return 1
    else:
        return (2 * tp) / (2 * tp + fp + fn)


def macrof1(cont_table):
    n = cont_table.shape[0]

    if n==2:
        tp = cont_table[1,1]
        fp = cont_table[0,1]
        fn = cont_table[1,0]
        return _f1_bin(tp, fp, fn)

    f1_per_class = []
    for i in range(n):
        tp = cont_table[i,i]
        fp = cont_table[:,i].sum() - tp
        fn = cont_table[i,:].sum() - tp
        f1_per_class.append(_f1_bin(tp, fp, fn))
    return np.mean(f1_per_class)


def microf1(cont_table):
    n = cont_table.shape[0]

    if n == 2:
        tp = cont_table[1, 1]
        fp = cont_table[0, 1]
        fn = cont_table[1, 0]
        return _f1_bin(tp, fp, fn)

    tp, fp, fn = 0, 0, 0
    for i in range(n):
        tp += cont_table[i, i]
        fp += cont_table[:, i] - tp
        fn += cont_table[i, :] - tp
    return _f1_bin(tp, fp, fn)


def cap_errors(true_acc, estim_acc):
    true_acc = np.asarray(true_acc)
    estim_acc = np.asarray(estim_acc)
    #return (true_acc - estim_acc)**2
    return np.abs(true_acc - estim_acc)


def plot_diagonal(cls_name, measure_name, results, base_dir='plots'):

    makedirs(base_dir, exist_ok=True)
    makedirs(join(base_dir, measure_name), exist_ok=True)

    # Create scatter plot
    plt.figure(figsize=(10, 10))
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.plot([0, 1], [0, 1], color='black', linestyle='--')

    for method_name in results.keys():
        xs = results[method_name]['true_acc']
        ys = results[method_name]['estim_acc']
        err = cap_errors(xs, ys).mean()
        #pear_cor, _ = 0, 0  #pearsonr(xs, ys)
        plt.scatter(xs, ys, label=f'{method_name} {err:.3f}', alpha=0.6)

    plt.legend()

    # Add labels and title
    plt.xlabel(f'True {measure_name}')
    plt.ylabel(f'Estimated {measure_name}')

    # Display the plot
    # plt.show()
    plt.savefig(join(base_dir, measure_name, 'diagonal_'+cls_name+'.png'))


def getpath(basedir, cls_name, acc_name, dataset_name, method_name):
    return f"results/{basedir}/{cls_name}/{acc_name}/{dataset_name}/{method_name}.json"


def open_results(basedir, cls_name, acc_name, dataset_name='*', method_name='*'):
    results = defaultdict(lambda : {'true_acc':[], 'estim_acc':[]})
    if isinstance(method_name, str):
        method_name = [method_name]
    if isinstance(dataset_name, str):
        dataset_name = [dataset_name]
    for dataset_, method_ in itertools.product(dataset_name, method_name):
        path = getpath(basedir, cls_name, acc_name, dataset_, method_)
        for file in glob(path):
            #print(file)
            method = Path(file).name.replace('.json','')
            result = json.load(open(file, 'r'))
            results[method]['true_acc'].extend(result['true_acc'])
            results[method]['estim_acc'].extend(result['estim_acc'])
    return results


def save_json_file(path, data):
    os.makedirs(Path(path).parent, exist_ok=True)
    with open(path, 'w') as f:
        json.dump(data, f)


def save_json_result(path, true_accs, estim_accs, t_train, t_test):
    result = {
        't_train': t_train,
        't_test_ave': t_test,
        'true_acc': true_accs,
        'estim_acc': estim_accs
    }
    save_json_file(path, result)


def get_dataset_stats(path, test_prot, L, V):
    test_prevs = [Ui.prevalence() for Ui in test_prot()]
    shifts = [qp.error.ae(L.prevalence(), Ui_prev) for Ui_prev in test_prevs]
    info = {
        'n_classes': L.n_classes,
        'n_train': len(L),
        'n_val': len(V),
        'train_prev': L.prevalence().tolist(),
        'val_prev': V.prevalence().tolist(),
        'test_prevs': [x.tolist() for x in test_prevs],
        'shifts': [x.tolist() for x in shifts],
        'sample_size': test_prot.sample_size,
        'num_samples': test_prot.total()
    }
    save_json_file(path, info)


def gen_tables(basedir, datasets):
    from tabular import Table

    mock_h = LogisticRegression(),
    methods = [method for method, _ in gen_CAP(mock_h, None)] + [method for method, _ in gen_CAP_cont_table(mock_h)]
    classifiers = [classifier for classifier, _ in gen_classifiers()]
    measures = [measure for measure, _ in gen_acc_measure()]

    os.makedirs('tables', exist_ok=True)

    tex_doc = """
    \\documentclass[10pt,a4paper]{article}
    \\usepackage[utf8]{inputenc}
    \\usepackage{amsmath}
    \\usepackage{amsfonts}
    \\usepackage{amssymb}
    \\usepackage{graphicx}
    \\usepackage{tabularx}
    \\usepackage{color}
    \\usepackage{colortbl}
    \\usepackage{xcolor}
    \\begin{document}
    """

    classifier = classifiers[0]
    metric = "vanilla_accuracy"

    table = Table(datasets, methods)
    for method, dataset in itertools.product(methods, datasets):
        path = getpath(basedir, classifier, metric, dataset, method)
        if not os.path.exists(path):
            print('missing ', path)
            continue
        results = json.load(open(path, 'r'))
        true_acc = results['true_acc']
        estim_acc = np.asarray(results['estim_acc'])
        if any(np.isnan(estim_acc)):
            print(f'nan values found in {method=} {dataset=}')
            continue
        if any(estim_acc>1.00001):
            print(f'values >1 found in {method=} {dataset=} [max={estim_acc.max()}]')
            continue
        if any(estim_acc<-0.00001):
            print(f'values <0 found in {method=} {dataset=} [min={estim_acc.min()}]')
            continue
        errors = cap_errors(true_acc, estim_acc)
        table.add(dataset, method, errors)

    tex = table.latexTabular()
    table_name = f'{basedir}_{classifier}_{metric}.tex'
    with open(f'./tables/{table_name}', 'wt') as foo:
        foo.write('\\resizebox{\\textwidth}{!}{%\n')
        foo.write('\\begin{tabular}{c|'+('c'*len(methods))+'}\n')
        foo.write(tex)
        foo.write('\\end{tabular}%\n')
        foo.write('}\n')

    tex_doc += "\input{" + table_name + "}\n"

    tex_doc += """
    \\end{document}
    """
    with open(f'./tables/main.tex', 'wt') as foo:
        foo.write(tex_doc)

    print("[Tables Done] runing latex")
    os.chdir('./tables/')
    os.system('pdflatex main.tex')
    os.system('rm main.aux main.log')