import sys
import sklearn
from sklearn.base import BaseEstimator
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC, SVC

import quapy as qp
from quapy.data import LabelledCollection
from quapy.method.aggregative import EMQ, CC, PACC, PCC, HDy, ACC
import numpy as np
from itertools import chain
import argparse



def NewClassifier(classifiername):
    if classifiername== 'lr':
        return LogisticRegression(class_weight='balanced')
    elif classifiername== 'svm':
        return SVC(class_weight='balanced', probability=True, kernel='linear')


def NewQuantifier(quantifiername, classifiername):
    if quantifiername == 'EMQ':
        return EMQ(CalibratedClassifierCV(NewClassifier(classifiername)))
        # return EMQ(NewClassifier(classifier))
    if quantifiername == 'HDy':
        return HDy(NewClassifier(classifiername))
    if quantifiername == 'PCC':
        return PCC(NewClassifier(classifiername))
    if quantifiername == 'ACC':
        return ACC(NewClassifier(classifiername), val_split=5)
    if quantifiername == 'PACC':
        return PACC(NewClassifier(classifiername), val_split=5)
    raise ValueError('unknown quantifier', quantifiername)


def experiment_name(args:argparse.Namespace):
    return '__'.join([f'{k}:{getattr(args, k)}' for k in sorted(vars(args).keys())]) + '.csv'


def split_from_index(collection: LabelledCollection, index: np.ndarray):
    in_index_set = set(index)
    out_index_set = set(range(len(collection))) - in_index_set
    out_index = np.asarray(sorted(out_index_set), dtype=int)
    return collection.sampling_from_index(index), collection.sampling_from_index(out_index)


def move_documents(target: LabelledCollection, origin: LabelledCollection, idx_origin: np.ndarray):
    # moves documents (indexed by idx_origin) from origin to target
    selected, reduced_origin = split_from_index(origin, idx_origin)
    enhanced_target = target + selected
    return enhanced_target, reduced_origin


def uniform_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
    return np.random.choice(len(pool), k, replace=False)


def proportional_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
    prob = classifier.predict_proba(pool.instances)[:, 1].flatten()
    return np.random.choice(len(pool), k, replace=False, p=prob/prob.sum())


def relevance_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
    prob = classifier.predict_proba(pool.instances)[:, 1].flatten()
    top_relevant_idx = np.argsort(-prob)[:k]
    return top_relevant_idx


def uncertainty_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
    prob = classifier.predict_proba(pool.instances)[:, 1].flatten()
    top_uncertain_idx = np.argsort(np.abs(prob - 0.5))[:k]
    return top_uncertain_idx


def mix_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
    relevance_idx = relevance_sampling(pool, classifier, k)
    uncertanty_idx = uncertainty_sampling(pool, classifier, k)
    interleave_idx = np.asarray(list(chain.from_iterable(zip(relevance_idx, uncertanty_idx))))
    _, unique_idx = np.unique(interleave_idx, return_index=True)
    top_interleaved_idx = interleave_idx[np.sort(unique_idx)][:k]
    return top_interleaved_idx


def adaptive_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, progress: float):
    relevance_k = int(k*progress/100)
    uncertanty_k = k - relevance_k
    relevance_idx = relevance_sampling(pool, classifier, relevance_k)
    uncertainty_idx = uncertainty_sampling(pool, classifier, uncertanty_k)
    idx = np.concatenate([relevance_idx, uncertainty_idx])
    idx = np.unique(idx)
    return idx


def negative_sampling_index(pool: LabelledCollection, classifier: BaseEstimator, k: int):
    prob = classifier.predict_proba(pool.instances)[:, 0].flatten()
    top_relevant_idx = np.argsort(-prob)[:k]
    return top_relevant_idx


def recall(train_prev, pool_prev, train_size, pool_size):
    frac_tr_pos = train_prev[1]
    frac_te_pos = pool_prev[1]
    recall = (frac_tr_pos * train_size) / (frac_tr_pos * train_size + frac_te_pos * pool_size)
    return recall


def create_dataset(datasetname):
    if datasetname == 'imdb.10K.75p':
        data = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=5)
        collection = data.training.sampling(10000, 0.75)
        return collection

    elif datasetname == 'RCV1.C4':
        X, y = sklearn.datasets.fetch_rcv1(subset='train', return_X_y=True)
        y = y.toarray()
        prev = y.mean(axis=0).flatten()
        # choose the first category having a positive prevalence between [0.1,0.2] (realistic scenario for e-Discovery)
        # this category happens to be the cat with id 4
        target_cat = np.argwhere(np.logical_and(prev > 0.1, prev < 0.2)).flatten()[0]
        print('chosen cat', target_cat)
        y = y[:, target_cat].flatten()
        return LabelledCollection(X, y)

    elif datasetname == 'hp':
        data = qp.datasets.fetch_reviews('hp', tfidf=True, min_df=5)
        collection = data.training + data.test
        collection = LabelledCollection(instances=collection.instances, labels=1-collection.labels)
        return collection

    print(f'unknown dataset {datasetname}. Abort')
    sys.exit(0)


def estimate_prev_CC(train, pool: LabelledCollection, classifiername:str):
    q = CC(NewClassifier(classifiername)).fit(train)
    # q = NewQuantifier("PCC").fit(train)
    return q.quantify(pool.instances), q.learner
    # small_pool = pool.sampling(100, *pool.prevalence())
    # return q.quantify(small_pool.instances), q.learner


def estimate_prev_Q(train, pool, quantifiername, classifiername):
    # q = qp.model_selection.GridSearchQ(
    #     ACC(LogisticRegression()),
    #     param_grid={'C':np.logspace(-3,3,7), 'class_weight':[None, 'balanced']},
    #     sample_size=len(train),
    #     protocol='app',
    #     n_prevpoints=21,
    #     n_repetitions=10)

    q = NewQuantifier(quantifiername, classifiername)
    # q = ACC(NewClassifier())
    # borrow (supposedly negative) pool documents
    # train_pos = train.counts()[1]
    # train_negs = train.counts()[0]
    # neg_idx = negative_sampling_index(pool, classifier, max(train_pos-train_negs, 5))
    # neg_sample = pool.sampling_from_index(neg_idx)
    # train_augmented = train + LabelledCollection(neg_sample.instances, [0]*len(neg_sample))
    # q.fit(train_augmented)
    q.fit(train)
    # q.fit(first_train)
    # bootstrap_prev = qp.evaluation.natural_prevalence_prediction(q, pool, sample_size=len(train), n_repetitions=50)[1].mean(axis=0).flatten()

    prev = q.quantify(pool.instances)
    return prev, q.learner
    # small_pool = pool.sampling(100, *pool.prevalence())
    # return q.quantify(small_pool.instances), q.learner