from ClassifierAccuracy.util.commons import *
from ClassifierAccuracy.util.plotting import plot_diagonal

PROBLEM = 'multiclass'
ORACLE = True
basedir = PROBLEM+('-oracle' if ORACLE else '')


if PROBLEM == 'binary':
    qp.environ['SAMPLE_SIZE'] = 1000
    NUM_TEST = 1000
    gen_datasets = gen_bin_datasets
elif PROBLEM == 'multiclass':
    qp.environ['SAMPLE_SIZE'] = 250
    NUM_TEST = 1000
    gen_datasets = gen_multi_datasets


for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifiers(), gen_datasets()):
    print(f'training {cls_name} in {dataset_name}')
    h.fit(*L.Xy)

    # test generation protocol
    test_prot = UPP(U, repeats=NUM_TEST, return_type='labelled_collection', random_state=0)

    # compute some stats of the dataset
    get_dataset_stats(f'dataset_stats/{dataset_name}.json', test_prot, L, V)

    # precompute the actual accuracy values
    true_accs = {}
    for acc_name, acc_fn in gen_acc_measure():
        true_accs[acc_name] = [true_acc(h, acc_fn, Ui) for Ui in test_prot()]

    # instances of ClassifierAccuracyPrediction are bound to the evaluation measure, so they
    # must be nested in the acc-for
    for acc_name, acc_fn in gen_acc_measure():
        print(f'\tfor measure {acc_name}')
        for (method_name, method) in gen_CAP(h, acc_fn, with_oracle=ORACLE):
            result_path = getpath(basedir, cls_name, acc_name, dataset_name, method_name)
            if os.path.exists(result_path):
                print(f'\t\t{method_name}-{acc_name} exists, skipping')
                continue

            print(f'\t\t{method_name} computing...')
            method, t_train = fit_method(method, V)
            estim_accs, t_test_ave = predictionsCAP(method, test_prot, ORACLE)
            save_json_result(result_path, true_accs[acc_name], estim_accs, t_train, t_test_ave)

    # instances of CAPContingencyTable instead are generic, and the evaluation measure can
    # be nested to the predictions to speed up things
    for (method_name, method) in gen_CAP_cont_table(h):
        if not any_missing(basedir, cls_name, dataset_name, method_name):
            print(f'\t\tmethod {method_name} has all results already computed. Skipping.')
            continue

        print(f'\t\tmethod {method_name} computing...')

        method, t_train = fit_method(method, V)
        estim_accs_dict, t_test_ave = predictionsCAPcont_table(method, test_prot, gen_acc_measure, ORACLE)
        for acc_name in estim_accs_dict.keys():
            result_path = getpath(basedir, cls_name, acc_name, dataset_name, method_name)
            save_json_result(result_path, true_accs[acc_name], estim_accs_dict[acc_name], t_train, t_test_ave)

    print()

# generate diagonal plots
print('generating plots')
for (cls_name, _), (acc_name, _) in itertools.product(gen_classifiers(), gen_acc_measure()):
    plot_diagonal(basedir, cls_name, acc_name)
    for dataset_name, _ in gen_datasets(only_names=True):
        plot_diagonal(basedir, cls_name, acc_name, dataset_name=dataset_name)

print('generating tables')
gen_tables(basedir, datasets=[d for d,_ in gen_datasets(only_names=True)])