68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
from quapy.protocol import (
|
|
OnLabelledCollectionProtocol,
|
|
AbstractStochasticSeededProtocol,
|
|
)
|
|
import quapy as qp
|
|
from typing import Iterable, Callable, Union
|
|
|
|
from .estimator import AccuracyEstimator
|
|
import pandas as pd
|
|
import quacc.error as error
|
|
|
|
|
|
def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol):
|
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
|
|
base_prevs, true_prevs, estim_prevs = [], [], []
|
|
for sample in protocol():
|
|
e_sample = estimator.extend(sample)
|
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
|
base_prevs.append(sample.prevalence())
|
|
true_prevs.append(e_sample.prevalence())
|
|
estim_prevs.append(estim_prev)
|
|
|
|
return base_prevs, true_prevs, estim_prevs
|
|
|
|
|
|
def evaluation_report(
|
|
estimator: AccuracyEstimator,
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
error_metrics: Iterable[Union[str, Callable]] = "all",
|
|
):
|
|
base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
|
|
|
|
if error_metrics == "all":
|
|
error_metrics = ["mae", "rae", "mrae", "kld", "nkld", "f1e"]
|
|
|
|
error_funcs = [
|
|
error.from_name(e) if isinstance(e, str) else e for e in error_metrics
|
|
]
|
|
assert all(hasattr(e, "__call__") for e in error_funcs), "invalid error function"
|
|
error_names = [e.__name__ for e in error_funcs]
|
|
|
|
df_cols = ["base_prev", "true_prev", "estim_prev"] + error_names
|
|
if "f1e" in df_cols:
|
|
df_cols.remove("f1e")
|
|
df_cols.extend(["f1e_true", "f1e_estim"])
|
|
lst = []
|
|
for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
|
|
series = {
|
|
"base_prev": base_prev,
|
|
"true_prev": true_prev,
|
|
"estim_prev": estim_prev,
|
|
}
|
|
for error_name, error_metric in zip(error_names, error_funcs):
|
|
if error_name == "f1e":
|
|
series["f1e_true"] = error_metric(true_prev)
|
|
series["f1e_estim"] = error_metric(estim_prev)
|
|
continue
|
|
|
|
score = error_metric(true_prev, estim_prev)
|
|
series[error_name] = score
|
|
|
|
lst.append(series)
|
|
|
|
df = pd.DataFrame(lst, columns=df_cols)
|
|
return df
|