from copy import deepcopy

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression, LinearRegression

import quapy as qp
from sklearn import clone
from sklearn.metrics import confusion_matrix
import scipy
from scipy.sparse import issparse, csr_matrix

from data import LabelledCollection
from abc import ABC, abstractmethod
from sklearn.model_selection import cross_val_predict

from quapy.protocol import UPP
from quapy.method.base import BaseQuantifier
from quapy.method.aggregative import PACC, AggregativeQuantifier
import quapy.functional as F


class ClassifierAccuracyPrediction(ABC):
    def __init__(self, h: BaseEstimator, acc: callable):
        self.h = h
        self.acc = acc

    @abstractmethod
    def fit(self, val: LabelledCollection):
        ...

    @abstractmethod
    def predict(self, X, oracle_prev=None):
        """
        Evaluates the accuracy function on the predicted contingency table

        :param X: test data
        :param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by
            an oracle. This is meant to test the effect of the errors in CAP that are explained by
            the errors in quantification performance
        :return: float
        """
        return ...

    def true_acc(self, sample: LabelledCollection):
        y_pred = self.h.predict(sample.X)
        y_true = sample.y
        conf_table = confusion_matrix(y_true, y_pred=y_pred, labels=sample.classes_)
        return self.acc(conf_table)


class CAPContingencyTable(ClassifierAccuracyPrediction):

    def __init__(self, h: BaseEstimator, acc: callable):
        self.h = h
        self.acc = acc

    def predict(self, X, oracle_prev=None):
        """
        Evaluates the accuracy function on the predicted contingency table

        :param X: test data
        :param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by
            an oracle. This is meant to test the effect of the errors in CAP that are explained by
            the errors in quantification performance
        :return: float
        """
        cont_table = self.predict_ct(X, oracle_prev)
        raw_acc = self.acc(cont_table)
        norm_acc = np.clip(raw_acc, 0, 1)
        return norm_acc

    @abstractmethod
    def predict_ct(self, X, oracle_prev=None):
        """
        Predicts the contingency table for the test data

        :param X: test data
        :param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by
            an oracle. This is meant to test the effect of the errors in CAP that are explained by
            the errors in quantification performance
        :return: a contingency table
        """
        ...


class NaiveCAP(CAPContingencyTable):
    """
    The Naive CAP is a method that relies on the IID assumption, and thus uses the estimation in the validation data
    as an estimate for the test data.
    """
    def __init__(self, h: BaseEstimator, acc: callable):
        super().__init__(h, acc)

    def fit(self, val: LabelledCollection):
        y_hat = self.h.predict(val.X)
        y_true = val.y
        self.cont_table = confusion_matrix(y_true, y_pred=y_hat, labels=val.classes_)
        return self

    def predict_ct(self, test, oracle_prev=None):
        """
        This method disregards the test set, under the assumption that it is IID wrt the training. This meaning that
        the confusion matrix for the test data should coincide with the one computed for training (using any cross
        validation strategy).

        :param test: test collection (ignored)
        :param oracle_prev: ignored
        :return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix`
        """
        return self.cont_table


class CAPContingencyTableQ(CAPContingencyTable):

    def __init__(self, h: BaseEstimator, acc: callable, q_class: AggregativeQuantifier, reuse_h=False):
        super().__init__(h, acc)
        self.reuse_h = reuse_h
        if reuse_h:
            assert isinstance(q_class, AggregativeQuantifier), f'quantifier {q_class} is not of type aggregative'
            self.q = deepcopy(q_class)
            self.q.set_params(classifier=h)
        else:
            self.q = q_class

    def quantifier_fit(self, val: LabelledCollection):
        if self.reuse_h:
            self.q.fit(val, fit_classifier=False, val_split=val)
        else:
            self.q.fit(val)


class ContTableTransferCAP(CAPContingencyTableQ):
    """

    """
    def __init__(self, h: BaseEstimator, acc: callable, q_class, reuse_h=False):
        super().__init__(h, acc, q_class, reuse_h)

    def fit(self, val: LabelledCollection):
        y_hat = self.h.predict(val.X)
        y_true = val.y
        self.cont_table = confusion_matrix(y_true=y_true, y_pred=y_hat, labels=val.classes_, normalize='all')
        self.train_prev = val.prevalence()
        self.quantifier_fit(val)
        return self

    def predict_ct(self, test, oracle_prev=None):
        """
        :param test: test collection (ignored)
        :param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by
            an oracle. This is meant to test the effect of the errors in CAP that are explained by
            the errors in quantification performance
        :return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix`
        """
        if oracle_prev is None:
            prev_hat = self.q.quantify(test)
        else:
            prev_hat = oracle_prev
        adjustment = prev_hat / self.train_prev
        return self.cont_table * adjustment[:, np.newaxis]


class NsquaredEquationsCAP(CAPContingencyTableQ):
    """

    """
    def __init__(self, h: BaseEstimator, acc: callable, q_class, reuse_h=False):
        super().__init__(h, acc, q_class, reuse_h)

    def fit(self, val: LabelledCollection):
        y_hat = self.h.predict(val.X)
        y_true = val.y
        self.cont_table = confusion_matrix(y_true, y_pred=y_hat, labels=val.classes_)
        self.quantifier_fit(val)
        self.A, self.partial_b = self._construct_equations()
        return self

    def _construct_equations(self):
        # we need a n x n matrix of unknowns
        n = self.cont_table.shape[1]

        # I is the matrix of indexes of unknowns. For example, if we need the counts of
        # all instances belonging to class i that have been classified as belonging to 0, 1, ..., n:
        # the indexes of the corresponding unknowns are given by I[i,:]
        I = np.arange(n * n).reshape(n, n)

        # system of equations: Ax=b, A.shape=(n*n, n*n,), b.shape=(n*n,)
        A = np.zeros(shape=(n * n, n * n))
        b = np.zeros(shape=(n * n))

        # first equation: the sum of all unknowns is 1
        eq_no = 0
        A[eq_no, :] = 1
        b[eq_no] = 1
        eq_no += 1

        # (n-1)*(n-1) equations: the class cond rations should be the same in training and in test due to the
        # PPS assumptions. Example in three classes, a ratio: a/(a+b+c) [test] = ar [a ratio in training]
        # a / (a + b + c) = ar
        # a = (a + b + c) * ar
        # a = a ar + b ar + c ar
        # a - a ar - b ar - c ar = 0
        # a (1-ar) + b (-ar)  + c (-ar) = 0
        class_cond_ratios_tr = self.cont_table / self.cont_table.sum(axis=1, keepdims=True)
        for i in range(1, n):
            for j in range(1, n):
                ratio_ij = class_cond_ratios_tr[i, j]
                A[eq_no, I[i, :]] = -ratio_ij
                A[eq_no, I[i, j]] = 1 - ratio_ij
                b[eq_no] = 0
                eq_no += 1

        # n-1 equations: the sum of class-cond counts must equal the C&C prevalence prediction
        for i in range(1, n):
            A[eq_no, I[:, i]] = 1
            #b[eq_no] = cc_prev_estim[i]
            eq_no += 1

        # n-1 equations: the sum of true true class-conditional positives must equal the class prev label in test
        for i in range(1, n):
            A[eq_no, I[i, :]] = 1
            #b[eq_no] = q_prev_estim[i]
            eq_no += 1

        return A, b

    def predict_ct(self, test, oracle_prev):
        """
        :param test: test collection (ignored)
        :param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by
            an oracle. This is meant to test the effect of the errors in CAP that are explained by
            the errors in quantification performance
        :return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix`
        """

        n = self.cont_table.shape[1]

        h_label_preds = self.h.predict(test)
        cc_prev_estim = F.prevalence_from_labels(h_label_preds, self.h.classes_)
        if oracle_prev is None:
            q_prev_estim = self.q.quantify(test)
        else:
            q_prev_estim = oracle_prev

        A = self.A
        b = self.partial_b

        # b is partially filled; we finish the vector by plugin in the classify and count
        # prevalence estimates (n-1 values only), and the quantification estimates (n-1 values only)

        b[-2*(n-1):-(n-1)] = cc_prev_estim[1:]
        b[-(n-1):] = q_prev_estim[1:]

        # try the fast solution (may not be valid)
        x = np.linalg.solve(A, b)

        if any(x<0) or any(x>0) or not np.isclose(x.sum(), 1):

            print('L', end='')

            # try the iterative solution
            def loss(x):
                return np.linalg.norm(A @ x - b, ord=2)

            x = F.optim_minimize(loss, n_classes=n**2)

        else:
            print('.', end='')

        cont_table_test = x.reshape(n,n)
        return cont_table_test


class SebastianiCAP(ClassifierAccuracyPrediction):

     def __init__(self, h, acc_fn, q_class, n_val_samples=500, alpha=0.3, predict_train_prev=True):
         self.h = h
         self.acc = acc_fn
         self.q = q_class(h)
         self.n_val_samples = n_val_samples
         self.alpha = alpha
         self.sample_size = qp.environ['SAMPLE_SIZE']
         self.predict_train_prev = predict_train_prev

     def fit(self, val: LabelledCollection):
         v2, v1 = val.split_stratified(train_prop=0.5)
         self.q.fit(v1, fit_classifier=False, val_split=v1)

         # precompute classifier predictions on samples
         gen_samples = UPP(v2, repeats=self.n_val_samples, sample_size=self.sample_size, return_type='labelled_collection')
         self.sigma_acc = [self.true_acc(sigma_i) for sigma_i in gen_samples()]

         # precompute prevalence predictions on samples
         if self.predict_train_prev:
             gen_samples.on_preclassified_instances(self.q.classify(v2.X), in_place=True)
             self.sigma_pred_prevs = [self.q.aggregate(sigma_i.X) for sigma_i in gen_samples()]
         else:
             self.sigma_pred_prevs = [sigma_i.prevalence() for sigma_i in gen_samples()]

     def predict(self, X, oracle_prev=None):
         if oracle_prev is None:
            test_pred_prev = self.q.quantify(X)
         else:
             test_pred_prev = oracle_prev

         if self.alpha > 0:
             # select samples from V2 with predicted prevalence close to the predicted prevalence for U
             selected_accuracies = []
             for pred_prev_i, acc_i in zip(self.sigma_pred_prevs, self.sigma_acc):
                max_discrepancy = np.max(np.abs(pred_prev_i - test_pred_prev))
                if max_discrepancy < self.alpha:
                    selected_accuracies.append(acc_i)

             return np.median(selected_accuracies)
         else:
             # mean average, weights samples from V2 according to the closeness of predicted prevalence in U
             accum_weight = 0
             moving_mean = 0
             epsilon = 10E-4
             for pred_prev_i, acc_i in zip(self.sigma_pred_prevs, self.sigma_acc):
                 max_discrepancy = np.max(np.abs(pred_prev_i - test_pred_prev))
                 weight = -np.log(max_discrepancy+epsilon)
                 accum_weight += weight
                 moving_mean += (weight*acc_i)

             return moving_mean/accum_weight


class PabloCAP(ClassifierAccuracyPrediction):

    def __init__(self, h, acc_fn, q_class, n_val_samples=100, aggr='mean'):
        self.h = h
        self.acc = acc_fn
        self.q = q_class(deepcopy(h))
        self.n_val_samples = n_val_samples
        self.aggr = aggr
        assert aggr in ['mean', 'median'], 'unknown aggregation function, use mean or median'

    def fit(self, val: LabelledCollection):
        self.q.fit(val)
        label_predictions = self.h.predict(val.X)
        self.pre_classified = LabelledCollection(instances=label_predictions, labels=val.labels)

    def predict(self, X, oracle_prev=None):
        if oracle_prev is None:
            pred_prev = F.smooth(self.q.quantify(X))
        else:
            pred_prev = oracle_prev
        X_size = X.shape[0]
        acc_estim = []
        for _ in range(self.n_val_samples):
            sigma_i = self.pre_classified.sampling(X_size, *pred_prev[:-1])
            y_pred, y_true = sigma_i.Xy
            conf_table = confusion_matrix(y_true, y_pred=y_pred, labels=sigma_i.classes_)
            acc_i = self.acc(conf_table)
            acc_estim.append(acc_i)
        if self.aggr == 'mean':
            return np.mean(acc_estim)
        elif self.aggr == 'median':
            return np.median(acc_estim)
        else:
            raise ValueError('unknown aggregation function')


def get_posteriors_from_h(h, X):
    if hasattr(h, 'predict_proba'):
        P = h.predict_proba(X)
    else:
        n_classes = len(h.classes_)
        dec_scores = h.decision_function(X)
        if n_classes == 1:
            dec_scores = np.vstack([-dec_scores, dec_scores]).T
        P = scipy.special.softmax(dec_scores, axis=1)
    return P


def max_conf(P, keepdims=False):
    mc = P.max(axis=1, keepdims=keepdims)
    return mc


def neg_entropy(P, keepdims=False):
    ne = scipy.stats.entropy(P, axis=1)
    if keepdims:
        ne = ne.reshape(-1, 1)
    return ne


class QuAcc:

    def _get_X_dot(self, X):
        h = self.h

        P = get_posteriors_from_h(h, X)

        add_covs = []

        if self.add_posteriors:
            add_covs.append(P[:, 1:])

        if self.add_maxconf:
            mc = max_conf(P, keepdims=True)
            add_covs.append(mc)

        if self.add_negentropy:
            ne = neg_entropy(P, keepdims=True)
            add_covs.append(ne)

        if self.add_maxinfsoft:
            lgP = np.log(P)
            mis = np.max(lgP -lgP.mean(axis=1, keepdims=True), axis=1, keepdims=True)
            add_covs.append(mis)

        if len(add_covs)>0:
            X_dot = np.hstack(add_covs)

        if self.add_X:
            X_dot = safehstack(X, X_dot)

        return X_dot


class QuAcc1xN2(CAPContingencyTableQ, QuAcc):

    def __init__(self,
                 h: BaseEstimator,
                 acc: callable,
                 q_class: AggregativeQuantifier,
                 add_X=True,
                 add_posteriors=True,
                 add_maxconf=False,
                 add_negentropy=False,
                 add_maxinfsoft=False):
        self.h = h
        self.acc = acc
        self.q = EmptySafeQuantifier(q_class)
        self.add_X = add_X
        self.add_posteriors = add_posteriors
        self.add_maxconf = add_maxconf
        self.add_negentropy = add_negentropy
        self.add_maxinfsoft = add_maxinfsoft

    def fit(self, val: LabelledCollection):
        pred_labels = self.h.predict(val.X)
        true_labels = val.y

        n = val.n_classes
        classes_dot = np.arange(n**2)
        ct_class_idx = classes_dot.reshape(n, n)

        X_dot = self._get_X_dot(val.X)
        y_dot = ct_class_idx[true_labels, pred_labels]
        val_dot = LabelledCollection(X_dot, y_dot, classes=classes_dot)
        self.q.fit(val_dot)

    def predict_ct(self, X, oracle_prev=None):
        X_dot = self._get_X_dot(X)
        return self.q.quantify(X_dot)


class QuAccNxN(CAPContingencyTableQ, QuAcc):

    def __init__(self,
                 h: BaseEstimator,
                 acc: callable,
                 q_class: AggregativeQuantifier,
                 add_X=True,
                 add_posteriors=True,
                 add_maxconf=False,
                 add_negentropy=False,
                 add_maxinfsoft=False):
        self.h = h
        self.acc = acc
        self.q_class = q_class
        self.add_X = add_X
        self.add_posteriors = add_posteriors
        self.add_maxconf = add_maxconf
        self.add_negentropy = add_negentropy
        self.add_maxinfsoft = add_maxinfsoft

    def fit(self, val: LabelledCollection):
        pred_labels = self.h.predict(val.X)
        true_labels = val.y
        X_dot = self._get_X_dot(val.X)

        self.q = []
        for class_i in self.h.classes_:
            X_dot_i = X_dot[pred_labels==class_i]
            y_i = true_labels[pred_labels==class_i]
            data_i = LabelledCollection(X_dot_i, y_i, classes=val.classes_)

            q_i = EmptySafeQuantifier(deepcopy(self.q_class))
            q_i.fit(data_i)
            self.q.append(q_i)

    def predict_ct(self, X, oracle_prev=None):
        classes = self.h.classes_
        pred_labels = self.h.predict(X)
        X_dot = self._get_X_dot(X)
        pred_prev = F.prevalence_from_labels(pred_labels, classes)
        cont_table = []
        for class_i, q_i, p_i in zip(classes, self.q, pred_prev):
            X_dot_i = X_dot[pred_labels==class_i]
            classcond_cond_table_prevs = q_i.quantify(X_dot_i)
            cond_table_prevs = p_i * classcond_cond_table_prevs
            cont_table.append(cond_table_prevs)
        cont_table = np.vstack(cont_table)
        return cont_table


def safehstack(X, P):
    if issparse(X) or issparse(P):
        XP = scipy.sparse.hstack([X, P])
        XP = csr_matrix(XP)
    else:
        XP = np.hstack([X,P])
    return XP


class EmptySafeQuantifier(BaseQuantifier):
    def __init__(self, surrogate_quantifier: BaseQuantifier):
        self.surrogate = surrogate_quantifier

    def fit(self, data: LabelledCollection):
        self.n_classes = data.n_classes
        class_compact_data, self.old_class_idx = data.compact_classes()
        if self.num_non_empty_classes() > 1:
            self.surrogate.fit(class_compact_data)
        return self

    def quantify(self, instances):
        num_instances = instances.shape[0]
        if self.num_non_empty_classes() == 0 or num_instances==0:
            # returns the uniform prevalence vector
            uniform = np.full(fill_value=1./self.n_classes, shape=self.n_classes, dtype=float)
            return uniform
        elif self.num_non_empty_classes() == 1:
            # returns a prevalence vector with 100% of the mass in the only non empty class
            prev_vector = np.full(fill_value=0., shape=self.n_classes, dtype=float)
            prev_vector[self.old_class_idx[0]] = 1
            return prev_vector
        else:
            class_compact_prev = self.surrogate.quantify(instances)
            prev_vector = np.full(fill_value=0., shape=self.n_classes, dtype=float)
            prev_vector[self.old_class_idx] = class_compact_prev
            return prev_vector

    def num_non_empty_classes(self):
        return len(self.old_class_idx)


# Baselines:
class ATC(ClassifierAccuracyPrediction):

    VALID_FUNCTIONS = {'maxconf', 'neg_entropy'}

    def __init__(self, h, acc_fn, scoring_fn='maxconf'):
        assert scoring_fn in ATC.VALID_FUNCTIONS, \
            f'unknown scoring function, use any from {ATC.VALID_FUNCTIONS}'
        #assert acc_fn == 'vanilla_accuracy', \
        #    'use acc_fn=="vanilla_accuracy"; other metris are not yet tested in ATC'
        self.h = h
        self.acc_fn = acc_fn
        self.scoring_fn = scoring_fn

    def get_scores(self, P):
        if self.scoring_fn == 'maxconf':
            scores = max_conf(P)
        else:
            scores = neg_entropy(P)
        return scores

    def fit(self, val: LabelledCollection):
        P = get_posteriors_from_h(self.h, val.X)
        pred_labels = np.argmax(P, axis=1)
        true_labels = val.y
        scores = self.get_scores(P)
        _, self.threshold = self.__find_ATC_threshold(scores=scores, labels=(pred_labels==true_labels))

    def predict(self, X, oracle_prev=None):
        P = get_posteriors_from_h(self.h, X)
        scores = self.get_scores(P)
        #assert self.acc_fn == 'vanilla_accuracy', \
        #    'use acc_fn=="vanilla_accuracy"; other metris are not yet tested in ATC'
        return self.__get_ATC_acc(self.threshold, scores)

    def __find_ATC_threshold(self, scores, labels):
        # code copy-pasted from https://github.com/saurabhgarg1996/ATC_code/blob/master/ATC_helper.py
        sorted_idx = np.argsort(scores)

        sorted_scores = scores[sorted_idx]
        sorted_labels = labels[sorted_idx]

        fp = np.sum(labels == 0)
        fn = 0.0

        min_fp_fn = np.abs(fp - fn)
        thres = 0.0
        for i in range(len(labels)):
            if sorted_labels[i] == 0:
                fp -= 1
            else:
                fn += 1

            if np.abs(fp - fn) < min_fp_fn:
                min_fp_fn = np.abs(fp - fn)
                thres = sorted_scores[i]

        return min_fp_fn, thres

    def __get_ATC_acc(self, thres, scores):
        # code copy-pasted from https://github.com/saurabhgarg1996/ATC_code/blob/master/ATC_helper.py
        return np.mean(scores >= thres)


class DoC(ClassifierAccuracyPrediction):

    def __init__(self, h, acc, sample_size, num_samples=500, clip_vals=(0,1)):
        self.h = h
        self.acc = acc
        self.sample_size = sample_size
        self.num_samples = num_samples
        self.clip_vals = clip_vals

    def _get_post_stats(self, X, y):
        P = get_posteriors_from_h(self.h, X)
        mc = max_conf(P)
        pred_labels = np.argmax(P, axis=-1)
        acc = self.acc(y, pred_labels)
        return mc, acc

    def _doc(self, mc1, mc2):
        return mc2.mean() - mc1.mean()

    def train_regression(self, v2_mcs, v2_accs):
        docs = [self._doc(self.v1_mc, v2_mc_i) for v2_mc_i in v2_mcs]
        target = [self.v1_acc - v2_acc_i for v2_acc_i in v2_accs]
        docs = np.asarray(docs).reshape(-1,1)
        target = np.asarray(target)
        lin_reg = LinearRegression()
        return lin_reg.fit(docs, target)

    def predict_regression(self, test_mc):
        docs = np.asarray([self._doc(self.v1_mc, test_mc)]).reshape(-1, 1)
        pred_acc = self.reg_model.predict(docs)
        return self.v1_acc - pred_acc

    def fit(self, val: LabelledCollection):
        v1, v2 = val.split_stratified(train_prop=0.5, random_state=0)

        self.v1_mc, self.v1_acc = self._get_post_stats(*v1.Xy)

        v2_prot = UPP(v2, sample_size=self.sample_size, repeats=self.num_samples, return_type='labelled_collection')
        v2_stats = [self._get_post_stats(*sample.Xy) for sample in v2_prot()]
        v2_mcs, v2_accs = list(zip(*v2_stats))

        self.reg_model = self.train_regression(v2_mcs, v2_accs)

    def predict(self, X, oracle_prev=None):
        P = get_posteriors_from_h(self.h, X)
        mc = max_conf(P)
        acc_pred = self.predict_regression(mc)[0]
        if self.clip_vals is not None:
            acc_pred = np.clip(acc_pred, *self.clip_vals)
        return acc_pred

    """
    def doc(self,
            c_model: BaseEstimator,
            validation: LabelledCollection,
            protocol: AbstractStochasticSeededProtocol,
            predict_method="predict_proba"):

        c_model_predict = getattr(c_model, predict_method)
        f1_average = "binary" if validation.n_classes == 2 else "macro"

        val1, val2 = validation.split_stratified(train_prop=0.5, random_state=env._R_SEED)
        val1_probs = c_model_predict(val1.X)
        val1_mc = np.max(val1_probs, axis=-1)
        val1_preds = np.argmax(val1_probs, axis=-1)
        val1_acc = metrics.accuracy_score(val1.y, val1_preds)
        val1_f1 = metrics.f1_score(val1.y, val1_preds, average=f1_average)
        val2_protocol = APP(
            val2,
            n_prevalences=21,
            repeats=100,
            return_type="labelled_collection",
        )
        val2_prot_mc = []
        val2_prot_preds = []
        val2_prot_y = []
        for v2 in val2_protocol():
            _probs = c_model_predict(v2.X)
            _mc = np.max(_probs, axis=-1)
            _preds = np.argmax(_probs, axis=-1)
            val2_prot_mc.append(_mc)
            val2_prot_preds.append(_preds)
            val2_prot_y.append(v2.y)

        val_scores = np.array([doclib.get_doc(val1_mc, v2_mc) for v2_mc in val2_prot_mc])
        val_targets_acc = np.array(
            [
                val1_acc - metrics.accuracy_score(v2_y, v2_preds)
                for v2_y, v2_preds in zip(val2_prot_y, val2_prot_preds)
            ]
        )
        reg_acc = LinearRegression().fit(val_scores[:, np.newaxis], val_targets_acc)
        val_targets_f1 = np.array(
            [
                val1_f1 - metrics.f1_score(v2_y, v2_preds, average=f1_average)
                for v2_y, v2_preds in zip(val2_prot_y, val2_prot_preds)
            ]
        )
        reg_f1 = LinearRegression().fit(val_scores[:, np.newaxis], val_targets_f1)

        report = EvaluationReport(name="doc")
        for test in protocol():
            test_probs = c_model_predict(test.X)
            test_preds = np.argmax(test_probs, axis=-1)
            test_mc = np.max(test_probs, axis=-1)
            acc_score = (
                    val1_acc
                    - reg_acc.predict(np.array([[doclib.get_doc(val1_mc, test_mc)]]))[0]
            )
            f1_score = (
                    val1_f1 - reg_f1.predict(np.array([[doclib.get_doc(val1_mc, test_mc)]]))[0]
            )
            meta_acc = abs(acc_score - metrics.accuracy_score(test.y, test_preds))
            meta_f1 = abs(
                f1_score - metrics.f1_score(test.y, test_preds, average=f1_average)
            )
            report.append_row(
                test.prevalence(),
                acc=meta_acc,
                acc_score=acc_score,
                f1=meta_f1,
                f1_score=f1_score,
            )

        return report

    def get_doc(probs1, probs2):
        return np.mean(probs2) - np.mean(probs1)
        """