estimator refactored, missing evaluation
This commit is contained in:
parent
755fbad588
commit
6ac18137fa
|
@ -13,4 +13,3 @@ class ExtendedCollection(LabelledCollection):
|
||||||
classes: Optional[List] = None,
|
classes: Optional[List] = None,
|
||||||
):
|
):
|
||||||
super().__init__(instances, labels, classes=classes)
|
super().__init__(instances, labels, classes=classes)
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ from quapy.method.base import BaseQuantifier
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
from sklearn.model_selection import cross_val_predict
|
from sklearn.model_selection import cross_val_predict
|
||||||
|
|
||||||
import quacc as qc
|
|
||||||
from .data import ExtendedCollection
|
from .data import ExtendedCollection
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,9 +14,11 @@ def _check_prevalence_classes(true_classes, estim_classes, estim_prev):
|
||||||
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
||||||
return estim_prev
|
return estim_prev
|
||||||
|
|
||||||
|
|
||||||
def _get_ex_class(classes, true_class, pred_class):
|
def _get_ex_class(classes, true_class, pred_class):
|
||||||
return true_class * classes + pred_class
|
return true_class * classes + pred_class
|
||||||
|
|
||||||
|
|
||||||
def _extend_instances(instances, pred_proba):
|
def _extend_instances(instances, pred_proba):
|
||||||
if isinstance(instances, sp.csr_matrix):
|
if isinstance(instances, sp.csr_matrix):
|
||||||
_pred_proba = sp.csr_matrix(pred_proba)
|
_pred_proba = sp.csr_matrix(pred_proba)
|
||||||
|
@ -29,6 +30,7 @@ def _extend_instances(instances, pred_proba):
|
||||||
|
|
||||||
return n_x
|
return n_x
|
||||||
|
|
||||||
|
|
||||||
def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollection:
|
def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollection:
|
||||||
n_classes = base.n_classes
|
n_classes = base.n_classes
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,25 @@
|
||||||
from quapy.method.base import BaseQuantifier
|
from quapy.protocol import (
|
||||||
from quapy.protocol import OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol
|
OnLabelledCollectionProtocol,
|
||||||
|
AbstractStochasticSeededProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
from .estimator import AccuracyEstimator, _extend_collection
|
from .estimator import AccuracyEstimator
|
||||||
|
|
||||||
|
|
||||||
def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol):
|
def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol):
|
||||||
|
|
||||||
# ensure that the protocol returns a LabelledCollection for each iteration
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
||||||
protocol.collator = OnLabelledCollectionProtocol.get_collator('labelled_collection')
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
||||||
|
|
||||||
base_prevs, true_prevs, estim_prevs = [], [], []
|
base_prevs, true_prevs, estim_prevs = [], [], []
|
||||||
for sample in protocol():
|
for sample in protocol():
|
||||||
e_sample = estimator.extend(sample)
|
e_sample = estimator.extend(sample)
|
||||||
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
||||||
base_prevs.append(sample.prevalence())
|
base_prevs.append(sample.prevalence())
|
||||||
true_prevs.append(e_sample.prevalence())
|
true_prevs.append(e_sample.prevalence())
|
||||||
estim_prevs.append(estim_prev)
|
estim_prevs.append(estim_prev)
|
||||||
|
|
||||||
return base_prevs, true_prevs, estim_prevs
|
return base_prevs, true_prevs, estim_prevs
|
||||||
|
|
||||||
|
|
||||||
def evaluate():
|
def evaluate():
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue