import pickle
import os
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression

import quapy as qp



def newLR():
    return LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1)


def calibratedLR():
    return CalibratedClassifierCV(LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1))


def save_results(result_dir, dataset_name, model_name, run, optim_loss, *results):
    rpath = result_path(result_dir, dataset_name, model_name, run, optim_loss)
    qp.util.create_parent_dir(rpath)
    with open(rpath, 'wb') as foo:
        pickle.dump(tuple(results), foo, pickle.HIGHEST_PROTOCOL)


def evaluate_experiment(true_prevalences, estim_prevalences):
    print('\nEvaluation Metrics:\n' + '=' * 22)
    for eval_measure in [qp.error.mae, qp.error.mrae]:
        err = eval_measure(true_prevalences, estim_prevalences)
        print(f'\t{eval_measure.__name__}={err:.4f}')
    print()


def result_path(path, dataset_name, model_name, run, optim_loss):
    return os.path.join(path, f'{dataset_name}-{model_name}-run{run}-{optim_loss}.pkl')


def is_already_computed(result_dir, dataset_name, model_name, run, optim_loss):
    return os.path.exists(result_path(result_dir, dataset_name, model_name, run, optim_loss))


nice = {
    'pacc.opt': 'PACC(LR)',
    'pacc.opt.svm': 'PACC(SVM)',
    'pcc.opt': 'PCC(LR)',
    'pcc.opt.svm': 'PCC(SVM)',
    'wpacc.opt': 'R-PCC(LR)',
    'wpacc.opt.svm': 'R-PCC(SVM)',
    'mae':'AE',
    'ae':'AE',
    'svmkld': 'SVM(KLD)',
    'svmnkld': 'SVM(NKLD)',
    'svmq': 'SVM(Q)',
    'svmae': 'SVM(AE)',
    'svmmae': 'SVM(AE)',
    'svmmrae': 'SVM(RAE)',
    'hdy': 'HDy',
    'sldc': 'SLD',
    'X': 'TSX',
    'T50': 'TS50',
    'ehdymaeds': 'E(HDy)$_\mathrm{DS}$',
    'Average': 'Average',
    'EMdiag':'EM$_{diag}$', 'EMfull':'EM$_{full}$', 'EMtied':'EM$_{tied}$', 'EMspherical':'EM$_{sph}$',
    'VEMdiag':'VEM$_{diag}$', 'VEMfull':'VEM$_{full}$', 'VEMtied':'VEM$_{tied}$', 'VEMspherical':'VEM$_{sph}$',
}


def nicerm(key):
    return '\mathrm{'+nice[key]+'}'


def nicename(method, eval_name=None, side=False):
    m = nice.get(method, method.upper())
    if eval_name is not None:
        m = m.replace('$$','')
    if side:
        m = '\side{'+m+'}'
    return m


def save_table(path, table):
    print(f'saving results in {path}')
    with open(path, 'wt') as foo:
        foo.write(table)


def experiment_errors(path, dataset, method, run, eval_loss, optim_loss=None):
    if optim_loss is None:
        optim_loss = eval_loss
    path = result_path(path, dataset, method, run, 'm' + optim_loss if not optim_loss.startswith('m') else optim_loss)
    if os.path.exists(path):
        true_prevs, estim_prevs, _, _, _ = pickle.load(open(path, 'rb'))
        err_fn = getattr(qp.error, eval_loss)
        errors = err_fn(true_prevs, estim_prevs)
        return errors
    return None