import os.path
import pathlib

from sklearn.metrics import f1_score
import functions as fn
import quapy as qp
import argparse
from quapy.data import LabelledCollection
from plot import eDiscoveryPlot



def main(args):

    datasetname = args.dataset
    k = args.k
    init_nD = args.initsize
    sampling_fn = getattr(fn, args.sampling)
    max_iterations = args.iter
    clf_name = args.classifier
    q_name = args.quantifier

    collection = qp.util.pickled_resource(f'./dataset/{datasetname}.pkl', fn.create_dataset, datasetname)
    nD = len(collection)

    fig = eDiscoveryPlot(args.output)

    skip_first_steps = 20

    with qp.util.temp_seed(args.seed):
        # initial labelled data selection
        if args.initprev == -1:
            idx = collection.sampling_index(init_nD)
        else:
            idx = collection.sampling_index(init_nD, *[1 - args.initprev, args.initprev])
        train, pool = fn.split_from_index(collection, idx)
        #first_train = LabelledCollection(train.instances, train.labels)

        # recall_target = 0.99
        i = 0

        # q = fn.NewQuantifier(q_name, clf_name)
        # print('searching regions')
        # q._find_regions((train+pool).instances)
        # print('[done]')

        with open(args.output, 'wt') as foo:
            def tee(msg):
                foo.write(msg + '\n')
                foo.flush()
                print(msg)

            tee('it\t%\ttr-size\tte-size\ttr-prev\tte-prev\tte-estim\tte-estimCC\tR\tRhat\tRhatCC\tShift\tAE\tAE_CC\tMF1_Q\tMF1_Clf\tICost\tremaining')

            while True:

                pool_p_hat_cc, classifier = fn.estimate_prev_CC(train, pool, clf_name)
                ideal_cost = fn.ideal_cost(classifier, pool)

                nDtr = len(train)
                nDte = len(pool)
                progress = 100 * nDtr / nD

                if i >= skip_first_steps:
                    pool_p_hat_q, q = fn.estimate_prev_Q(train, pool, q_name, clf_name)

                    f1_clf = 0 # eval_classifier(classifier, pool)
                    f1_q = 0 #eval_classifier(q_classifier, pool)

                    tr_p = train.prevalence()
                    te_p = pool.prevalence()

                    # this is based on an observation by D.Lewis "it is convenient to have the same kind of systematic"
                    # error both in the numerator and in the denominator
                    #tr_p_hat = q.quantify(train.instances)
                    #r_hat_q = fn.recall(tr_p_hat, pool_p_hat_q, nDtr, nDte)

                    r_hat_cc = fn.recall(tr_p, pool_p_hat_cc, nDtr, nDte)
                    r_hat_q = fn.recall(tr_p, pool_p_hat_q, nDtr, nDte)
                    r = fn.recall(tr_p, te_p, nDtr, nDte)
                    tr_te_shift = qp.error.ae(tr_p, te_p)

                    ae_q = qp.error.ae(te_p, pool_p_hat_q)
                    ae_cc = qp.error.ae(te_p, pool_p_hat_cc)

                    tee(f'{i}\t{progress:.2f}\t{nDtr}\t{nDte}\t{tr_p[1]:.3f}\t{te_p[1]:.3f}\t{pool_p_hat_q[1]:.3f}\t{pool_p_hat_cc[1]:.3f}'
                        f'\t{r:.3f}\t{r_hat_q:.3f}\t{r_hat_cc:.3f}\t{tr_te_shift:.5f}\t{ae_q:.4f}\t{ae_cc:.4f}\t{f1_q:.3f}\t{f1_clf:.3f}'
                        f'\t{ideal_cost}\t{pool.labels.sum()}')

                    posteriors = classifier.predict_proba(pool.instances)
                    fig.plot(posteriors, pool.labels)

                    if nDte < k:
                        print('[stop] too few documents remaining')
                        break
                    elif i+1 == max_iterations:
                        print('[stop] maximum number of iterations reached')
                        break

                top_relevant_idx = sampling_fn(pool, classifier, k, progress)
                train, pool = fn.move_documents(train, pool, top_relevant_idx)

                i += 1


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='e-Discovery')
    parser.add_argument('--dataset', metavar='DATASET', type=str, help='Dataset name',
                        default='RCV1.C4')
    parser.add_argument('--quantifier', metavar='METHOD', type=str, help='Quantification method',
                        default='EMQ')
    parser.add_argument('--sampling', metavar='SAMPLING', type=str, help='Sampling criterion',
                        default='relevance_sampling')
    parser.add_argument('--iter', metavar='INT', type=int, help='number of iterations (-1 to set no limit)',
                        default=-1)
    parser.add_argument('--k', metavar='BATCH', type=int, help='number of documents in a batch',
                        default=100)
    parser.add_argument('--initsize', metavar='SIZE', type=int, help='number of labelled documents at the beginning',
                        default=2)
    parser.add_argument('--initprev', metavar='PREV', type=float,
                        help='prevalence of the initial sample (-1 for uniform sampling)',
                        default=0.5)
    parser.add_argument('--seed', metavar='SEED', type=int,
                        help='random seed',
                        default=1)
    parser.add_argument('--classifier', metavar='CLS', type=str,
                        help='classifier type (svm, lr)',
                        default='lr')
    parser.add_argument('--output', metavar='OUT', type=str,
                        help="name of the file containing the results of the experiment (default is an automatic "
                             "filename based on the model's parameters in the folder './results/')",
                        default=None)
    args = parser.parse_args()

    assert args.initprev==-1.0 or (0 < args.initprev < 1), 'wrong value for initsize; should be in (0., 1.)'
    if args.initprev==-1:  # this is to clean the path, to show initprev:-1 and not initprev:-1.0
        args.initprev = int(args.initprev)
    if args.output is None:
        outputdir = './results'
        args.output = os.path.join(outputdir, fn.experiment_name(args))
    else:
        outputdir = pathlib.Path(args.output).parent.name
    if outputdir:
        qp.util.create_if_not_exist(outputdir)

    for k,v in args.__dict__.items():
        print(f'{k}={v}')

    main(args)