From 6ac18137fa7c22f939b8d21b3b3a3c5c922325f4 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 18 May 2023 22:56:57 +0200 Subject: [PATCH] estimator refactored, missing evaluation --- quacc/data.py | 1 - quacc/estimator.py | 4 +++- quacc/evaluation.py | 23 ++++++++++++----------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/quacc/data.py b/quacc/data.py index 7d511b8..e14bf98 100644 --- a/quacc/data.py +++ b/quacc/data.py @@ -13,4 +13,3 @@ class ExtendedCollection(LabelledCollection): classes: Optional[List] = None, ): super().__init__(instances, labels, classes=classes) - diff --git a/quacc/estimator.py b/quacc/estimator.py index e0f8520..760133c 100644 --- a/quacc/estimator.py +++ b/quacc/estimator.py @@ -5,7 +5,6 @@ from quapy.method.base import BaseQuantifier from sklearn.base import BaseEstimator from sklearn.model_selection import cross_val_predict -import quacc as qc from .data import ExtendedCollection @@ -15,9 +14,11 @@ def _check_prevalence_classes(true_classes, estim_classes, estim_prev): estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0) return estim_prev + def _get_ex_class(classes, true_class, pred_class): return true_class * classes + pred_class + def _extend_instances(instances, pred_proba): if isinstance(instances, sp.csr_matrix): _pred_proba = sp.csr_matrix(pred_proba) @@ -29,6 +30,7 @@ def _extend_instances(instances, pred_proba): return n_x + def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollection: n_classes = base.n_classes diff --git a/quacc/evaluation.py b/quacc/evaluation.py index 336093e..02dfc67 100644 --- a/quacc/evaluation.py +++ b/quacc/evaluation.py @@ -1,24 +1,25 @@ -from quapy.method.base import BaseQuantifier -from quapy.protocol import OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol +from quapy.protocol import ( + OnLabelledCollectionProtocol, + AbstractStochasticSeededProtocol, +) -from .estimator import AccuracyEstimator, _extend_collection +from .estimator import AccuracyEstimator def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol): - # ensure that the protocol returns a LabelledCollection for each iteration - protocol.collator = OnLabelledCollectionProtocol.get_collator('labelled_collection') + protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") base_prevs, true_prevs, estim_prevs = [], [], [] for sample in protocol(): - e_sample = estimator.extend(sample) - estim_prev = estimator.estimate(e_sample.X, ext=True) - base_prevs.append(sample.prevalence()) - true_prevs.append(e_sample.prevalence()) - estim_prevs.append(estim_prev) + e_sample = estimator.extend(sample) + estim_prev = estimator.estimate(e_sample.X, ext=True) + base_prevs.append(sample.prevalence()) + true_prevs.append(e_sample.prevalence()) + estim_prevs.append(estim_prev) return base_prevs, true_prevs, estim_prevs - + def evaluate(): pass