estimator refactored, missing evaluation
This commit is contained in:
parent
755fbad588
commit
6ac18137fa
|
@ -13,4 +13,3 @@ class ExtendedCollection(LabelledCollection):
|
|||
classes: Optional[List] = None,
|
||||
):
|
||||
super().__init__(instances, labels, classes=classes)
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ from quapy.method.base import BaseQuantifier
|
|||
from sklearn.base import BaseEstimator
|
||||
from sklearn.model_selection import cross_val_predict
|
||||
|
||||
import quacc as qc
|
||||
from .data import ExtendedCollection
|
||||
|
||||
|
||||
|
@ -15,9 +14,11 @@ def _check_prevalence_classes(true_classes, estim_classes, estim_prev):
|
|||
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
||||
return estim_prev
|
||||
|
||||
|
||||
def _get_ex_class(classes, true_class, pred_class):
|
||||
return true_class * classes + pred_class
|
||||
|
||||
|
||||
def _extend_instances(instances, pred_proba):
|
||||
if isinstance(instances, sp.csr_matrix):
|
||||
_pred_proba = sp.csr_matrix(pred_proba)
|
||||
|
@ -29,6 +30,7 @@ def _extend_instances(instances, pred_proba):
|
|||
|
||||
return n_x
|
||||
|
||||
|
||||
def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollection:
|
||||
n_classes = base.n_classes
|
||||
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
from quapy.method.base import BaseQuantifier
|
||||
from quapy.protocol import OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol
|
||||
from quapy.protocol import (
|
||||
OnLabelledCollectionProtocol,
|
||||
AbstractStochasticSeededProtocol,
|
||||
)
|
||||
|
||||
from .estimator import AccuracyEstimator, _extend_collection
|
||||
from .estimator import AccuracyEstimator
|
||||
|
||||
|
||||
def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol):
|
||||
|
||||
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator('labelled_collection')
|
||||
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||
|
||||
base_prevs, true_prevs, estim_prevs = [], [], []
|
||||
for sample in protocol():
|
||||
|
|
Loading…
Reference in New Issue