import itertools
import json
import os
from collections import defaultdict
from glob import glob
from pathlib import Path
from time import time
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score, f1_score

from sklearn.datasets import fetch_rcv1, fetch_20newsgroups
from sklearn.model_selection import GridSearchCV

from ClassifierAccuracy.models_multiclass import *
from ClassifierAccuracy.util.tabular import Table
from quapy.protocol import OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol
from quapy.method.aggregative import EMQ, ACC, KDEyML

from quapy.data import LabelledCollection
from quapy.data.datasets import fetch_UCIMulticlassLabelledCollection, UCI_MULTICLASS_DATASETS, fetch_lequa2022, TWITTER_SENTIMENT_DATASETS_TEST
from quapy.data.datasets import fetch_reviews


def gen_classifiers():
    param_grid = {
        'C': np.logspace(-4, -4, 9),
        'class_weight': ['balanced', None]
    }

    yield 'LR', LogisticRegression()
    #yield 'LR-opt', GridSearchCV(LogisticRegression(), param_grid, cv=5, n_jobs=-1)
    #yield 'NB', GaussianNB()
    #yield 'SVM(rbf)', SVC()
    #yield 'SVM(linear)', LinearSVC()


def gen_multi_datasets(only_names=False)-> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
    for dataset_name in UCI_MULTICLASS_DATASETS:
        if dataset_name == 'wine-quality':
            continue
        if only_names:
            yield dataset_name, None
        else:
            dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
            yield dataset_name, split(dataset)

    # yields the 20 newsgroups dataset
    if only_names:
        yield "20news", None
    else:
        train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
        test  = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))
        tfidf = TfidfVectorizer(min_df=5, sublinear_tf=True)
        Xtr = tfidf.fit_transform(train.data)
        Xte = tfidf.transform((test.data))
        train = LabelledCollection(instances=Xtr, labels=train.target)
        U  = LabelledCollection(instances=Xte, labels=test.target)
        T, V = train.split_stratified(train_prop=0.5, random_state=0)
        yield "20news", (T, V, U)

    # yields the T1B@LeQua2022 (training) dataset
    if only_names:
        yield "T1B-LeQua2022", None
    else:
        train, _, _ = fetch_lequa2022(task='T1B')
        yield "T1B-LeQua2022", split(train)


def gen_tweet_datasets(only_names=False)-> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
    for dataset_name in TWITTER_SENTIMENT_DATASETS_TEST:
        if only_names:
            yield dataset_name, None
        else:
            data = qp.datasets.fetch_twitter(dataset_name, min_df=3, pickle=True)
            T, V = data.training.split_stratified(0.5, random_state=0)
            U = data.test
            yield dataset_name, (T, V, U)


def gen_bin_datasets(only_names=False) -> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
    if only_names:
        for dataset_name in ['imdb', 'CCAT', 'GCAT', 'MCAT']:
            yield dataset_name, None
    else:
        train, U = fetch_reviews('imdb', tfidf=True, min_df=10, pickle=True).train_test
        L, V = train.split_stratified(0.5, random_state=0)
        yield 'imdb', (L, V, U)

        training = fetch_rcv1(subset='train')
        test = fetch_rcv1(subset='test')
        class_names = training.target_names.tolist()
        for cat in ['CCAT', 'GCAT', 'MCAT']:
            class_idx = class_names.index(cat)
            tr_labels = training.target[:,class_idx].toarray().flatten()
            te_labels = test.target[:,class_idx].toarray().flatten()
            tr = LabelledCollection(training.data, tr_labels)
            U = LabelledCollection(test.data, te_labels)
            L, V = tr.split_stratified(train_prop=0.5, random_state=0)
            yield cat, (L, V, U)


def gen_CAP(h, acc_fn, with_oracle=False)->[str, ClassifierAccuracyPrediction]:
    #yield 'SebCAP', SebastianiCAP(h, acc_fn, ACC)
    # yield 'SebCAP-SLD', SebastianiCAP(h, acc_fn, EMQ, predict_train_prev=not with_oracle)
    #yield 'SebCAP-KDE', SebastianiCAP(h, acc_fn, KDEyML)
    #yield 'SebCAPweight', SebastianiCAP(h, acc_fn, ACC, alpha=0)
    #yield 'PabCAP', PabloCAP(h, acc_fn, ACC)
    # yield 'PabCAP-SLD-median', PabloCAP(h, acc_fn, EMQ, aggr='median')
    yield 'ATC-MC', ATC(h, acc_fn, scoring_fn='maxconf')
    # yield 'ATC-NE', ATC(h, acc_fn, scoring_fn='neg_entropy')
    yield 'DoC', DoC(h, acc_fn, sample_size=qp.environ['SAMPLE_SIZE'])


def gen_CAP_cont_table(h)->[str,CAPContingencyTable]:
    acc_fn = None
    yield 'Naive', NaiveCAP(h, acc_fn)
    yield 'CT-PPS-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
    # yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01))
    # yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05))
    #yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False)
    #yield 'QuAcc(EMQ)nxn', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()))
    #yield 'QuAcc(EMQ)nxn-MC', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_maxconf=True)
    # yield 'QuAcc(EMQ)nxn-NE', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_negentropy=True)
    #yield 'QuAcc(EMQ)nxn-MIS', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_maxinfsoft=True)
    #yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression()))
    #yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression()))
    #yield 'CT-PPSh-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()), reuse_h=True)
    #yield 'Equations-ACCh', NsquaredEquationsCAP(h, acc_fn, ACC, reuse_h=True)
    # yield 'Equations-ACC', NsquaredEquationsCAP(h, acc_fn, ACC)
    #yield 'Equations-SLD', NsquaredEquationsCAP(h, acc_fn, EMQ)


def get_method_names():
    mock_h = LogisticRegression()
    return [m for m, _ in gen_CAP(mock_h, None)] + [m for m, _ in gen_CAP_cont_table(mock_h)]


def gen_acc_measure():
    yield 'vanilla_accuracy', vanilla_acc_fn
    yield 'macro-F1', macrof1_fn


def split(data: LabelledCollection):
    train_val, test = data.split_stratified(train_prop=0.66, random_state=0)
    train, val = train_val.split_stratified(train_prop=0.5, random_state=0)
    return train, val, test


def fit_method(method, V):
    tinit = time()
    method.fit(V)
    t_train = time() - tinit
    return method, t_train


def predictionsCAP(method, test_prot, oracle=False):
    tinit = time()
    if not oracle:
        estim_accs = [method.predict(Ui.X) for Ui in test_prot()]
    else:
        estim_accs = [method.predict(Ui.X, oracle_prev=Ui.prevalence()) for Ui in test_prot()]
    t_test_ave = (time() - tinit) / test_prot.total()
    return estim_accs, t_test_ave


def predictionsCAPcont_table(method, test_prot, gen_acc_measure, oracle=False):
    estim_accs_dict = {}
    tinit = time()
    if not oracle:
        estim_tables = [method.predict_ct(Ui.X) for Ui in test_prot()]
    else:
        estim_tables = [method.predict_ct(Ui.X, oracle_prev=Ui.prevalence()) for Ui in test_prot()]
    for acc_name, acc_fn in gen_acc_measure():
        estim_accs_dict[acc_name] = [acc_fn(cont_table) for cont_table in estim_tables]
    t_test_ave = (time() - tinit) / test_prot.total()
    return estim_accs_dict, t_test_ave


def any_missing(basedir, cls_name, dataset_name, method_name):
    for acc_name, _ in gen_acc_measure():
        if not os.path.exists(getpath(basedir, cls_name, acc_name, dataset_name, method_name)):
            return True
    return False


def true_acc(h:BaseEstimator, acc_fn: callable, U: LabelledCollection):
    y_pred = h.predict(U.X)
    y_true = U.y
    conf_table = confusion_matrix(y_true, y_pred=y_pred, labels=U.classes_)
    return acc_fn(conf_table)


def from_contingency_table(param1, param2):
    if param2 is None and isinstance(param1, np.ndarray) and param1.ndim==2 and (param1.shape[0]==param1.shape[1]):
        return True
    elif isinstance(param1, np.ndarray) and isinstance(param2, np.ndarray) and param1.shape==param2.shape:
        return False
    else:
        raise ValueError('parameters for evaluation function not understood')


def vanilla_acc_fn(param1, param2=None):
    if from_contingency_table(param1, param2):
        return _vanilla_acc_from_ct(param1)
    else:
        return accuracy_score(param1, param2)


def macrof1_fn(param1, param2=None):
    if from_contingency_table(param1, param2):
        return macro_f1_from_ct(param1)
    else:
        return f1_score(param1, param2, average='macro')


def _vanilla_acc_from_ct(cont_table):
    return np.diag(cont_table).sum() / cont_table.sum()


def _f1_bin(tp, fp, fn):
    if tp + fp + fn == 0:
        return 1
    else:
        return (2 * tp) / (2 * tp + fp + fn)


def macro_f1_from_ct(cont_table):
    n = cont_table.shape[0]

    if n==2:
        tp = cont_table[1,1]
        fp = cont_table[0,1]
        fn = cont_table[1,0]
        return _f1_bin(tp, fp, fn)

    f1_per_class = []
    for i in range(n):
        tp = cont_table[i,i]
        fp = cont_table[:,i].sum() - tp
        fn = cont_table[i,:].sum() - tp
        f1_per_class.append(_f1_bin(tp, fp, fn))

    return np.mean(f1_per_class)


def microf1(cont_table):
    n = cont_table.shape[0]

    if n == 2:
        tp = cont_table[1, 1]
        fp = cont_table[0, 1]
        fn = cont_table[1, 0]
        return _f1_bin(tp, fp, fn)

    tp, fp, fn = 0, 0, 0
    for i in range(n):
        tp += cont_table[i, i]
        fp += cont_table[:, i] - tp
        fn += cont_table[i, :] - tp
    return _f1_bin(tp, fp, fn)


def cap_errors(true_acc, estim_acc):
    true_acc = np.asarray(true_acc)
    estim_acc = np.asarray(estim_acc)
    #return (true_acc - estim_acc)**2
    return np.abs(true_acc - estim_acc)


def getpath(basedir, cls_name, acc_name, dataset_name, method_name):
    return f"results/{basedir}/{cls_name}/{acc_name}/{dataset_name}/{method_name}.json"


def open_results(basedir, cls_name, acc_name, dataset_name='*', method_name='*'):
    results = defaultdict(lambda : {'true_acc':[], 'estim_acc':[]})
    if isinstance(method_name, str):
        method_name = [method_name]
    if isinstance(dataset_name, str):
        dataset_name = [dataset_name]
    for dataset_, method_ in itertools.product(dataset_name, method_name):
        path = getpath(basedir, cls_name, acc_name, dataset_, method_)
        for file in glob(path):
            #print(file)
            method = Path(file).name.replace('.json','')
            result = json.load(open(file, 'r'))
            results[method]['true_acc'].extend(result['true_acc'])
            results[method]['estim_acc'].extend(result['estim_acc'])
    return results


def save_json_file(path, data):
    os.makedirs(Path(path).parent, exist_ok=True)
    with open(path, 'w') as f:
        json.dump(data, f)


def save_json_result(path, true_accs, estim_accs, t_train, t_test):
    result = {
        't_train': t_train,
        't_test_ave': t_test,
        'true_acc': true_accs,
        'estim_acc': estim_accs
    }
    save_json_file(path, result)


def get_dataset_stats(path, test_prot, L, V):
    test_prevs = [Ui.prevalence() for Ui in test_prot()]
    shifts = [qp.error.ae(L.prevalence(), Ui_prev) for Ui_prev in test_prevs]
    info = {
        'n_classes': L.n_classes,
        'n_train': len(L),
        'n_val': len(V),
        'train_prev': L.prevalence().tolist(),
        'val_prev': V.prevalence().tolist(),
        'test_prevs': [x.tolist() for x in test_prevs],
        'shifts': [x.tolist() for x in shifts],
        'sample_size': test_prot.sample_size,
        'num_samples': test_prot.total()
    }
    save_json_file(path, info)


def gen_tables(basedir, datasets):


    mock_h = LogisticRegression(),
    methods = [method for method, _ in gen_CAP(mock_h, None)] + [method for method, _ in gen_CAP_cont_table(mock_h)]
    classifiers = [classifier for classifier, _ in gen_classifiers()]

    os.makedirs('./tables', exist_ok=True)

    with_oracle = 'oracle' in basedir

    tex_doc = """
    \\documentclass[10pt,a4paper]{article}
    \\usepackage[utf8]{inputenc}
    \\usepackage{amsmath}
    \\usepackage{amsfonts}
    \\usepackage{amssymb}
    \\usepackage{graphicx}
    \\usepackage{tabularx}
    \\usepackage{color}
    \\usepackage{colortbl}
    \\usepackage{xcolor}
    \\begin{document}
    """

    for classifier in classifiers:
        for metric in [measure for measure, _ in gen_acc_measure()]:

            table = Table(datasets, methods, prec_mean=5, clean_zero=True)
            for method, dataset in itertools.product(methods, datasets):
                path = getpath(basedir, classifier, metric, dataset, method)
                if not os.path.exists(path):
                    print('missing ', path)
                    continue
                results = json.load(open(path, 'r'))
                true_acc = results['true_acc']
                estim_acc = np.asarray(results['estim_acc'])
                if any(np.isnan(estim_acc)):
                    print(f'nan values found in {method=} {dataset=}')
                    continue
                if any(estim_acc>1.00001):
                    print(f'values >1 found in {method=} {dataset=} [max={estim_acc.max()}]')
                    continue
                if any(estim_acc<-0.00001):
                    print(f'values <0 found in {method=} {dataset=} [min={estim_acc.min()}]')
                    continue
                errors = cap_errors(true_acc, estim_acc)
                table.add(dataset, method, errors)

            tex = table.latexTabular()
            table_name = f'{basedir}_{classifier}_{metric}.tex'
            table_name = table_name.replace('/', '_')
            with open(f'./tables/{table_name}', 'wt') as foo:
                foo.write('\\begin{table}[h]\n')
                foo.write('\\centering\n')
                foo.write('\\resizebox{\\textwidth}{!}{%\n')
                foo.write('\\begin{tabular}{c|'+('c'*len(methods))+'}\n')
                foo.write(tex)
                foo.write('\\end{tabular}%\n')
                foo.write('}\n')
                foo.write('\\caption{Classifier ' + classifier.replace('_', ' ') + ('(oracle)' if with_oracle else '') +
                          ' evaluated in terms of ' + metric.replace('_', ' ') + '}\n')
                foo.write('\\end{table}\n')

            tex_doc += "\input{" + table_name + "}\n\n"

    tex_doc += """
    \\end{document}
    """
    with open(f'./tables/main.tex', 'wt') as foo:
        foo.write(tex_doc)

    print("[Tables Done] runing latex")
    os.chdir('./tables/')
    os.system('pdflatex main.tex')
    os.system('rm main.aux main.log')


class ArtificialAccuracyProtocol(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):

    def __init__(self, data: LabelledCollection, h: BaseEstimator, sample_size=None, n_prevalences=101, repeats=10, random_state=0):
        super(ArtificialAccuracyProtocol, self).__init__(random_state)
        self.data = data
        self.h = h
        self.sample_size = qp._get_sample_size(sample_size)
        self.n_prevalences = n_prevalences
        self.repeats = repeats
        self.collator = OnLabelledCollectionProtocol.get_collator('labelled_collection')

    def accuracy_grid(self):
        grid = np.linspace(0, 1, self.n_prevalences)
        grid = np.repeat(grid, self.repeats, axis=0)
        return grid

    def samples_parameters(self):
        # issue predictions
        label_predictions = self.h.predict(self.data.X)
        correct = label_predictions == self.data.y
        self.data_evaluated = LabelledCollection(self.data.X, labels=correct, classes=[0,1])
        indexes = []
        for acc_value in self.accuracy_grid():
            index = self.data_evaluated.sampling_index(self.sample_size, acc_value)
            indexes.append(index)
        return indexes

    def sample(self, index):
        return self.data.sampling_from_index(index)

    def total(self):
        return self.n_prevalences * self.repeats