from collections import defaultdict

from sklearn.calibration import CalibratedClassifierCV
from sklearn.svm import LinearSVC
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
import os
import quapy as qp
from method.aggregative import PACC, EMQ, PCC, CC, ACC, HDy
from models import *
import matplotlib.pyplot as plt
from pathlib import Path


def clf():
    # return CalibratedClassifierCV(LinearSVC(class_weight=None))
    return LogisticRegression(class_weight=None)


def F1(contingency_table):
    # tn = contingency_table[0, 0]
    tp = contingency_table[1, 1]
    fp = contingency_table[0, 1]
    fn = contingency_table[1, 0]
    den = (2*tp+fp+fn)
    if den>0:
        return 2*tp/den
    else:
        return 1


def accuracy(contingency_table):
    tn = contingency_table[0, 0]
    tp = contingency_table[1, 1]
    fp = contingency_table[0, 1]
    fn = contingency_table[1, 0]
    return (tp+tn)/(tp+fp+fn+tn)


def plot_series(series, repeats, metric_name, train_prev=None, savepath=None):

    for key in series:
        print(series[key])

    fig, ax = plt.subplots()

    def bin(v):
        mat = np.asarray(v).reshape(-1, repeats)
        return mat.mean(axis=1), mat.std(axis=1)

    x = series['prev']
    x,_ = bin(x)

    for serie in series:
        if serie=='prev': continue
        values = series[serie]
        print(serie, values)
        val_mean, val_std = bin(values)
        ax.errorbar(x, val_mean, label=serie, fmt='-', marker='o')
        ax.fill_between(x, val_mean - val_std, val_mean + val_std, alpha=0.25)

    if train_prev is not None:
        ax.axvline(x=train_prev, label='tr-prev', color='k', linestyle='--')
        # ax.scatter(train_prev, train_prev, c='c', label='tr-prev', linewidth=2, edgecolor='k', s=100, zorder=3)

    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    ax.grid()
    ax.set_title(metric_name)
    ax.set(xlabel='$p_U(\oplus)$', ylabel='estimated '+metric_name,
           title='Classifier accuracy in terms of '+metric_name)

    if savepath is None:
        plt.show()
    else:
        os.makedirs(Path(savepath).parent, exist_ok=True)
        plt.savefig(savepath, bbox_inches='tight')


dataset='imdb'
data = qp.datasets.fetch_reviews(dataset, tfidf=True, min_df=5, pickle=True)

# qp.data.preprocessing.reduce_columns(data, min_df=5, inplace=True)
# print('num_features', data.training.instances.shape[1])

train = data.training
test = data.test

upper = UpperBound(clf(), y_test=None).fit(train)

mlcfe = MLCMEstimator(clf(), strategy='kfcv', k=5, n_jobs=-1).fit(train)

emq_quant = QuantificationCMPredictor(clf(), EMQ(LogisticRegression()), strategy='kfcv', k=5, n_jobs=-1).fit(train)
# cc_quant = QuantificationCMPredictor(clf(), CC(clf()), strategy='kfcv', k=5, n_jobs=-1).fit(train)
# pcc_quant = QuantificationCMPredictor(clf(), PCC(clf()), strategy='kfcv', k=5, n_jobs=-1).fit(train)
# acc_quant = QuantificationCMPredictor(clf(), ACC(clf()), strategy='kfcv', k=5, n_jobs=-1).fit(train)
pacc_quant = QuantificationCMPredictor(clf(), PACC(clf()), strategy='kfcv', k=5, n_jobs=-1).fit(train)
# hdy_quant = QuantificationCMPredictor(clf(), HDy(clf()), strategy='kfcv', k=5, n_jobs=-1).fit(train)

sld = EMQ(LogisticRegression()).fit(train)
pacc = PACC(clf()).fit(train)

contenders = [
    ('kFCV+MLPE', mlcfe),
    ('SLD', emq_quant),
    # ('CC', cc_quant),
    # ('PCC', pcc_quant),
    # ('ACC', acc_quant),
    ('PACC', pacc_quant),
    # ('HDy', hdy_quant)
]

metric = F1
# metric = accuracy

repeats = 10
with qp.util.temp_seed(42):
    samples_idx = [idx for idx in test.artificial_sampling_index_generator(sample_size=500, n_prevalences=21, repeats=repeats)]


series = defaultdict(lambda: [])
for idx in tqdm(samples_idx, desc='generating predictions'):
    sample = test.sampling_from_index(idx)

    upper.show_true_labels(sample.labels)
    upper_conf_matrix = upper.predict(sample.instances)
    metric_true = metric(upper_conf_matrix)
    series['Upper'].append(metric_true)

    for mname, method in contenders:
        conf_matrix = method.predict(sample.instances)
        estim_metric = metric(conf_matrix)
        series[mname].append(estim_metric)
        if hasattr(method, 'quantify'):
            series[mname+'-prev'].append(method.quantify(sample.instances))

    series['binsld-prev'].append(sld.quantify(sample.instances)[1])
    series['binpacc-prev'].append(pacc.quantify(sample.instances)[1])
    series['optimal-prev'].append(sample.prevalence()[1])
    series['prev'].append(sample.prevalence()[1])

metricname = metric.__name__
plot_series(series, repeats, metric_name=metricname, train_prev=train.prevalence()[1], savepath='./plots/'+dataset+'_LinearSVC_'+metricname+'.pdf')