estimator refactored, missing evaluation

This commit is contained in:
Lorenzo Volpi 2023-05-18 22:56:57 +02:00
parent 755fbad588
commit 6ac18137fa
3 changed files with 15 additions and 13 deletions

View File

@ -13,4 +13,3 @@ class ExtendedCollection(LabelledCollection):
classes: Optional[List] = None,
):
super().__init__(instances, labels, classes=classes)

View File

@ -5,7 +5,6 @@ from quapy.method.base import BaseQuantifier
from sklearn.base import BaseEstimator
from sklearn.model_selection import cross_val_predict
import quacc as qc
from .data import ExtendedCollection
@ -15,9 +14,11 @@ def _check_prevalence_classes(true_classes, estim_classes, estim_prev):
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
return estim_prev
def _get_ex_class(classes, true_class, pred_class):
return true_class * classes + pred_class
def _extend_instances(instances, pred_proba):
if isinstance(instances, sp.csr_matrix):
_pred_proba = sp.csr_matrix(pred_proba)
@ -29,6 +30,7 @@ def _extend_instances(instances, pred_proba):
return n_x
def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollection:
n_classes = base.n_classes

View File

@ -1,13 +1,14 @@
from quapy.method.base import BaseQuantifier
from quapy.protocol import OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol
from quapy.protocol import (
OnLabelledCollectionProtocol,
AbstractStochasticSeededProtocol,
)
from .estimator import AccuracyEstimator, _extend_collection
from .estimator import AccuracyEstimator
def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol):
# ensure that the protocol returns a LabelledCollection for each iteration
protocol.collator = OnLabelledCollectionProtocol.get_collator('labelled_collection')
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
base_prevs, true_prevs, estim_prevs = [], [], []
for sample in protocol():