from sklearn.svm import LinearSVC

from class_weight_model import ClassWeightPCC
# from classification.methods import LowRankLogisticRegression
# from method.experimental import ExpMax, VarExpMax
from common import *
from method.meta import QuaNet
from quantification_stumps_model import QuantificationStumpRegressor
from quapy.method.aggregative import CC, ACC, PCC, PACC, MAX, MS, MS2, EMQ, SVMAE, HDy
from quapy.method.meta import EHDy
import numpy as np
import os
import pickle
import itertools
import argparse
import torch
import shutil


SAMPLE_SIZE = 100

N_FOLDS = 5
N_REPEATS = 1

N_JOBS = -1
CUDA_N_JOBS = 2
ENSEMBLE_N_JOBS = -1

qp.environ['SAMPLE_SIZE'] = SAMPLE_SIZE

__C_range = np.logspace(-3, 3, 7)
lr_params = {'C': __C_range, 'class_weight': [None, 'balanced']}
svmperf_params = {'C': __C_range}


def quantification_models():
    # yield 'cc', CC(newLR()), lr_params
    # yield 'acc', ACC(newLR()), lr_params
    yield 'pcc.opt', PCC(newLR()), lr_params
    yield 'pacc.opt', PACC(newLR()), lr_params
    yield 'wpacc.opt', ClassWeightPCC(), lr_params
    yield 'ds.opt', QuantificationStumpRegressor(SAMPLE_SIZE), {'C': __C_range}
    # yield 'pcc.opt.svm', PCC(LinearSVC()), lr_params
    # yield 'pacc.opt.svm', PACC(LinearSVC()), lr_params
    # yield 'wpacc.opt.svm', ClassWeightPCC(LinearSVC), lr_params
    # yield 'wpacc.opt2', ClassWeightPCC(C=__C_range), lr_params  # this cannot work in its current version (see notes in the class_weight_model.py file)
    # yield 'MAX', MAX(newLR()), lr_params
    # yield 'MS', MS(newLR()), lr_params
    # yield 'MS2', MS2(newLR()), lr_params
    yield 'sldc', EMQ(calibratedLR()), lr_params
    # yield 'svmmae', SVMAE(), svmperf_params
    # yield 'hdy', HDy(newLR()), lr_params
    # yield 'EMdiag', ExpMax(cov_type='diag'), None
    # yield 'EMfull', ExpMax(cov_type='full'), None
    # yield 'EMtied', ExpMax(cov_type='tied'), None
    # yield 'EMspherical', ExpMax(cov_type='spherical'), None
    # yield 'VEMdiag', VarExpMax(cov_type='diag'), None
    # yield 'VEMfull', VarExpMax(cov_type='full'), None
    # yield 'VEMtied', VarExpMax(cov_type='tied'), None
    # yield 'VEMspherical', VarExpMax(cov_type='spherical'), None


# def quantification_cuda_models():
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'
#     print(f'Running QuaNet in {device}')
#     learner = LowRankLogisticRegression(**newLR().get_params())
#     yield 'quanet', QuaNet(learner, SAMPLE_SIZE, checkpointdir=args.checkpointdir, device=device), lr_params


# def quantification_ensembles():
#     param_mod_sel = {
#         'sample_size': SAMPLE_SIZE,
#         'n_prevpoints': 21,
#         'n_repetitions': 5,
#         'refit': True,
#         'verbose': False
#     }
#     common = {
#         'size': 30,
#         'red_size': 15,
#         'max_sample_size': None,  # same as training set
#         'n_jobs': ENSEMBLE_N_JOBS,
#         'param_grid': lr_params,
#         'param_mod_sel': param_mod_sel,
#         'val_split': 0.4,
#         'min_pos': 5
#     }
#
#     hyperparameters will be evaluated within each quantifier of the ensemble, and so the typical model selection
#     will be skipped (by setting hyperparameters to None)
    # hyper_none = None
    # yield 'ehdymaeds',  EHDy(newLR(), optim='mae', policy='ds', **common), hyper_none


def run(experiment):
    optim_loss, dataset_name, (model_name, model, hyperparams) = experiment
    if dataset_name in ['acute.a', 'acute.b', 'iris.1']: return

    collection = qp.datasets.fetch_UCILabelledCollection(dataset_name)
    for run, data in enumerate(qp.data.Dataset.kFCV(collection, nfolds=N_FOLDS, nrepeats=N_REPEATS)):
        if is_already_computed(args.results, dataset_name, model_name, run=run, optim_loss=optim_loss):
            print(f'result for dataset={dataset_name} model={model_name} loss={optim_loss} already computed.')
            continue

        print(f'running dataset={dataset_name} model={model_name} loss={optim_loss}')
        # model selection (hyperparameter optimization for a quantification-oriented loss)
        if hyperparams is not None:
            model_selection = qp.model_selection.GridSearchQ(
                model,
                param_grid=hyperparams,
                sample_size=SAMPLE_SIZE,
                n_prevpoints=21,
                n_repetitions=25,
                error=optim_loss,
                refit=True,
                timeout=60 * 60,
                verbose=True
            )
            model_selection.fit(data.training)
            model = model_selection.best_model()
            best_params = model_selection.best_params_
        else:
            model.fit(data.training)
            best_params = {}

        # model evaluation
        true_prevalences, estim_prevalences = qp.evaluation.artificial_prevalence_prediction(
            model,
            test=data.test,
            sample_size=SAMPLE_SIZE,
            n_prevpoints=21,
            n_repetitions=100,
            n_jobs=-1 if isinstance(model, qp.method.meta.Ensemble) else 1
        )
        test_true_prevalence = data.test.prevalence()

        evaluate_experiment(true_prevalences, estim_prevalences)
        save_results(args.results, dataset_name, model_name, run, optim_loss,
                     true_prevalences, estim_prevalences,
                     data.training.prevalence(), test_true_prevalence,
                     best_params)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run experiments for UCI ML Quantification')
    parser.add_argument('results', metavar='RESULT_PATH', type=str,
                        help='path to the directory where to store the results')
    parser.add_argument('--svmperfpath', metavar='SVMPERF_PATH', type=str, default='./svm_perf_quantification',
                        help='path to the directory with svmperf')
    parser.add_argument('--checkpointdir', metavar='PATH', type=str, default='./checkpoint',
                        help='path to the directory where to dump QuaNet checkpoints')
    args = parser.parse_args()

    print(f'Result folder: {args.results}')
    np.random.seed(0)

    qp.environ['SVMPERF_HOME'] = args.svmperfpath

    optim_losses = ['mae']
    datasets = qp.datasets.UCI_DATASETS

    models = quantification_models()
    # for runargs in itertools.product(optim_losses, datasets, models):
    #     run(runargs)
    qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=N_JOBS)

    # models = quantification_cuda_models()
    # qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=CUDA_N_JOBS)

    # models = quantification_ensembles()
    # qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=1)

    shutil.rmtree(args.checkpointdir, ignore_errors=True)