from typing import Union, Callable

import numpy as np
import quapy as qp
from MultiLabel.mlquantification import MLAggregativeQuantifier
from mldata import MultilabelledCollection
import itertools
from tqdm import tqdm


def __check_error(error_metric):
    if isinstance(error_metric, str):
        error_metric = qp.error.from_name(error_metric)

    assert hasattr(error_metric, '__call__'), 'invalid error function'
    return error_metric


def _ml_prevalence_predictions(model,
                               test: MultilabelledCollection,
                               test_indexes):

    predict_batch_fn = _predict_quantification_batch
    if isinstance(model, MLAggregativeQuantifier):
        test = MultilabelledCollection(model.preclassify(test.instances), test.labels)
        predict_batch_fn = _predict_aggregative_batch

    args = tuple([model, test, test_indexes])
    true_prevs, estim_prevs = predict_batch_fn(args)
    return true_prevs, estim_prevs


def ml_natural_prevalence_prediction(model,
                                     test:MultilabelledCollection,
                                     sample_size,
                                     repeats=100,
                                     random_seed=42):

    with qp.util.temp_seed(random_seed):
        test_indexes = list(test.natural_sampling_index_generator(sample_size=sample_size, repeats=repeats))

    return _ml_prevalence_predictions(model, test, test_indexes)


def ml_natural_prevalence_evaluation(model,
                                     test:MultilabelledCollection,
                                     sample_size,
                                     repeats=100,
                                     error_metric:Union[str,Callable]='mae',
                                     random_seed=42):

    error_metric = __check_error(error_metric)

    true_prevs, estim_prevs = ml_natural_prevalence_prediction(model, test, sample_size, repeats, random_seed)

    errs = [error_metric(true_prev_i, estim_prev_i) for true_prev_i, estim_prev_i in zip(true_prevs, estim_prevs)]
    return np.mean(errs)


def ml_artificial_prevalence_prediction(model,
                                        test:MultilabelledCollection,
                                        sample_size,
                                        n_prevalences=21,
                                        repeats=10,
                                        random_seed=42):

    nested_test_indexes = []
    with qp.util.temp_seed(random_seed):
        for cat in test.classes_:
            nested_test_indexes.append(list(test.artificial_sampling_index_generator(sample_size=sample_size,
                                                                              category=cat,
                                                                              n_prevalences=n_prevalences,
                                                                              repeats=repeats)))
    def _predict_batch(test_indexes):
        return _ml_prevalence_predictions(model, test, test_indexes)

    predictions = qp.util.parallel(_predict_batch, nested_test_indexes, n_jobs=-1)
    true_prevs = list(itertools.chain.from_iterable(trues for trues, estims in predictions))
    estim_prevs = list(itertools.chain.from_iterable(estims for trues, estims in predictions))
    return true_prevs, estim_prevs


def ml_artificial_prevalence_evaluation(model,
                                        test:MultilabelledCollection,
                                        sample_size,
                                        n_prevalences=21,
                                        repeats=10,
                                        error_metric:Union[str,Callable]='mae',
                                        random_seed=42):

    error_metric = __check_error(error_metric)

    true_prevs, estim_prevs = ml_artificial_prevalence_prediction(model, test, sample_size, n_prevalences, repeats, random_seed)

    errs = [error_metric(true_prev_i, estim_prev_i) for true_prev_i, estim_prev_i in zip(true_prevs, estim_prevs)]
    return np.mean(errs)


def _predict_quantification_batch(args):
    model, test, indexes = args
    return __predict_batch_fn(args, model.quantify)


def _predict_aggregative_batch(args):
    model, test, indexes = args
    return __predict_batch_fn(args, model.aggregate)


def __predict_batch_fn(args, quant_fn):
    model, test, indexes = args
    trues, estims = [], []
    for index in indexes:
        sample = test.sampling_from_index(index)
        estims.append(quant_fn(sample.instances))
        trues.append(sample.prevalence())
    return trues, estims