import numpy as np
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression

from Census.methods import AreaQuantifier, AggregationRule, optimize_ensemble
from quapy.data import LabelledCollection
from quapy.method.non_aggregative import MaximumLikelihoodPrevalenceEstimation as MLPE
from quapy.method.aggregative import CC, PCC, ACC, PACC, EMQ, MS, MS2
from commons import *
from table import Table
from tqdm import tqdm
import quapy as qp
from copy import deepcopy


np.set_printoptions(linewidth=np.inf)


def classifier():
    return LogisticRegression()


def quantifiers():
    cls = classifier()
    # yield 'MLPE', MLPE()
    yield 'CC', CC(cls)
    yield 'PCC', PCC(cls)
    yield 'ACC', ACC(cls)
    yield 'PACC', PACC(cls)
    yield 'MS', MS(cls)
    # yield 'MS2', MS2(cls)
    yield 'SLD', EMQ(cls)


survey_y = './data/survey_y.csv'

Atr, Xtr, ytr = load_csv(survey_y, use_yhat=True)

preprocessor = Preprocessor()
Xtr = preprocessor.fit_transform(Xtr)

data = get_dataset_by_area(Atr, Xtr, ytr)
n_areas = len(data)

areas = [Ai for Ai, _, _ in data]
q_names = [q_name for q_name, _ in quantifiers()]

Madj = AdjMatrix('./data/matrice_adiacenza.csv')

tables = []
text_outputs = []

benchmarks  = [f'te-{Ai}' for Ai in areas]  # areas used as test

for aggr in ['median', 'mean']:

    # areas on which a quantifier is trained, e.g., 'PACC-w/o46' means a PACC quantifier
    # has been trained on all areas but 46
    methods     = [f'{q_name}-{aggr}' for q_name in q_names]

    table = Table(name=f'adjacent{aggr}optim', benchmarks=benchmarks, methods=methods, stat_test=None, color_mode='local')
    table.format.mean_prec = 4
    table.format.show_std = False
    table.format.sta = False
    table.format.remove_zero = True


    for q_name, q in quantifiers():
        for i, (Ai, Xi, yi) in tqdm(enumerate(data), total=n_areas):
            # compose members of the rule (quantifiers are optimized wrt the rest of the areas)
            #training
            other_area = [(Aj, Xj, yj) for Aj, Xj, yj in data if Aj != Ai]
            area_quantifiers = optimize_ensemble(other_area, q, Madj)
            rule = AggregationRule(area_quantifiers, adjacent_matrix=Madj, aggr=aggr)

            #test
            te = LabelledCollection(Xi, yi)
            qp.environ["SAMPLE_SIZE"] = len(te)
            pred_prev = rule.predict(Ai, te.X)
            true_prev = te.prevalence()
            err = qp.error.mae(true_prev, pred_prev)

            method_name = f'{q_name}-{aggr}'
            table.add(benchmark=f'te-{Ai}', method=method_name, v=err)

        # text_outputs.append(f'{q_name} got mean {table.all_mean():.5f}, best mean {table.get_method_values("Best").mean():.5f}')

    tables.append(table)

Table.LatexPDF(f'./results/adjacentaggregationoptim/doc.pdf', tables)

# with open(f'./results/classifier/output.txt', 'tw') as foo:
#     foo.write('\n'.join(text_outputs))