import itertools
import os.path
import pickle
from collections import defaultdict
from pathlib import Path

import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.svm import LinearSVC

import quapy as qp
from Retrieval.commons import RetrievedSamples, load_sample
from quapy.protocol import UPP
from quapy.method.non_aggregative import MaximumLikelihoodPrevalenceEstimation as Naive
from quapy.model_selection import GridSearchQ
from quapy.method.aggregative import ClassifyAndCount, EMQ, ACC, PCC, PACC, KDEyML
from quapy.data.base import LabelledCollection

from os.path import join
from tqdm import tqdm

from result_table.src.table import Table

"""
 
"""

data_home = 'data'

datasets = ['continent', 'gender', 'years_category'] #, 'relative_pageviews_category', 'num_sitelinks_category']

for class_name in datasets:

    train_data_path = join(data_home, class_name, 'FULL', 'classifier_training.json')  # <-------- fixed classifier
    texts, labels = load_sample(train_data_path, class_name=class_name)

    classifier_path = join('classifiers', 'FULL', f'classifier_{class_name}.pkl')
    tfidf, classifier_trained = pickle.load(open(classifier_path, 'rb'))
    classifier_hyper = classifier_trained.get_params()
    print(f'{classifier_hyper=}')

    X = tfidf.transform(texts)
    print(f'Xtr shape={X.shape}')

    pool = LabelledCollection(X, labels)
    train, val = pool.split_stratified(train_prop=0.5, random_state=0)
    q = KDEyML(LogisticRegression())
    classifier_hyper = {'classifier__C':[classifier_hyper['C'], 0.00000001], 'classifier__class_weight':[classifier_hyper['class_weight']]}
    quantifier_hyper = {'bandwidth': np.linspace(0.01, 0.2, 20)}
    hyper = {**classifier_hyper, **quantifier_hyper}
    qp.environ['SAMPLE_SIZE'] = 100
    modsel = GridSearchQ(
        model=q,
        param_grid=hyper,
        protocol=UPP(val, sample_size=100),
        n_jobs=-1,
        error='mrae',
        verbose=True
    )
    modsel.fit(train)

    print(class_name)
    print(f'{modsel.best_params_}')
    print(f'{modsel.best_score_}')