import quapy as qp
import os
import pathlib
import pickle
from glob import glob
import sys

from plot_driftbox import brokenbar_supremacy_by_drift
from uci_experiments import *
from uci_tables import METHODS
from os.path import join


qp.environ['SAMPLE_SIZE'] = SAMPLE_SIZE
plotext='png'

resultdir = './results_uci'
plotdir = './plots_uci'
os.makedirs(plotdir, exist_ok=True)

N_RUNS = N_FOLDS * N_REPEATS


def gather_results(methods, error_name, resultdir):
    method_names, true_prevs, estim_prevs, tr_prevs = [], [], [], []
    for method in methods:
        for run in range(N_RUNS):
            for experiment in glob(f'{resultdir}/*-{method}-run{run}-m{error_name}.pkl'):
                true_prevalences, estim_prevalences, tr_prev, te_prev, best_params = pickle.load(open(experiment, 'rb'))
                method_names.append(nicename(method))
                true_prevs.append(true_prevalences)
                estim_prevs.append(estim_prevalences)
                tr_prevs.append(tr_prev)
    return method_names, true_prevs, estim_prevs, tr_prevs


def plot_error_by_drift(methods, error_name, logscale=False, path=None):
    print('plotting error by drift')
    if path is not None:
        path = join(path, f'error_by_drift_{error_name}.{plotext}')
    method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name, resultdir)
    qp.plot.error_by_drift(
        method_names,
        true_prevs,
        estim_prevs,
        tr_prevs,
        n_bins=20,
        error_name=error_name,
        show_std=True,
        logscale=logscale,
        title=f'Quantification error as a function of distribution shift',
        savepath=path
    )


def diagonal_plot(methods, error_name, path=None):
    print('plotting diagonal plots')
    if path is not None:
        path = join(path, f'diag_{error_name}')
    method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name, resultdir)
    qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title='Positive', legend=True, show_std=True, savepath=f'{path}_pos.{plotext}')


def binary_bias_global(methods, error_name, path=None):
    print('plotting bias global')
    if path is not None:
        path = join(path, f'globalbias_{error_name}')
    method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name, resultdir)
    qp.plot.binary_bias_global(method_names, true_prevs, estim_prevs, pos_class=1, title='Positive', savepath=f'{path}_pos.{plotext}')


def binary_bias_bins(methods, error_name, path=None):
    print('plotting bias local')
    if path is not None:
        path = join(path, f'localbias_{error_name}')
    method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name, resultdir)
    qp.plot.binary_bias_bins(method_names, true_prevs, estim_prevs, pos_class=1, title='Positive', legend=True, savepath=f'{path}_pos.{plotext}')


def brokenbar_supr(methods, error_name, path=None):
    print('plotting brokenbar_supr')
    if path is not None:
        path = join(path, f'broken_{error_name}')
    method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name, resultdir)
    brokenbar_supremacy_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=10, binning='isometric',
                                 x_error='ae', y_error='ae', ttest_alpha=0.005, tail_density_threshold=0.005,
                                 savepath=path)


if __name__ == '__main__':
    # plot_error_by_drift(METHODS, error_name='ae', path=plotdir)

    # diagonal_plot(METHODS, error_name='ae', path=plotdir)

    # binary_bias_global(METHODS, error_name='ae', path=plotdir)

    # binary_bias_bins(METHODS, error_name='ae', path=plotdir)

    # brokenbar_supr(METHODS, error_name='ae', path=plotdir)
    brokenbar_supr(METHODS, error_name='ae', path=plotdir)