From 5959a0d3235e56dfe1aa1e3ff139ba4d725afba6 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Sat, 20 May 2023 20:23:17 +0200 Subject: [PATCH] Refactoring completed --- quacc/data.py | 1 - quacc/error.py | 15 +++++++++++++++ quacc/estimator.py | 6 +++--- quacc/evaluation.py | 46 +++++++++++++++++++++++++++++++++++++++++++-- quacc/main.py | 26 +++++++++++++++++++++---- 5 files changed, 84 insertions(+), 10 deletions(-) create mode 100644 quacc/error.py diff --git a/quacc/data.py b/quacc/data.py index e14bf98..f8fad6d 100644 --- a/quacc/data.py +++ b/quacc/data.py @@ -7,7 +7,6 @@ from typing import List, Optional class ExtendedCollection(LabelledCollection): def __init__( self, - b_coll: LabelledCollection, instances: np.ndarray | sp.csr_matrix, labels: np.ndarray, classes: Optional[List] = None, diff --git a/quacc/error.py b/quacc/error.py new file mode 100644 index 0000000..2ab688e --- /dev/null +++ b/quacc/error.py @@ -0,0 +1,15 @@ +import quapy as qp + +def from_name(err_name): + if err_name == 'f1e': + return f1e + else: + return qp.error.from_name(err_name) + +def f1e(prev): + return 1 - f1_score(prev) + +def f1_score(prev): + recall = prev[0] / (prev[0] + prev[1]) + precision = prev[0] / (prev[0] + prev[2]) + return 2 * (precision * recall) / (precision + recall) diff --git a/quacc/estimator.py b/quacc/estimator.py index 760133c..4afd490 100644 --- a/quacc/estimator.py +++ b/quacc/estimator.py @@ -46,7 +46,7 @@ def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollecti ] ) - return ExtendedCollection(n_x, n_y, [*range(0, n_classes * n_classes)]) + return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)]) class AccuracyEstimator: @@ -65,7 +65,7 @@ class AccuracyEstimator: # self.model.fit(*train.Xy) if isinstance(train, LabelledCollection): pred_prob_train = cross_val_predict( - self.model, train.Xy, method="predict_proba" + self.model, *train.Xy, method="predict_proba" ) self.e_train = _extend_collection(train, pred_prob_train) @@ -84,5 +84,5 @@ class AccuracyEstimator: estim_prev = self.q_model.quantify(e_inst) return _check_prevalence_classes( - e_inst.classes_, self.q_model.classes_, estim_prev + self.e_train.classes_, self.q_model.classes_, estim_prev ) diff --git a/quacc/evaluation.py b/quacc/evaluation.py index 02dfc67..d12d098 100644 --- a/quacc/evaluation.py +++ b/quacc/evaluation.py @@ -2,8 +2,12 @@ from quapy.protocol import ( OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol, ) +import quapy as qp +from typing import Iterable, Callable, Union from .estimator import AccuracyEstimator +import pandas as pd +import quacc.error as error def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol): @@ -21,5 +25,43 @@ def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededPro return base_prevs, true_prevs, estim_prevs -def evaluate(): - pass +def evaluation_report( + estimator: AccuracyEstimator, + protocol: AbstractStochasticSeededProtocol, + error_metrics: Iterable[Union[str, Callable]] = "all", +): + base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol) + + if error_metrics == "all": + error_metrics = ["mae", "rae", "mrae", "kld", "nkld", "f1e"] + + error_funcs = [ + error.from_name(e) if isinstance(e, str) else e for e in error_metrics + ] + assert all(hasattr(e, "__call__") for e in error_funcs), "invalid error function" + error_names = [e.__name__ for e in error_funcs] + + df_cols = ["base_prev", "true_prev", "estim_prev"] + error_names + if "f1e" in df_cols: + df_cols.remove("f1e") + df_cols.extend(["f1e_true", "f1e_estim"]) + lst = [] + for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs): + series = { + "base_prev": base_prev, + "true_prev": true_prev, + "estim_prev": estim_prev, + } + for error_name, error_metric in zip(error_names, error_funcs): + if error_name == "f1e": + series["f1e_true"] = error_metric(true_prev) + series["f1e_estim"] = error_metric(estim_prev) + continue + + score = error_metric(true_prev, estim_prev) + series[error_name] = score + + lst.append(series) + + df = pd.DataFrame(lst, columns=df_cols) + return df diff --git a/quacc/main.py b/quacc/main.py index 0d2423b..51f4b04 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -2,10 +2,16 @@ import numpy as np import quapy as qp import scipy.sparse as sp from quapy.data import LabelledCollection +from quapy.method.aggregative import SLD from quapy.protocol import APP, AbstractStochasticSeededProtocol from sklearn.linear_model import LogisticRegression from sklearn.model_selection import cross_val_predict +import quacc.evaluation as eval +from quacc.estimator import AccuracyEstimator + +qp.environ['SAMPLE_SIZE'] = 100 + # Extended classes # @@ -86,7 +92,7 @@ def extend_and_quantify( pred_prob_test = model.predict_proba(test.X) _test = extend_collection(test, pred_prob_test) _estim_prev = q_model.quantify(_test.instances) - # check that _estim_prev has all the classes and eventually fill the missing + # check that _estim_prev has all the classes and eventually fill the missing # ones with 0 for _cls in _test.classes_: if _cls not in q_model.classes_: @@ -133,9 +139,9 @@ def test_1(dataset_name): orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify( LogisticRegression(), - qp.method.aggregative.SLD(LogisticRegression()), + SLD(LogisticRegression()), train, - APP(test, sample_size=100, n_prevalences=11, repeats=1), + APP(test, n_prevalences=11, repeats=1), ) for orig_prev, true_prev, estim_prev, _errors in zip( @@ -149,6 +155,18 @@ def test_1(dataset_name): print() +def test_2(dataset_name): + train, test = get_dataset(dataset_name) + model = LogisticRegression() + model.fit(*train.Xy) + estimator = AccuracyEstimator(model, SLD(LogisticRegression())) + estimator.fit(train) + df = eval.evaluation_report( + estimator, APP(test, n_prevalences=11, repeats=1) + ) + print(df.to_string()) + + def main(): for dataset_name in [ # "hp", @@ -156,7 +174,7 @@ def main(): "spambase", ]: print(dataset_name) - test_1(dataset_name) + test_2(dataset_name) print("*" * 50)