import numpy as np
import matplotlib.pyplot as plt
import sklearn.preprocessing
from matplotlib import cm
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import normalize

import quapy as qp
import quapy.functional as F
from quapy.data import LabelledCollection
from quapy.method.aggregative import CC, ACC, PCC, PACC, EMQ
import os
from scipy.stats import ttest_rel


x_min, x_max = 0, 11
y_min, y_max = 0, x_max
center0 = (2*x_max/5,2*x_max/5)
center1 = (3*x_max/5,3*x_max/5)

X, Y = make_blobs(n_samples=[100000, 100000], n_features=2, centers=[center0,center1])


data = LabelledCollection(X, Y)

train_pool, test_pool = data.split_stratified(train_prop=0.5)



def plot(fignum, title, savepath=None):
    clf = q.learner

    # get the separating hyperplane
    w = clf.coef_[0]
    a = -w[0] / w[1]
    xx = np.linspace(0, x_max)
    yy = a * xx - (clf.intercept_[0]) / w[1]

    wref = reference_hyperplane.coef_[0]
    aref = -wref[0] / wref[1]

    YY, XX = np.meshgrid(yy, xx)
    xy = np.vstack([XX.ravel(), YY.ravel()]).T
    # Z = clf.decision_function(xy).reshape(XX.shape)
    # Z2 = reference_hyperplane.decision_function(xy).reshape(XX.shape)


    # plot the line and the points
    plt.figure(fignum + 1, figsize=(10, 10))
    plt.clf()
    plt.plot(xx, yy, "k-")

    Xte, yte = test.Xy
    # plt.scatter(Xte[:, 0], Xte[:, 1], c=test.labels, zorder=10, cmap=cm.get_cmap("RdBu"), alpha=0.4)
    cmap=cm.get_cmap("RdBu")
    plt.scatter(Xte[yte==0][:, 0], Xte[yte==0][:, 1], color=cmap(0), zorder=10, alpha=0.4, label='-')
    plt.scatter(Xte[yte==1][:, 0], Xte[yte==1][:, 1], color=cmap(cmap.N-1), zorder=10, alpha=0.4, label='+')

    plt.axis("tight")

    # Put the result into a contour plot
    # plt.contourf(XX, YY, Z, cmap=cm.get_cmap("RdBu"), alpha=0.6, levels=50, linestyles=None)

    plt.plot(xx, a * xx - (clf.intercept_[0]) / w[1], 'k-', label='modified')
    plt.plot(xx, aref * xx - (reference_hyperplane.intercept_[0]) / wref[1], 'k--', label='original')

    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)

    plt.xticks(())
    plt.yticks(())

    plt.title(title)
    plt.legend()

    if savepath:
        plt.savefig(savepath)


def mock_y(prev):
    n=10000
    nneg = int(n * prev[0])
    npos = int(n * prev[1])
    mock = np.asarray([0]*nneg + [1]*npos, dtype=int)
    return mock


def get_class_weight(prevalence):
    # class_weight = compute_class_weight('balanced', classes=[0, 1], y=mock_y(prevalence))
    # return {0: class_weight[1], 1: class_weight[0]}
    # weights = prevalence/prevalence.min()
    weights = prevalence / train.prevalence()
    normfactor = weights.min()
    if normfactor <= 0:
        normfactor = 1E-3
    weights /= normfactor
    return {0:weights[0], 1:weights[1]}


def train_eval(class_weight, test):
    q = Method(LogisticRegression(class_weight=class_weight))
    q.fit(train)

    prev_estim = q.quantify(test.instances)
    true_prev = test.prevalence()
    ae = qp.error.ae(true_prev, prev_estim)

    return q, prev_estim, ae


probabilistic = True

Prompter = PACC  # the method creating the very first guess
Baseline = PACC if probabilistic else ACC
bname = Baseline.__name__

Method = PCC if probabilistic else CC
mname = Method.__name__

plotdir=f'./plots/{mname}_vs_{bname}'
os.makedirs(plotdir, exist_ok=True)

test_prevs = np.linspace(0,1,20)
train_prevs = np.linspace(0.05,0.95,20)

fignum = 0
wins, total = 0, 0
merrors = []
berrors = []

for ptr in train_prevs:
    train = train_pool.sampling(10000, ptr)

    reference_hyperplane = LogisticRegression().fit(*train.Xy)
    baseline = Baseline(LogisticRegression()).fit(train)
    if Baseline != Prompter:
        prompter = Prompter(LogisticRegression()).fit(train)
    else:
        prompter = baseline

    for pte in test_prevs:
        test = test_pool.sampling(10000, pte)

        # some baseline results
        prev_estim_acc = baseline.quantify(test.instances)
        ae_baseline = qp.error.ae(test.prevalence(), prev_estim_acc)
        berrors.append(ae_baseline)

        # guessed_prevalence = train.prevalence()
        guessed_prevalence = prompter.quantify(test.instances)

        niter=10
        last_prev = None
        for i in range(niter):
            class_weight = get_class_weight(guessed_prevalence)

            q, prev_estim, ae = train_eval(class_weight, test)

            stop = (i == niter-1) or (last_prev is not None and qp.error.ae(prev_estim, last_prev) < 0.001)
            if stop:
                merrors.append(ae)
                win = ae < ae_baseline
                if win: wins+=1

                print(f'{i}: tr_prev={F.strprev(train.prevalence())} te_prev={F.strprev(test.prevalence())}, {mname}+ estim_prev={F.strprev(prev_estim)} AE={ae:.5f} '
                      f'using class_weight [{class_weight[0]:.3f}, {class_weight[1]:.3f}] '
                      f'({bname} prev={F.strprev(prev_estim_acc)} AE={ae_baseline:.5f}) '
                      f'{"WIN" if win else "LOSE"}')
                break
            else:
                last_prev = prev_estim


            # title='$\hat{{p}}^{{{}}}={:.3f}$, $p={:.3f}$, $\hat{{p}}={:.3f}$, AE$_{{{}}}={:.3f}$, AE$_{{{}}}={:.3f}$'.format(
            #     i, guessed_prevalence[0], pte, prev_estim[0], mname, ae, bname, ae_baseline
            # )
            # savepath=os.path.join(plotdir, f'tr_{ptr}_te{pte}_{i}.png')
            # plot(fignum, title, savepath)

            fignum+=1

            guessed_prevalence = prev_estim
        total += 1


merrors = np.asarray(merrors)
berrors = np.asarray(berrors)
mean_merrors = merrors.mean()
mean_berrors = berrors.mean()

print(f'WINS={wins}/{total}={100*wins/total:.2f}%')

_,p_val = ttest_rel(merrors,berrors)
print(f'{mname}-ave={mean_merrors:.5f} {bname}-ave={mean_berrors:.5f}')
print(f'ttest p-value={p_val:5f} significant={p_val<0.05}')