from statistics import mean
from typing import Dict
from sklearn.base import BaseEstimator
from sklearn.model_selection import cross_validate
from quapy.data import LabelledCollection
import garg22_ATC.ATC_helper as atc
import numpy as np
import jiang18_trustscore.trustscore as trustscore
import guillory21_doc.doc as doc


def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict:
    scoring = ["f1_macro"]
    scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring)
    return {"f1_score": mean(scores["test_f1_macro"])}


def atc_mc(
    c_model: BaseEstimator,
    validation: LabelledCollection,
    test: LabelledCollection,
    predict_method="predict_proba",
):
    c_model_predict = getattr(c_model, predict_method)

    ## Load ID validation data probs and labels
    val_probs, val_labels = c_model_predict(validation.X), validation.y

    ## Load OOD test data probs
    test_probs = c_model_predict(test.X)

    ## score function, e.g., negative entropy or argmax confidence
    val_scores = atc.get_max_conf(val_probs)
    val_preds = np.argmax(val_probs, axis=-1)
    test_scores = atc.get_max_conf(test_probs)

    _, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
    atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)

    return {
        "true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y),
        "pred_acc": atc_accuracy,
    }


def atc_ne(
    c_model: BaseEstimator,
    validation: LabelledCollection,
    test: LabelledCollection,
    predict_method="predict_proba",
):
    c_model_predict = getattr(c_model, predict_method)

    ## Load ID validation data probs and labels
    val_probs, val_labels = c_model_predict(validation.X), validation.y

    ## Load OOD test data probs
    test_probs = c_model_predict(test.X)

    ## score function, e.g., negative entropy or argmax confidence
    val_scores = atc.get_entropy(val_probs)
    val_preds = np.argmax(val_probs, axis=-1)

    test_scores = atc.get_entropy(test_probs)

    _, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
    atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)

    return {
        "true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y),
        "pred_acc": atc_accuracy,
    }


def trust_score(
    c_model: BaseEstimator,
    validation: LabelledCollection,
    test: LabelledCollection,
    predict_method="predict",
):
    c_model_predict = getattr(c_model, predict_method)

    test_pred = c_model_predict(test.X)

    trust_model = trustscore.TrustScore()
    trust_model.fit(validation.X, validation.y)

    return trust_model.get_score(test.X, test_pred)


def doc_feat(
    c_model: BaseEstimator,
    validation: LabelledCollection,
    test: LabelledCollection,
    predict_method="predict_proba",
):
    c_model_predict = getattr(c_model, predict_method)

    val_probs, val_labels = c_model_predict(validation.X), validation.y
    test_probs = c_model_predict(test.X)
    val_scores = np.max(val_probs, axis=-1)
    test_scores = np.max(test_probs, axis=-1)
    val_preds = np.argmax(val_probs, axis=-1)

    v1acc = np.mean(val_preds == val_labels) * 100
    return v1acc + doc.get_doc(val_scores, test_scores)