from sklearn.calibration import CalibratedClassifierCV

import quapy as qp
from sklearn.linear_model import LogisticRegression

from class_weight_model import ClassWeightPCC
# from classification.methods import LowRankLogisticRegression
# from method.experimental import ExpMax, VarExpMax
from common import *
from method.meta import QuaNet
from quantification_stumps_model import QuantificationStumpRegressor
from quapy.method.aggregative import CC, ACC, PCC, PACC, MAX, MS, MS2, EMQ, SVMAE, HDy
from quapy.method.meta import EHDy
import numpy as np
import os
import pickle
import itertools
import argparse
import torch
import shutil


SAMPLE_SIZE = 500

N_JOBS = -1
CUDA_N_JOBS = 2
ENSEMBLE_N_JOBS = -1

qp.environ['SAMPLE_SIZE'] = SAMPLE_SIZE

__C_range = np.logspace(-3, 3, 7)
lr_params = {'C': __C_range, 'class_weight': [None, 'balanced']}
svmperf_params = {'C': __C_range}


def quantification_models():
    # yield 'cc', CC(newLR()), lr_params
    # yield 'acc', ACC(newLR()), lr_params
    # yield 'pcc', PCC(newLR()), None
    # yield 'pacc', PACC(newLR()), None
    # yield 'wpacc', ClassWeightPCC(), None
    # yield 'pcc.opt', PCC(newLR()), lr_params
    # yield 'pacc.opt', PACC(newLR()), lr_params
    # yield 'wpacc.opt', ClassWeightPCC(), lr_params
    yield 'ds', QuantificationStumpRegressor(SAMPLE_SIZE, 21, 10), None
    # yield 'ds.opt', QuantificationStumpRegressor(SAMPLE_SIZE), {'C': __C_range}
    # yield 'MAX', MAX(newLR()), lr_params
    # yield 'MS', MS(newLR()), lr_params
    # yield 'MS2', MS2(newLR()), lr_params
    # yield 'sldc', EMQ(calibratedLR()), lr_params
    # yield 'svmmae', SVMAE(), svmperf_params
    # yield 'hdy', HDy(newLR()), lr_params
    # yield 'EMdiag', ExpMax(cov_type='diag'), None
    # yield 'EMfull', ExpMax(cov_type='full'), None
    # yield 'EMtied', ExpMax(cov_type='tied'), None
    # yield 'EMspherical', ExpMax(cov_type='spherical'), None
    # yield 'VEMdiag', VarExpMax(cov_type='diag'), None
    # yield 'VEMfull', VarExpMax(cov_type='full'), None
    # yield 'VEMtied', VarExpMax(cov_type='tied'), None
    # yield 'VEMspherical', VarExpMax(cov_type='spherical'), None


# def quantification_cuda_models():
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'
#     print(f'Running QuaNet in {device}')
#     learner = LowRankLogisticRegression(**newLR().get_params())
#     yield 'quanet', QuaNet(learner, SAMPLE_SIZE, checkpointdir=args.checkpointdir, device=device), lr_params


def quantification_ensembles():
    param_mod_sel = {
        'sample_size': SAMPLE_SIZE,
        'n_prevpoints': 21,
        'n_repetitions': 5,
        'refit': True,
        'verbose': False
    }
    common = {
        'size': 30,
        'red_size': 15,
        'max_sample_size': None,  # same as training set
        'n_jobs': ENSEMBLE_N_JOBS,
        'param_grid': lr_params,
        'param_mod_sel': param_mod_sel,
        'val_split': 0.4,
        'min_pos': 5
    }

    # hyperparameters will be evaluated within each quantifier of the ensemble, and so the typical model selection
    # will be skipped (by setting hyperparameters to None)
    hyper_none = None
    yield 'ehdymaeds',  EHDy(newLR(), optim='mae', policy='ds', **common), hyper_none



def run(experiment):
    optim_loss, dataset_name, (model_name, model, hyperparams) = experiment
    if dataset_name == 'imdb':
        return
    data = qp.datasets.fetch_reviews(dataset_name, tfidf=True, min_df=5)
    run=0

    if is_already_computed(args.results, dataset_name, model_name, run=run, optim_loss=optim_loss):
        print(f'result for dataset={dataset_name} model={model_name} loss={optim_loss} already computed.')
        return

    print(f'running dataset={dataset_name} model={model_name} loss={optim_loss}')
    # model selection (hyperparameter optimization for a quantification-oriented loss)
    if hyperparams is not None:
        model_selection = qp.model_selection.GridSearchQ(
            model,
            param_grid=hyperparams,
            sample_size=SAMPLE_SIZE,
            n_prevpoints=21,
            n_repetitions=100,
            error=optim_loss,
            refit=True,
            timeout=60 * 60,
            verbose=True
        )
        model_selection.fit(data.training)
        model = model_selection.best_model()
        best_params = model_selection.best_params_
    else:
        model.fit(data.training)
        best_params = {}

    # model evaluation
    true_prevalences, estim_prevalences = qp.evaluation.artificial_prevalence_prediction(
        model,
        test=data.test,
        sample_size=SAMPLE_SIZE,
        n_prevpoints=21,  # 21
        n_repetitions=10,  # 100
        n_jobs=-1 if isinstance(model, qp.method.meta.Ensemble) else 1,
        verbose=True
    )
    test_true_prevalence = data.test.prevalence()

    evaluate_experiment(true_prevalences, estim_prevalences)
    save_results(args.results, dataset_name, model_name, run, optim_loss,
                 true_prevalences, estim_prevalences,
                 data.training.prevalence(), test_true_prevalence,
                 best_params)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run experiments for Tweeter Sentiment Quantification')
    parser.add_argument('results', metavar='RESULT_PATH', type=str,
                        help='path to the directory where to store the results')
    parser.add_argument('--svmperfpath', metavar='SVMPERF_PATH', type=str, default='./svm_perf_quantification',
                        help='path to the directory with svmperf')
    parser.add_argument('--checkpointdir', metavar='PATH', type=str, default='./checkpoint',
                        help='path to the directory where to dump QuaNet checkpoints')
    args = parser.parse_args()

    print(f'Result folder: {args.results}')
    np.random.seed(0)

    qp.environ['SVMPERF_HOME'] = args.svmperfpath

    optim_losses = ['mae']
    datasets = qp.datasets.REVIEWS_SENTIMENT_DATASETS

    models = quantification_models()
    qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=N_JOBS)

    # models = quantification_cuda_models()
    # qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=CUDA_N_JOBS)

    # models = quantification_ensembles()
    # qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=1)

    shutil.rmtree(args.checkpointdir, ignore_errors=True)