import sys
import sklearn
from sklearn.base import BaseEstimator
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC, SVC

import quapy as qp
from eDiscovery.method import RegionAdjustment, RegionProbAdjustment, RegionProbAdjustmentGlobal, RegionAdjustmentQ, \
    ClassWeightPCC
from quapy.data import LabelledCollection
from quapy.method.aggregative import EMQ, CC, PACC, PCC, HDy, ACC
import numpy as np
from itertools import chain
import argparse



def NewClassifier(classifiername):
    if classifiername== 'lr':
        return LogisticRegression(class_weight='balanced')
    elif classifiername== 'svm':
        # return SVC(class_weight='balanced', probability=True, kernel='linear')
        return CalibratedClassifierCV(LinearSVC(class_weight='balanced'))


def NewQuantifier(quantifiername, classifiername):
    if quantifiername == 'EMQ':
        return EMQ(CalibratedClassifierCV(NewClassifier(classifiername)))
        # return EMQ(NewClassifier(classifier))
    if quantifiername == 'CC':
        return CC(NewClassifier(classifiername))
    if quantifiername == 'HDy':
        return HDy(NewClassifier(classifiername))
    if quantifiername == 'PCC':
        return PCC(NewClassifier(classifiername))
    if quantifiername == 'ACC':
        return ACC(NewClassifier(classifiername), val_split=0.4)
    if quantifiername == 'PACC':
        return PACC(NewClassifier(classifiername), val_split=0.4)
    if quantifiername == 'CW':
        return ClassWeightPCC()
    if quantifiername == 'SRSQ':  # supervised regions, then single-label quantification
        #q = EMQ(CalibratedClassifierCV(NewClassifier(classifiername)))
        #q = PACC(NewClassifier(classifiername), val_split=0.4)
        q = ACC(NewClassifier(classifiername))
        return RegionAdjustmentQ(q, k=4)
    if quantifiername == 'URBQ':  # unsupervised regions, then binary quantifications
        def newQ():
            # return PACC(NewClassifier(classifiername), val_split=0.4)
            # return CC(CalibratedClassifierCV(NewClassifier(classifiername)))
            # return ClassWeightPCC()
            return CC(NewClassifier(classifiername))
        return RegionProbAdjustmentGlobal(newQ, k=20, clustering='kmeans')
    raise ValueError('unknown quantifier', quantifiername)


def experiment_name(args:argparse.Namespace):
    return '__'.join([f'{k}:{getattr(args, k)}' for k in sorted(vars(args).keys())]) + '.csv'


def split_from_index(collection: LabelledCollection, index: np.ndarray):
    in_index_set = set(index)
    out_index_set = set(range(len(collection))) - in_index_set
    out_index = np.asarray(sorted(out_index_set), dtype=int)
    return collection.sampling_from_index(index), collection.sampling_from_index(out_index)


def move_documents(target: LabelledCollection, origin: LabelledCollection, idx_origin: np.ndarray):
    # moves documents (indexed by idx_origin) from origin to target
    selected, reduced_origin = split_from_index(origin, idx_origin)
    enhanced_target = target + selected
    return enhanced_target, reduced_origin


def uniform_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
    return np.random.choice(len(pool), k, replace=False)


def proportional_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
    prob = classifier.predict_proba(pool.instances)[:, 1].flatten()
    return np.random.choice(len(pool), k, replace=False, p=prob/prob.sum())


def relevance_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
    prob = classifier.predict_proba(pool.instances)[:, 1].flatten()
    top_relevant_idx = np.argsort(-prob)[:k]
    return top_relevant_idx


def uncertainty_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
    prob = classifier.predict_proba(pool.instances)[:, 1].flatten()
    top_uncertain_idx = np.argsort(np.abs(prob - 0.5))[:k]
    return top_uncertain_idx


def mix_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
    relevance_idx = relevance_sampling(pool, classifier, k)
    uncertanty_idx = uncertainty_sampling(pool, classifier, k)
    interleave_idx = np.asarray(list(chain.from_iterable(zip(relevance_idx, uncertanty_idx))))
    _, unique_idx = np.unique(interleave_idx, return_index=True)
    top_interleaved_idx = interleave_idx[np.sort(unique_idx)][:k]
    return top_interleaved_idx


def adaptive_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, progress: float):
    relevance_k = int(k*progress/100)
    uncertanty_k = k - relevance_k
    relevance_idx = relevance_sampling(pool, classifier, relevance_k)
    uncertainty_idx = uncertainty_sampling(pool, classifier, uncertanty_k)
    idx = np.concatenate([relevance_idx, uncertainty_idx])
    idx = np.unique(idx)
    return idx


def negative_sampling_index(pool: LabelledCollection, classifier: BaseEstimator, k: int):
    prob = classifier.predict_proba(pool.instances)[:, 0].flatten()
    top_relevant_idx = np.argsort(-prob)[:k]
    return top_relevant_idx


def recall(train_prev, pool_prev, train_size, pool_size):
    frac_tr_pos = train_prev[1]
    frac_te_pos = pool_prev[1]
    recall = (frac_tr_pos * train_size) / (frac_tr_pos * train_size + frac_te_pos * pool_size)
    return recall


def create_dataset(datasetname):
    if datasetname == 'imdb.10K.75p':
        data = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=5)
        collection = data.training.sampling(10000, 0.75)
        return collection

    elif datasetname == 'RCV1.C4':
        X, y = sklearn.datasets.fetch_rcv1(subset='train', return_X_y=True)
        y = y.toarray()
        prev = y.mean(axis=0).flatten()
        # choose the first category having a positive prevalence between [0.1,0.2] (realistic scenario for e-Discovery)
        # this category happens to be the cat with id 4
        target_cat = np.argwhere(np.logical_and(prev > 0.1, prev < 0.2)).flatten()[0]
        print('chosen cat', target_cat)
        y = y[:, target_cat].flatten()
        return LabelledCollection(X, y)

    elif datasetname == 'hp':
        data = qp.datasets.fetch_reviews('hp', tfidf=True, min_df=5)
        collection = data.training + data.test
        collection = LabelledCollection(instances=collection.instances, labels=1-collection.labels)
        return collection

    print(f'unknown dataset {datasetname}. Abort')
    sys.exit(0)


def estimate_prev_CC(train, pool: LabelledCollection, classifiername:str):
    q = CC(NewClassifier(classifiername)).fit(train)
    return q.quantify(pool.instances), q.learner


def estimate_prev_Q(train, pool, quantifiername, classifiername):
    # q = qp.model_selection.GridSearchQ(
    #     ACC(LogisticRegression()),
    #     param_grid={'C':np.logspace(-3,3,7), 'class_weight':[None, 'balanced']},
    #     sample_size=len(train),
    #     protocol='app',
    #     n_prevpoints=21,
    #     n_repetitions=10)

    q = NewQuantifier(quantifiername, classifiername)
    # q._find_regions((train+pool).instances)
    q.fit(train)

    prev = q.quantify(pool.instances)
    return prev, None