from typing import Callable, Union

from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol

import quacc as qc
from quacc.deprecated.method.base import BaseAccuracyEstimator


def evaluate(
    estimator: BaseAccuracyEstimator,
    protocol: AbstractProtocol,
    error_metric: Union[Callable | str],
) -> float:
    if isinstance(error_metric, str):
        error_metric = qc.error.from_name(error_metric)

    collator_bck_ = protocol.collator
    protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")

    estim_prevs, true_prevs = [], []
    for sample in protocol():
        e_sample = estimator.extend(sample)
        estim_prev = estimator.estimate(e_sample.eX)
        estim_prevs.append(estim_prev)
        true_prevs.append(e_sample.e_prevalence())

    protocol.collator = collator_bck_

    # true_prevs = np.array(true_prevs)
    # estim_prevs = np.array(estim_prevs)

    return error_metric(true_prevs, estim_prevs)