import argparse
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
import itertools
from sklearn.multioutput import ClassifierChain
from tqdm import tqdm
from skmultilearn.dataset import load_dataset, available_data_sets
from scipy.sparse import csr_matrix
import quapy as qp
from MultiLabel.main import load_results, SKMULTILEARN_RED_DATASETS, TC_DATASETS, sample_size
from MultiLabel.mlclassification import MLStackedClassifier
from MultiLabel.mldata import MultilabelledCollection
from MultiLabel.mlquantification import MLNaiveQuantifier, MLCC, MLPCC, MLRegressionQuantification, \
    MLACC, \
    MLPACC, MLNaiveAggregativeQuantifier
from MultiLabel.tabular import Table
from method.aggregative import PACC, CC, EMQ, PCC, ACC, HDy
import numpy as np
from data.dataset  import Dataset
from mlevaluation import ml_natural_prevalence_prediction, ml_artificial_prevalence_prediction, check_error_str
import sys
import os
import pickle

models = [#'MLPE',
    'NaiveCC', 'NaivePCC', 'NaivePCCcal', 'NaiveACC', 'NaivePACC', 'NaivePACCcal', 'NaiveACCit', 'NaivePACCit',
    #'NaiveHDy', 'NaiveSLD',
    'ChainCC', 'ChainPCC', 'ChainACC', 'ChainPACC',
    'StackCC', 'StackPCC', 'StackPCCcal', 'StackACC', 'StackPACC', 'StackPACCcal', 'StackACCit', 'StackP'
                                                                  'ACCit',
    'MRQ-CC', 'MRQ-PCC', 'MRQ-ACC', 'MRQ-PACC',  'MRQ-ACCit', 'MRQ-PACCit',
    'StackMRQ-CC', 'StackMRQ-PCC', 'StackMRQ-ACC', 'StackMRQ-PACC',
    'MRQ-StackCC', 'MRQ-StackPCC', 'MRQ-StackACC', 'MRQ-StackPACC',
    'StackMRQ-StackCC', 'StackMRQ-StackPCC', 'StackMRQ-StackACC', 'StackMRQ-StackPACC',
    'MRQ-StackCC-app', 'MRQ-StackPCC-app', 'MRQ-StackACC-app', 'MRQ-StackPACC-app',
    'StackMRQ-StackCC-app', 'StackMRQ-StackPCC-app', 'StackMRQ-StackACC-app', 'StackMRQ-StackPACC-app',
    'LSP-CC', 'LSP-ACC', 'MLKNN-CC', 'MLKNN-ACC',
    'MLAdjustedC', 'MLStackAdjustedC', 'MLprobAdjustedC', 'MLStackProbAdjustedC'
]

# datasets = sorted(set([x[0] for x in available_data_sets().keys()]))
datasets = TC_DATASETS




def generate_table(path, protocol, error):

    def compute_score_job(args):
        dataset, model = args
        result_path = f'{opt.results}/{dataset}_{model}.pkl'
        if os.path.exists(result_path):
            print('+', end='')
            sys.stdout.flush()
            result = load_results(result_path)
            true_prevs, estim_prevs = result[protocol]
            scores = np.asarray([error(trues, estims) for trues, estims in zip(true_prevs, estim_prevs)]).flatten()
            return dataset, model, scores
        print('-', end='')
        sys.stdout.flush()
        return None


    print(f'\ngenerating {path}')
    table = Table(datasets, models, prec_mean=4, significance_test='wilcoxon')
    results = qp.util.parallel(compute_score_job, list(itertools.product(datasets, models)), n_jobs=-1)
    print()

    for r in results:
        if r is not None:
            dataset, model, scores = r
            table.add(dataset, model, scores)

    save_table(table, path)
    save_table(table.getRankTable(), path.replace('.tex','.rank.tex'))



def save_table(table, path):
    tabular = """
    \\resizebox{\\textwidth}{!}{%
            \\begin{tabular}{|c||""" + ('c|' * len(models)) + """} \hline
            """
    dataset_replace = {'tmc2007_500': 'tmc2007\_500', 'tmc2007_500-red': 'tmc2007\_500-red'}
    method_replace = {}

    tabular += table.latexTabularT(benchmark_replace=dataset_replace, method_replace=method_replace, side=True)
    tabular += """
        \end{tabular}%
        }
    """
    with open(path, 'wt') as foo:
        foo.write(tabular)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Experiments for multi-label quantification')
    parser.add_argument('--results', type=str, default='./results', metavar='str',
                        help=f'path where to store the results')
    parser.add_argument('--tablepath', type=str, default='./tables', metavar='str',
                        help=f'path where to store the tables')
    opt = parser.parse_args()

    assert os.path.exists(opt.results), f'result directory {opt.results} does not exist'
    os.makedirs(opt.tablepath, exist_ok=True)

    qp.environ["SAMPLE_SIZE"] = sample_size
    absolute_error = qp.error.ae
    relative_absolute_error = qp.error.rae

    generate_table(f'{opt.tablepath}/npp.ae.tex', protocol='npp', error=absolute_error)
    generate_table(f'{opt.tablepath}/app.ae.tex', protocol='app', error=absolute_error)
    generate_table(f'{opt.tablepath}/npp.rae.tex', protocol='npp', error=relative_absolute_error)
    generate_table(f'{opt.tablepath}/app.rae.tex', protocol='app', error=relative_absolute_error)