2023-05-18 22:56:57 +02:00
|
|
|
from quapy.protocol import (
|
|
|
|
OnLabelledCollectionProtocol,
|
|
|
|
AbstractStochasticSeededProtocol,
|
|
|
|
)
|
2023-05-18 22:55:10 +02:00
|
|
|
|
2023-05-18 22:56:57 +02:00
|
|
|
from .estimator import AccuracyEstimator
|
2023-05-18 22:55:10 +02:00
|
|
|
|
|
|
|
|
|
|
|
def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol):
|
|
|
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
2023-05-18 22:56:57 +02:00
|
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
2023-05-18 22:55:10 +02:00
|
|
|
|
|
|
|
base_prevs, true_prevs, estim_prevs = [], [], []
|
|
|
|
for sample in protocol():
|
2023-05-18 22:56:57 +02:00
|
|
|
e_sample = estimator.extend(sample)
|
|
|
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
|
|
|
base_prevs.append(sample.prevalence())
|
|
|
|
true_prevs.append(e_sample.prevalence())
|
|
|
|
estim_prevs.append(estim_prev)
|
2023-05-18 22:55:10 +02:00
|
|
|
|
|
|
|
return base_prevs, true_prevs, estim_prevs
|
2023-05-18 22:56:57 +02:00
|
|
|
|
2023-05-18 22:55:10 +02:00
|
|
|
|
|
|
|
def evaluate():
|
|
|
|
pass
|