from sklearn.base import BaseEstimator

import quapy as qp
import itertools
import json
import os
from collections import defaultdict
from glob import glob
from pathlib import Path
from time import time
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

from ClassifierAccuracy.util.tabular import Table
from quapy.protocol import OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol

from quapy.data import LabelledCollection


def split(data: LabelledCollection):
    train_val, test = data.split_stratified(train_prop=0.66, random_state=0)
    train, val = train_val.split_stratified(train_prop=0.5, random_state=0)
    return train, val, test


def fit_method(method, V):
    tinit = time()
    method.fit(V)
    t_train = time() - tinit
    return method, t_train


def predictionsCAP(method, test_prot, oracle=False):
    tinit = time()
    if not oracle:
        estim_accs = [method.predict(Ui.X) for Ui in test_prot()]
    else:
        estim_accs = [method.predict(Ui.X, oracle_prev=Ui.prevalence()) for Ui in test_prot()]
    t_test_ave = (time() - tinit) / test_prot.total()
    return estim_accs, t_test_ave


def predictionsCAPcont_table(method, test_prot, gen_acc_measure, oracle=False):
    estim_accs_dict = {}
    tinit = time()
    if not oracle:
        estim_tables = [method.predict_ct(Ui.X) for Ui in test_prot()]
    else:
        estim_tables = [method.predict_ct(Ui.X, oracle_prev=Ui.prevalence()) for Ui in test_prot()]
    for acc_name, acc_fn in gen_acc_measure():
        estim_accs_dict[acc_name] = [acc_fn(cont_table) for cont_table in estim_tables]
    t_test_ave = (time() - tinit) / test_prot.total()
    return estim_accs_dict, t_test_ave


def any_missing(basedir, cls_name, dataset_name, method_name, acc_measures):
    for acc_name in acc_measures():
        if not os.path.exists(getpath(basedir, cls_name, acc_name, dataset_name, method_name)):
            return True
    return False


def true_acc(h:BaseEstimator, acc_fn: callable, U: LabelledCollection):
    y_pred = h.predict(U.X)
    y_true = U.y
    conf_table = confusion_matrix(y_true, y_pred=y_pred, labels=U.classes_)
    return acc_fn(conf_table)


def from_contingency_table(param1, param2):
    if param2 is None and isinstance(param1, np.ndarray) and param1.ndim==2 and (param1.shape[0]==param1.shape[1]):
        return True
    elif isinstance(param1, np.ndarray) and isinstance(param2, np.ndarray) and param1.shape==param2.shape:
        return False
    else:
        raise ValueError('parameters for evaluation function not understood')


def vanilla_acc_fn(param1, param2=None):
    if from_contingency_table(param1, param2):
        return _vanilla_acc_from_ct(param1)
    else:
        return accuracy_score(param1, param2)


def macrof1_fn(param1, param2=None):
    if from_contingency_table(param1, param2):
        return macro_f1_from_ct(param1)
    else:
        return f1_score(param1, param2, average='macro')


def _vanilla_acc_from_ct(cont_table):
    return np.diag(cont_table).sum() / cont_table.sum()


def _f1_bin(tp, fp, fn):
    if tp + fp + fn == 0:
        return 1
    else:
        return (2 * tp) / (2 * tp + fp + fn)


def macro_f1_from_ct(cont_table):
    n = cont_table.shape[0]

    if n==2:
        tp = cont_table[1,1]
        fp = cont_table[0,1]
        fn = cont_table[1,0]
        return _f1_bin(tp, fp, fn)

    f1_per_class = []
    for i in range(n):
        tp = cont_table[i,i]
        fp = cont_table[:,i].sum() - tp
        fn = cont_table[i,:].sum() - tp
        f1_per_class.append(_f1_bin(tp, fp, fn))

    return np.mean(f1_per_class)


def microf1(cont_table):
    n = cont_table.shape[0]

    if n == 2:
        tp = cont_table[1, 1]
        fp = cont_table[0, 1]
        fn = cont_table[1, 0]
        return _f1_bin(tp, fp, fn)

    tp, fp, fn = 0, 0, 0
    for i in range(n):
        tp += cont_table[i, i]
        fp += cont_table[:, i] - tp
        fn += cont_table[i, :] - tp
    return _f1_bin(tp, fp, fn)


def cap_errors(true_acc, estim_acc):
    true_acc = np.asarray(true_acc)
    estim_acc = np.asarray(estim_acc)
    #return (true_acc - estim_acc)**2
    return np.abs(true_acc - estim_acc)


def getpath(basedir, cls_name, acc_name, dataset_name, method_name):
    return f"results/{basedir}/{cls_name}/{acc_name}/{dataset_name}/{method_name}.json"


def open_results(basedir, cls_name, acc_name, dataset_name='*', method_name='*'):
    results = defaultdict(lambda : {'true_acc':[], 'estim_acc':[]})
    if isinstance(method_name, str):
        method_name = [method_name]
    if isinstance(dataset_name, str):
        dataset_name = [dataset_name]
    for dataset_, method_ in itertools.product(dataset_name, method_name):
        path = getpath(basedir, cls_name, acc_name, dataset_, method_)
        for file in glob(path):
            #print(file)
            method = Path(file).name.replace('.json','')
            result = json.load(open(file, 'r'))
            results[method]['true_acc'].extend(result['true_acc'])
            results[method]['estim_acc'].extend(result['estim_acc'])
    return results


def save_json_file(path, data):
    os.makedirs(Path(path).parent, exist_ok=True)
    with open(path, 'w') as f:
        json.dump(data, f)


def save_json_result(path, true_accs, estim_accs, t_train, t_test):
    result = {
        't_train': t_train,
        't_test_ave': t_test,
        'true_acc': true_accs,
        'estim_acc': estim_accs
    }
    save_json_file(path, result)


def get_dataset_stats(path, test_prot, L, V):
    test_prevs = [Ui.prevalence() for Ui in test_prot()]
    shifts = [qp.error.ae(L.prevalence(), Ui_prev) for Ui_prev in test_prevs]
    info = {
        'n_classes': L.n_classes,
        'n_train': len(L),
        'n_val': len(V),
        'train_prev': L.prevalence().tolist(),
        'val_prev': V.prevalence().tolist(),
        'test_prevs': [x.tolist() for x in test_prevs],
        'shifts': [x.tolist() for x in shifts],
        'sample_size': test_prot.sample_size,
        'num_samples': test_prot.total()
    }
    save_json_file(path, info)