157 lines
5.0 KiB
Python
157 lines
5.0 KiB
Python
import itertools
|
|
from quapy.protocol import (
|
|
OnLabelledCollectionProtocol,
|
|
AbstractStochasticSeededProtocol,
|
|
)
|
|
from typing import Iterable, Callable, Union
|
|
|
|
from .estimator import AccuracyEstimator
|
|
import pandas as pd
|
|
import numpy as np
|
|
import quacc.error as error
|
|
import statistics as stats
|
|
|
|
|
|
def estimate(
|
|
estimator: AccuracyEstimator,
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
):
|
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
|
|
base_prevs, true_prevs, estim_prevs = [], [], []
|
|
for sample in protocol():
|
|
e_sample = estimator.extend(sample)
|
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
|
# base_prevs.append(_prettyfloat(accuracy, sample.prevalence()))
|
|
# true_prevs.append(_prettyfloat(accuracy, e_sample.prevalence()))
|
|
# estim_prevs.append(_prettyfloat(accuracy, estim_prev))
|
|
base_prevs.append(sample.prevalence())
|
|
true_prevs.append(e_sample.prevalence())
|
|
estim_prevs.append(estim_prev)
|
|
|
|
return base_prevs, true_prevs, estim_prevs
|
|
|
|
|
|
_bprev_col_0 = ["base"]
|
|
_bprev_col_1 = ["0", "1"]
|
|
_prev_col_0 = ["true", "estim"]
|
|
_prev_col_1 = ["TN", "FP", "FN", "TP"]
|
|
_err_col_0 = ["errors"]
|
|
|
|
|
|
def _report_columns(err_names):
|
|
bprev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1))
|
|
prev_cols = list(itertools.product(_prev_col_0, _prev_col_1))
|
|
|
|
err_1 = err_names
|
|
err_cols = list(itertools.product(_err_col_0, err_1))
|
|
|
|
cols = bprev_cols + prev_cols + err_cols
|
|
|
|
return pd.MultiIndex.from_tuples(cols)
|
|
|
|
def _report_avg_groupby_distribution(lst, error_names):
|
|
def _bprev(s):
|
|
return (s[("base", "0")], s[("base", "1")])
|
|
|
|
def _normalize_prev(r, prev_name):
|
|
raw_prev = [v for ((k0, k1), v) in r.items() if k0 == prev_name]
|
|
norm_prev = [v/sum(raw_prev) for v in raw_prev]
|
|
for n, v in zip(itertools.product([prev_name], _prev_col_1), norm_prev):
|
|
r[n] = v
|
|
|
|
return r
|
|
|
|
|
|
current_bprev = _bprev(lst[0])
|
|
bprev_cnt = 0
|
|
g_lst = [[]]
|
|
for s in lst:
|
|
if _bprev(s) == current_bprev:
|
|
g_lst[bprev_cnt].append(s)
|
|
else:
|
|
g_lst.append([])
|
|
bprev_cnt += 1
|
|
current_bprev = _bprev(s)
|
|
g_lst[bprev_cnt].append(s)
|
|
|
|
r_lst = []
|
|
for gs in g_lst:
|
|
assert len(gs) > 0
|
|
r = {}
|
|
r[("base", "0")], r[("base", "1")] = _bprev(gs[0])
|
|
|
|
for pn in itertools.product(_prev_col_0, _prev_col_1):
|
|
r[pn] = stats.mean(map(lambda s: s[pn], gs))
|
|
|
|
r = _normalize_prev(r, "true")
|
|
r = _normalize_prev(r, "estim")
|
|
|
|
for en in itertools.product(_err_col_0, error_names):
|
|
r[en] = stats.mean(map(lambda s: s[en], gs))
|
|
|
|
r_lst.append(r)
|
|
|
|
return r_lst
|
|
|
|
def evaluation_report(
|
|
estimator: AccuracyEstimator,
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
error_metrics: Iterable[Union[str, Callable]] = "all",
|
|
aggregate: bool = True,
|
|
):
|
|
base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
|
|
|
|
if error_metrics == "all":
|
|
error_metrics = ["ae", "f1"]
|
|
|
|
error_funcs = [
|
|
error.from_name(e) if isinstance(e, str) else e for e in error_metrics
|
|
]
|
|
assert all(hasattr(e, "__call__") for e in error_funcs), "invalid error function"
|
|
error_names = [e.__name__ for e in error_funcs]
|
|
error_cols = error_names.copy()
|
|
if "f1" in error_cols:
|
|
error_cols.remove("f1")
|
|
error_cols.extend(["f1_true", "f1_estim", "f1_dist"])
|
|
if "f1e" in error_cols:
|
|
error_cols.remove("f1e")
|
|
error_cols.extend(["f1e_true", "f1e_estim"])
|
|
|
|
# df_cols = ["base_prev", "true_prev", "estim_prev"] + error_names
|
|
df_cols = _report_columns(error_cols)
|
|
|
|
lst = []
|
|
for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
|
|
prev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + list(
|
|
itertools.product(_prev_col_0, _prev_col_1)
|
|
)
|
|
|
|
series = {
|
|
k: v
|
|
for (k, v) in zip(
|
|
prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0)
|
|
)
|
|
}
|
|
for error_name, error_metric in zip(error_names, error_funcs):
|
|
if error_name == "f1e":
|
|
series[("errors", "f1e_true")] = error_metric(true_prev)
|
|
series[("errors", "f1e_estim")] = error_metric(estim_prev)
|
|
continue
|
|
if error_name == "f1":
|
|
f1_true, f1_estim = error_metric(true_prev), error_metric(estim_prev)
|
|
series[("errors", "f1_true")] = f1_true
|
|
series[("errors", "f1_estim")] = f1_estim
|
|
series[("errors", "f1_dist")] = abs(f1_estim - f1_true)
|
|
continue
|
|
|
|
score = error_metric(true_prev, estim_prev)
|
|
series[("errors", error_name)] = score
|
|
|
|
lst.append(series)
|
|
|
|
lst = _report_avg_groupby_distribution(lst, error_cols) if aggregate else lst
|
|
df = pd.DataFrame(lst, columns=df_cols)
|
|
return df
|