QuAcc/quacc/evaluation.py

106 lines
3.4 KiB
Python

import itertools
from quapy.protocol import (
OnLabelledCollectionProtocol,
AbstractStochasticSeededProtocol,
)
from typing import Iterable, Callable, Union
from .estimator import AccuracyEstimator
import pandas as pd
import numpy as np
import quacc.error as error
def estimate(
estimator: AccuracyEstimator,
protocol: AbstractStochasticSeededProtocol,
):
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
base_prevs, true_prevs, estim_prevs = [], [], []
for sample in protocol():
e_sample = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
# base_prevs.append(_prettyfloat(accuracy, sample.prevalence()))
# true_prevs.append(_prettyfloat(accuracy, e_sample.prevalence()))
# estim_prevs.append(_prettyfloat(accuracy, estim_prev))
base_prevs.append(sample.prevalence())
true_prevs.append(e_sample.prevalence())
estim_prevs.append(estim_prev)
return base_prevs, true_prevs, estim_prevs
_bprev_col_0 = ["base"]
_bprev_col_1 = ["0", "1"]
_prev_col_0 = ["true", "estim"]
_prev_col_1 = ["T0", "F1", "F0", "T1"]
_err_col_0 = ["errors"]
def _report_columns(err_names):
bprev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1))
prev_cols = list(itertools.product(_prev_col_0, _prev_col_1))
err_1 = err_names
err_cols = list(itertools.product(_err_col_0, err_1))
cols = bprev_cols + prev_cols + err_cols
return pd.MultiIndex.from_tuples(cols)
def _dict_prev(base_prev, true_prev, estim_prev):
prev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + list(
itertools.product(_prev_col_0, _prev_col_1)
)
return {
k: v
for (k, v) in zip(
prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0)
)
}
def evaluation_report(
estimator: AccuracyEstimator,
protocol: AbstractStochasticSeededProtocol,
error_metrics: Iterable[Union[str, Callable]] = "all",
):
base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
if error_metrics == "all":
error_metrics = ["mae", "rae", "mrae", "kld", "nkld", "f1e"]
error_funcs = [
error.from_name(e) if isinstance(e, str) else e for e in error_metrics
]
assert all(hasattr(e, "__call__") for e in error_funcs), "invalid error function"
error_names = [e.__name__ for e in error_funcs]
error_cols = error_names.copy()
if "f1e" in error_cols:
error_cols.remove("f1e")
error_cols.extend(["f1e_true", "f1e_estim"])
# df_cols = ["base_prev", "true_prev", "estim_prev"] + error_names
df_cols = _report_columns(error_cols)
lst = []
for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
series = _dict_prev(base_prev, true_prev, estim_prev)
for error_name, error_metric in zip(error_names, error_funcs):
if error_name == "f1e":
series[("errors", "f1e_true")] = error_metric(true_prev)
series[("errors", "f1e_estim")] = error_metric(estim_prev)
continue
score = error_metric(true_prev, estim_prev)
series[("errors", error_name)] = score
lst.append(series)
df = pd.DataFrame(lst, columns=df_cols)
return df