2021-04-29 16:07:39 +02:00
|
|
|
import numpy
|
|
|
|
import pytest
|
|
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
|
from sklearn.naive_bayes import MultinomialNB
|
|
|
|
from sklearn.svm import LinearSVC
|
|
|
|
|
|
|
|
import quapy as qp
|
2021-04-30 17:22:58 +02:00
|
|
|
from quapy.method import AGGREGATIVE_METHODS
|
2021-04-29 16:07:39 +02:00
|
|
|
|
2021-04-30 17:22:58 +02:00
|
|
|
datasets = [pytest.param(qp.datasets.fetch_twitter('hcr'), id='hcr'),
|
|
|
|
pytest.param(qp.datasets.fetch_UCIDataset('ionosphere'), id='ionosphere')]
|
2021-04-29 16:07:39 +02:00
|
|
|
|
|
|
|
learners = [LogisticRegression, MultinomialNB, LinearSVC]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('dataset', datasets)
|
2021-04-30 17:22:58 +02:00
|
|
|
@pytest.mark.parametrize('aggregative_method', AGGREGATIVE_METHODS)
|
2021-04-29 16:07:39 +02:00
|
|
|
@pytest.mark.parametrize('learner', learners)
|
|
|
|
def test_aggregative_methods(dataset, aggregative_method, learner):
|
|
|
|
model = aggregative_method(learner())
|
|
|
|
|
2021-04-30 17:22:58 +02:00
|
|
|
if model.binary and not dataset.binary:
|
|
|
|
return
|
|
|
|
|
2021-04-29 16:07:39 +02:00
|
|
|
model.fit(dataset.training)
|
|
|
|
|
|
|
|
estim_prevalences = model.quantify(dataset.test.instances)
|
|
|
|
|
|
|
|
true_prevalences = dataset.test.prevalence()
|
|
|
|
error = qp.error.mae(true_prevalences, estim_prevalences)
|
|
|
|
|
|
|
|
assert type(error) == numpy.float64
|