from functools import wraps from typing import List import numpy as np import quapy as qp from quacc.data import ExtendedPrev def from_name(err_name): assert err_name in ERROR_NAMES, f"unknown error {err_name}" callable_error = globals()[err_name] return callable_error # def f1(prev): # # https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure # if prev[0] == 0 and prev[1] == 0 and prev[2] == 0: # return 1.0 # elif prev[0] == 0 and prev[1] > 0 and prev[2] == 0: # return 0.0 # elif prev[0] == 0 and prev[1] == 0 and prev[2] > 0: # return float('NaN') # else: # recall = prev[0] / (prev[0] + prev[1]) # precision = prev[0] / (prev[0] + prev[2]) # return 2 * (precision * recall) / (precision + recall) def nae(prevs: np.ndarray, prevs_hat: np.ndarray) -> np.ndarray: _ae = qp.error.ae(prevs, prevs_hat) # _zae = (2.0 * (1.0 - prevs.min())) / prevs.shape[1] _zae = 2.0 / prevs.shape[1] return _ae / _zae def f1(prev: np.ndarray | ExtendedPrev) -> float: if isinstance(prev, ExtendedPrev): prev = prev.A def _score(idx): _tp = prev[idx, idx] _fn = prev[idx, :].sum() - _tp _fp = prev[:, idx].sum() - _tp _den = 2.0 * _tp + _fp + _fn return 0.0 if _den == 0.0 else (2.0 * _tp) / _den if prev.shape[0] == 2: return _score(1) else: _idxs = np.arange(prev.shape[0]) return np.array([_score(idx) for idx in _idxs]).mean() def f1e(prev): return 1 - f1(prev) def acc(prev: np.ndarray | ExtendedPrev) -> float: if isinstance(prev, ExtendedPrev): prev = prev.A return np.diag(prev).sum() / prev.sum() def accd( true_prevs: List[np.ndarray | ExtendedPrev], estim_prevs: List[np.ndarray | ExtendedPrev], ) -> np.ndarray: a_tp = np.array([acc(tp) for tp in true_prevs]) a_ep = np.array([acc(ep) for ep in estim_prevs]) return np.abs(a_tp - a_ep) def maccd( true_prevs: List[np.ndarray | ExtendedPrev], estim_prevs: List[np.ndarray | ExtendedPrev], ) -> float: return accd(true_prevs, estim_prevs).mean() ACCURACY_ERROR = {maccd} ACCURACY_ERROR_SINGLE = {accd} ACCURACY_ERROR_NAMES = {func.__name__ for func in ACCURACY_ERROR} ACCURACY_ERROR_SINGLE_NAMES = {func.__name__ for func in ACCURACY_ERROR_SINGLE} ERROR_NAMES = ACCURACY_ERROR_NAMES | ACCURACY_ERROR_SINGLE_NAMES